Unverified Commit 96d999fb authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Kernel] Initial Machete W4A8 support + Refactors (#9855)


Signed-off-by: default avatarLucas Wilkinson <lwilkinson@neuralmagic.com>
parent c2170a5b
This diff is collapsed.
......@@ -20,10 +20,11 @@ if __name__ == "__main__":
args = parser.parse_args()
with open(args.filename, 'rb') as f:
data: List[TMeasurement] = pickle.load(f)
data = pickle.load(f)
raw_results: List[TMeasurement] = data["results"]
results = defaultdict(lambda: list())
for v in data:
for v in raw_results:
result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label)
if result is not None:
KN = result.group(1)
......
......@@ -40,4 +40,10 @@ WEIGHT_SHAPES = {
([8192, 57344], 1),
([28672, 8192], 0),
],
"meta-llama/Llama-3.1-405b-hf": [
([16384, 18432], 1),
([16384, 16384], 0),
([16384, 106496], 1),
([53248, 16384], 0),
],
}
......@@ -20,9 +20,9 @@ CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) {
// is the layout f(x) = x
template <typename Layout>
CUTE_HOST_DEVICE static constexpr bool is_identity_layout() {
if constexpr (std::is_same_v<Layout, void>)
if constexpr (std::is_same_v<Layout, void>) {
return true;
else {
} else {
constexpr auto coalesced_layout = coalesce(Layout{});
if constexpr (rank(coalesced_layout) == 1 &&
stride<0>(coalesced_layout) == 1) {
......
......@@ -52,6 +52,7 @@
// clang-format off
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cute/tensor.hpp"
namespace cutlass::epilogue::threadblock {
......
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
/*
This file defines custom epilogues for fusing channel scales, token scales,
bias, and activation zero-points onto a GEMM operation using the
CUTLASS 2.x API, for sm80 (Ampere) NVIDIA GPUs.
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace vllm::c2x {
using namespace cute;
/*
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBase {
protected:
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
template <typename T>
using ColOrScalarLoad =
cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowOrScalarLoad =
cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
template <typename T>
using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
template <typename T>
using RowOrZeroLoad =
cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template <typename Descriptor, typename T>
static auto args_from_tensor(torch::Tensor const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
return Arguments{data_ptr, tensor.numel() != 1};
} else {
// it would technically work but no use case as data_ptr is never nullptr
static_assert(!std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
return Arguments{data_ptr};
}
}
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T>
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
return Arguments{data_ptr};
}
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch._scaled_mm.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogue
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
* This bias can also be used in the per-tensor azp case, where the activation
* zero point (azp) is used to compute an azp correction term,
* which is folded into the bias.
*
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBias
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
protected:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args, bias_args};
}
};
/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
* term, which should already be multiplied with the scalar azp.
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBiasAzp
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
// This is the full AZP term, azp * J @ B, shape (1,n)
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
// Compute float(accum - azp_adj), both operands are int32_t
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>;
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
EVTComputeAzp>;
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
return ArgumentType{a_args, evt_scale_b_args, bias_args};
}
};
/*
* This epilogue supports per-token azp by computing and applying
* the correction term using a rank-1 update. If the term were materialized,
* it would require O(m*n) space, and this way it only requires O(m+n) space.
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
* point for each row of A.
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBiasAzpToken
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
// Per-token azp term, shape (m,1)
using Azp = typename SUPER::template ColLoad<int32_t>;
// This is the AZP adjustment term, J @ B, shape (1,n)
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
// Compute azp * azp_adj
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, int32_t, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Azp, AzpAdj>;
// Compute float(accum - azp*azp_adj), all operands are int32_t
using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAcc =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>;
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
EVTComputeAcc>;
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
torch::Tensor const& azp,
c10::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
return ArgumentType{a_args, evt_scale_b_args, bias_args};
}
};
}; // namespace vllm::c2x
\ No newline at end of file
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
/*
This file defines custom epilogues for fusing channel scales, token scales,
bias, and activation zero-points onto a GEMM operation using the
CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later.
Epilogues must contain a public type named EVTCompute of type Sm90EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace vllm::c3x {
using namespace cute;
/*
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBase {
protected:
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
template <typename T>
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<0>, Int<1>, Int<0>>>;
// Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false>
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
// Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false>
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template <typename Descriptor, typename T>
static auto args_from_tensor(torch::Tensor const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
return Arguments{data_ptr, tensor.numel() != 1};
} else {
static_assert(!std::is_same_v<Descriptor, ColLoad<T, true>> &&
!std::is_same_v<Descriptor, RowLoad<T, true>>);
return Arguments{data_ptr};
}
}
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T>
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
std::is_same_v<Descriptor, RowLoad<T, true>>);
return Arguments{data_ptr};
}
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch.scaled_mm_.
A and B may be both either int8 or fp8_e4m3. A can be
quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogue
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
* This bias can also be used in the per-tensor azp case, where the activation
* zero point (azp) is used to compute an azp correction term,
* which is folded into the bias.
*
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBias
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args, bias_args};
}
};
/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
* term, which should already be multiplied with the scalar azp.
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBiasAzp
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD, true>;
// This is the full AZP term, azp * J @ B, shape (1,n)
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
// Compute float(accum - azp_adj), both operands are int32_t
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Accum, AzpWithAdj>;
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
return ArgumentType{a_args, evt_scale_b_args, bias_args};
}
};
/*
* This epilogue supports per-token azp by computing and applying
* the correction term using a rank-1 update. If the term were materialized,
* it would require O(m*n) space, and this way it only requires O(m+n) space.
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
* point for each row of A.
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBiasAzpToken
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD, true>;
// Per-token azp term, shape (m,1)
using Azp = typename SUPER::template ColLoad<int32_t>;
// This is the AZP adjustment term, J @ B, shape (1,n)
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
// Compute azp * azp_adj
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, int32_t, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Azp, AzpAdj>;
// Compute float(accum - azp*azp_adj), all operands are int32_t
using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAcc =
cutlass::epilogue::fusion::Sm90EVT<ComputeAcc, Accum, EVTComputeAzp>;
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
torch::Tensor const& azp,
c10::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
return ArgumentType{a_args, evt_scale_b_args, bias_args};
}
};
}; // namespace vllm::c3x
\ No newline at end of file
......@@ -35,6 +35,35 @@ VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
}
}
VLLMDataTypeSize: Dict[Union[VLLMDataType, DataType], int] = {
**DataTypeSize, # type: ignore
**{
VLLMDataType.u4b8: 4,
VLLMDataType.u8b128: 8,
}
}
VLLMDataTypeVLLMScalarTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
VLLMDataType.u4b8: "vllm::kU4B8",
VLLMDataType.u8b128: "vllm::kU8B128",
DataType.u4: "vllm::kU4",
DataType.u8: "vllm::kU8",
DataType.s4: "vllm::kS4",
DataType.s8: "vllm::kS8",
DataType.f16: "vllm::kFloat16",
DataType.bf16: "vllm::kBfloat16",
}
VLLMDataTypeTorchDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
DataType.u8: "at::ScalarType::Byte",
DataType.s8: "at::ScalarType::Char",
DataType.e4m3: "at::ScalarType::Float8_e4m3fn",
DataType.s32: "at::ScalarType::Int",
DataType.f16: "at::ScalarType::Half",
DataType.bf16: "at::ScalarType::BFloat16",
DataType.f32: "at::ScalarType::Float",
}
VLLMKernelScheduleTag: Dict[Union[
MixedInputKernelScheduleType, KernelScheduleType], str] = {
**KernelScheduleTag, # type: ignore
......
......@@ -3,6 +3,7 @@
#include "cutlass/numeric_conversion.h"
#include "cutlass_extensions/vllm_custom_types.cuh"
#include "cutlass_extensions/cute_utils.cuh"
#include "cutlass_extensions/vllm_type_utils.cuh"
// this file extends:
// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h
......@@ -28,8 +29,19 @@ struct InterleavedNumericArrayConverter {
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
CUTE_INVALID_CONTROL_PATH(
"InterleavedNumericArrayConverter not implemented\n");
if (cute::elect_one_sync()) {
if constexpr (std::is_same_v<IlvBlkLayout, void>) {
printf(
"Convert %s <= %s (N = %d, IlvBlkLayout = void), not implemented\n",
nameof_v<T>, nameof_v<S>, N);
} else {
printf(
"Convert %s <= %s (N = %d, size(IlvBlkLayout{}) = %d), not "
"implemented\n",
nameof_v<T>, nameof_v<S>, N, size(IlvBlkLayout{}));
}
__brkpt();
}
return {};
}
......@@ -56,11 +68,6 @@ struct InterleavedNumericArrayConverter<
result_type operator()(source_type const& s) const { return convert(s); }
};
// TODO (LucasWilkinson): Implement
// for Array<cutlass::float8_e4m3fn, N> <= Array<vllm_uint4b8_t, N>
// ....
template <typename RegConvert32bit, typename T, typename S, int N>
struct ArrayConverterPacked32Bit {
using result_type = Array<T, N>;
......@@ -86,14 +93,16 @@ struct ArrayConverterPacked32Bit {
using ScalarConverter = NumericConverter<T, S>;
template <typename PackedSrc>
CUTLASS_DEVICE static uint32_t to_reg(PackedSrc const& source) {
CUTLASS_DEVICE static auto to_regs(PackedSrc const& src) {
if constexpr (sizeof(PackedSrc) == 1) {
return static_cast<uint32_t>(reinterpret_cast<const uint8_t&>(source));
return Array<uint32_t, 1>{reinterpret_cast<uint8_t const&>(src)};
} else if constexpr (sizeof(PackedSrc) == 2) {
return static_cast<uint32_t>(reinterpret_cast<const uint16_t&>(source));
return Array<uint32_t, 1>{reinterpret_cast<uint16_t const&>(src)};
} else if constexpr (sizeof(PackedSrc) == 4) {
return Array<uint32_t, 1>{reinterpret_cast<uint32_t const&>(src)};
} else {
static_assert(sizeof(PackedSrc) == 4);
return reinterpret_cast<const uint32_t&>(source);
static_assert(sizeof(PackedSrc) == 8);
return reinterpret_cast<Array<uint32_t, 2> const&>(src);
}
}
......@@ -110,7 +119,7 @@ struct ArrayConverterPacked32Bit {
static_assert(std::is_same_v<typename PackedSrcType::Element, S>);
static_assert(std::is_same_v<typename PackedResultType::Element, T>);
return RegConvert32bit::template convert<PackedResultType>(to_reg(source));
return RegConvert32bit::template convert<PackedResultType>(to_regs(source));
}
friend class detail::VectorizedConverter;
......@@ -140,6 +149,131 @@ struct ArrayConverterPacked32Bit {
}
};
// Convert 8 4bit values packed into a 32bit register to 8 8bit values packed
// into 2 32bit register.
template <uint8_t LUT0, uint8_t LUT1, uint8_t LUT2, uint8_t LUT3, //
uint8_t LUT4, uint8_t LUT5, uint8_t LUT6, uint8_t LUT7, //
uint8_t LUT8, uint8_t LUT9, uint8_t LUT10, uint8_t LUT11, //
uint8_t LUT12, uint8_t LUT13, uint8_t LUT14, uint8_t LUT15>
CUTLASS_DEVICE cutlass::AlignedArray<uint32_t, 2> lut_4bit_to_8bit_convert(
uint32_t src) {
cutlass::AlignedArray<uint32_t, 2> r;
// Determines if the value is in the top half of the LUT if set or
// (i.e. LUT[8:15]) in the bottom half (i.e. LUT[0:7]) if not set. Then move
// into bit position 0x4 of each nibble so when or'd with final_prmt_base it
// selects the correct candidate. When elements in final_prmt_base
// are >= 0x4, the high candidate is selected (i.e. LUT[8:15]), when elements
// are < 0x4, the low candidate is selected (i.e. LUT[0:7])
uint32_t high_bit = (src & 0x88888888) >> 1;
// `high_bit` is OR'd with 0x31203120 to find the correct value in the LUT
// (selects correct high or low candidate)
const uint32_t final_prmt_base = 0x32103210;
// Ignore the high bit when indexing into LUT, for each 4bit value
// we index into both the high and low candidates then use
// high_bit | final_prmt_base to select the correct candidate
uint32_t lut_idx = (src & 0x77777777);
auto pack = [](uint8_t a, uint8_t b, uint8_t c, uint8_t d) {
return uint32_t(a) | (uint32_t(b) << 8) | (uint32_t(c) << 16) |
(uint32_t(d) << 24);
};
static constexpr uint32_t LOW_0 = pack(LUT0, LUT1, LUT2, LUT3);
static constexpr uint32_t LOW_1 = pack(LUT4, LUT5, LUT6, LUT7);
static constexpr uint32_t HIGH_0 = pack(LUT8, LUT9, LUT10, LUT11);
static constexpr uint32_t HIGH_1 = pack(LUT12, LUT13, LUT14, LUT15);
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < 2; ++ii, lut_idx >>= 16, high_bit >>= 16) {
uint32_t final_prmt_idx = final_prmt_base | high_bit;
// This uses a look up table to convert packed int4s to packed int8s,
// using the int4 value as the index to prmt. It first select both the
// high and low candidates, then uses the high bit (i.e. `high_bit`) to
// select the correct candidate.
asm volatile(
"{\n"
" .reg .b32 low, high;\n"
" prmt.b32 low, %1, %2, %5;\n"
" prmt.b32 high, %3, %4, %5;\n"
" prmt.b32 %0, low, high, %6;\n"
"}\n"
: "=r"(r[ii])
: "n"(LOW_0), "n"(LOW_1), "n"(HIGH_0), "n"(HIGH_1), "r"(lut_idx),
"r"(final_prmt_idx));
}
return r;
};
// for Array<int8_t, N> <= Array<vllm_uint4b8_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<int8_t, vllm_uint4b8_t, N, Round> {
using result_type = Array<int8_t, N>;
using source_type = Array<vllm_uint4b8_t, N>;
static FloatRoundStyle const round_style = Round;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
// [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as int8s
auto r = lut_4bit_to_8bit_convert<0xF8, 0xF9, 0xFA, 0xFB, //
0xFC, 0xFD, 0xFE, 0xFF, //
0x00, 0x01, 0x02, 0x03, //
0x04, 0x05, 0x06, 0x07>(src_[0]);
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::float_e4m3_t, N> <= Array<vllm_uint4b8_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::float_e4m3_t, vllm_uint4b8_t, N, Round> {
using result_type = Array<cutlass::float_e4m3_t, N>;
using source_type = Array<vllm_uint4b8_t, N>;
static FloatRoundStyle const round_style = Round;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
// [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as fp8s
auto r = lut_4bit_to_8bit_convert<0xD0, 0xCE, 0xCC, 0xCA, //
0xC8, 0xC4, 0xC0, 0xB8, //
0x00, 0x38, 0x40, 0x44, //
0x48, 0x4A, 0x4C, 0x4E>(src_[0]);
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::half_t, N> <= Array<vllm_uint4b8_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::half_t, vllm_uint4b8_t, N, Round> {
......@@ -148,7 +282,8 @@ struct NumericArrayConverter<cutlass::half_t, vllm_uint4b8_t, N, Round> {
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
......@@ -249,7 +384,8 @@ struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
......@@ -338,7 +474,8 @@ struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
......@@ -417,7 +554,8 @@ struct NumericArrayConverter<cutlass::half_t, vllm_uint8b128_t, N, Round> {
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
// Hold output FP16s in reg. We need 1 reg for every 2 elements
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
......@@ -469,7 +607,8 @@ struct NumericArrayConverter<float, vllm_uint8b128_t, N, Round> {
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
PackedResultType r;
// __byte_perm simulates the add.u32 0x4B000000 to every u8 element of
......@@ -513,7 +652,8 @@ struct NumericArrayConverter<cutlass::bfloat16_t, vllm_uint4b8_t, N, Round> {
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src_reg) {
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src_reg = src_[0];
// Hold output BF16s in reg. We need 1 reg for every 2 elements
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
......@@ -603,7 +743,8 @@ struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
......@@ -671,7 +812,8 @@ struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
......@@ -788,6 +930,61 @@ struct NumericArrayConverter<cutlass::bfloat16_t, vllm_uint8b128_t, N, Round> {
#endif
// for Array<int8_t, N> <= Array<cutlass::half_t, N>
// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<int8_t, cutlass::half_t, N, Round> {
using result_type = Array<int8_t, N>;
using source_type = Array<cutlass::half_t, N>;
struct RegConvert {
// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904
template <typename PackedResultType, int src_regs>
CUTLASS_DEVICE static PackedResultType convert(
Array<uint32_t, src_regs> src) {
// Hold output int8s in reg. We need 1 reg for every 4 elements
using RegArray = cutlass::AlignedArray<
uint32_t, std::max(PackedResultType::kElements / 4, size_t(1))>;
RegArray r;
static constexpr uint32_t MAGIC_BIAS_ = 0x64806480;
auto MAGIC_BIAS = *reinterpret_cast<const half2*>(&MAGIC_BIAS_);
*reinterpret_cast<half2*>(&src[0]) =
__hadd2(*reinterpret_cast<half2*>(&src[0]), MAGIC_BIAS);
if constexpr (src_regs > 1) {
*reinterpret_cast<half2*>(&src[1]) =
__hadd2(*reinterpret_cast<half2*>(&src[1]), MAGIC_BIAS);
}
static_assert(PackedResultType::kElements <= 4);
uint32_t uint8s;
static constexpr uint32_t MASK_0246 = 0x6420;
static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080;
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
: "=r"(uint8s)
: "r"(src[0]), "r"((src_regs > 1) ? src[1] : src[0]),
"n"(MASK_0246));
uint32_t int8s = (uint8s ^ UINT8s_TO_INT8s_MASK);
return reinterpret_cast<PackedResultType&>(int8s);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
......
#include "cutlass/bfloat16.h"
#include "cutlass/half.h"
#include "cuda_bf16.h"
#include "cutlass_extensions/vllm_custom_types.cuh"
namespace cutlass {
template <typename T>
struct nameof {
static constexpr char const* value = "unknown";
};
template <typename T>
inline constexpr auto nameof_v = nameof<T>::value;
#define NAMEOF_TYPE(T) \
template <> \
struct nameof<T> { \
static constexpr char const* value = #T; \
};
NAMEOF_TYPE(float_e4m3_t)
NAMEOF_TYPE(float_e5m2_t)
NAMEOF_TYPE(half_t)
NAMEOF_TYPE(nv_bfloat16)
NAMEOF_TYPE(bfloat16_t)
NAMEOF_TYPE(float)
NAMEOF_TYPE(int4b_t)
NAMEOF_TYPE(int8_t)
NAMEOF_TYPE(int32_t)
NAMEOF_TYPE(int64_t)
NAMEOF_TYPE(vllm_uint4b8_t)
NAMEOF_TYPE(uint4b_t)
NAMEOF_TYPE(uint8_t)
NAMEOF_TYPE(vllm_uint8b128_t)
NAMEOF_TYPE(uint32_t)
NAMEOF_TYPE(uint64_t)
}; // namespace cutlass
\ No newline at end of file
......@@ -8,6 +8,10 @@
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"
using namespace vllm;
/*
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
......@@ -22,12 +26,11 @@ void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) {
return vllm::cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>(
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return vllm::cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
......@@ -42,10 +45,10 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBias>(
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogue>(
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
......@@ -61,10 +64,10 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (azp) {
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBiasAzpToken>(
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBiasAzp>(
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}
......@@ -78,12 +81,11 @@ void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) {
return vllm::cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>(
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return vllm::cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
......@@ -98,10 +100,10 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBias>(
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogue>(
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
......@@ -117,10 +119,10 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (azp) {
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBiasAzpToken>(
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBiasAzp>(
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}
......@@ -134,13 +136,12 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) {
return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>(
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
assert(out.dtype() == torch::kFloat16);
return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t,
Epilogue>(
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
} else {
......@@ -148,13 +149,13 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
if (out.dtype() == torch::kBFloat16) {
return vllm::cutlass_gemm_sm89_fp8_dispatch<
cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue>(
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return vllm::cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>(
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
......@@ -170,10 +171,10 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBias>(
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogue>(
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
......@@ -189,10 +190,10 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (azp) {
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBiasAzpToken>(
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBiasAzp>(
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}
......@@ -21,7 +21,6 @@
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "broadcast_load_epilogue_c2x.hpp"
#include "common.hpp"
// clang-format on
......@@ -71,307 +70,6 @@ struct enable_sm89_to_sm90 : Kernel {
#endif
}
};
/*
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBase {
protected:
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
template <typename T>
using ColOrScalarLoad =
cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowOrScalarLoad =
cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
template <typename T>
using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
template <typename T>
using RowOrZeroLoad =
cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template <typename Descriptor, typename T>
static auto args_from_tensor(torch::Tensor const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
return Arguments{data_ptr, tensor.numel() != 1};
} else {
// it would technically work but no use case as data_ptr is never nullptr
static_assert(!std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
return Arguments{data_ptr};
}
}
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T>
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
return Arguments{data_ptr};
}
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch._scaled_mm.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogue
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
* This bias can also be used in the per-tensor azp case, where the activation
* zero point (azp) is used to compute an azp correction term,
* which is folded into the bias.
*
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBias
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
protected:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args, bias_args};
}
};
/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
* term, which should already be multiplied with the scalar azp.
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBiasAzp
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
// This is the full AZP term, azp * J @ B, shape (1,n)
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
// Compute float(accum - azp_adj), both operands are int32_t
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>;
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
EVTComputeAzp>;
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
return ArgumentType{a_args, evt_scale_b_args, bias_args};
}
};
/*
* This epilogue supports per-token azp by computing and applying
* the correction term using a rank-1 update. If the term were materialized,
* it would require O(m*n) space, and this way it only requires O(m+n) space.
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
* point for each row of A.
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBiasAzpToken
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
// Per-token azp term, shape (m,1)
using Azp = typename SUPER::template ColLoad<int32_t>;
// This is the AZP adjustment term, J @ B, shape (1,n)
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
// Compute azp * azp_adj
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, int32_t, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Azp, AzpAdj>;
// Compute float(accum - azp*azp_adj), all operands are int32_t
using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAcc =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>;
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
EVTComputeAcc>;
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
torch::Tensor const& azp,
c10::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
return ArgumentType{a_args, evt_scale_b_args, bias_args};
}
};
template <typename Arch, template <typename> typename ArchGuard,
typename ElementAB_, typename ElementD_,
template <typename, typename> typename Epilogue_, typename TileShape,
......
......@@ -23,11 +23,12 @@
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "broadcast_load_epilogue_c3x.hpp"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "common.hpp"
// clang-format on
using namespace cute;
using namespace vllm;
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
......@@ -56,305 +57,6 @@ struct enable_sm90_or_later : Kernel {
#endif
}
};
/*
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBase {
protected:
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
template <typename T>
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<0>, Int<1>, Int<0>>>;
// Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false>
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
// Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false>
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template <typename Descriptor, typename T>
static auto args_from_tensor(torch::Tensor const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
return Arguments{data_ptr, tensor.numel() != 1};
} else {
static_assert(!std::is_same_v<Descriptor, ColLoad<T, true>> &&
!std::is_same_v<Descriptor, RowLoad<T, true>>);
return Arguments{data_ptr};
}
}
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T>
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
std::is_same_v<Descriptor, RowLoad<T, true>>);
return Arguments{data_ptr};
}
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch.scaled_mm_.
A and B may be both either int8 or fp8_e4m3. A can be
quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogue
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
* This bias can also be used in the per-tensor azp case, where the activation
* zero point (azp) is used to compute an azp correction term,
* which is folded into the bias.
*
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBias
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args, bias_args};
}
};
/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
* term, which should already be multiplied with the scalar azp.
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBiasAzp
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD, true>;
// This is the full AZP term, azp * J @ B, shape (1,n)
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
// Compute float(accum - azp_adj), both operands are int32_t
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Accum, AzpWithAdj>;
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
return ArgumentType{a_args, evt_scale_b_args, bias_args};
}
};
/*
* This epilogue supports per-token azp by computing and applying
* the correction term using a rank-1 update. If the term were materialized,
* it would require O(m*n) space, and this way it only requires O(m+n) space.
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
* point for each row of A.
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBiasAzpToken
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD, true>;
// Per-token azp term, shape (m,1)
using Azp = typename SUPER::template ColLoad<int32_t>;
// This is the AZP adjustment term, J @ B, shape (1,n)
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
// Compute azp * azp_adj
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, int32_t, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Azp, AzpAdj>;
// Compute float(accum - azp*azp_adj), all operands are int32_t
using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAcc =
cutlass::epilogue::fusion::Sm90EVT<ComputeAcc, Accum, EVTComputeAzp>;
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
torch::Tensor const& azp,
c10::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
return ArgumentType{a_args, evt_scale_b_args, bias_args};
}
};
template <typename ElementAB_, typename ElementD_,
template <typename, typename, typename> typename Epilogue_,
typename TileShape, typename ClusterShape, typename KernelSchedule,
......@@ -721,11 +423,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
if (bias) {
TORCH_CHECK(bias->dtype() == c.dtype(),
"currently bias dtype must match output dtype ", c.dtype());
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBias>(
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBias>(
c, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogue>(c, a, b, a_scales,
b_scales);
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogue>(
c, a, b, a_scales, b_scales);
}
}
......@@ -740,10 +442,10 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (azp) {
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBiasAzpToken>(
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBiasAzp>(
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}
......
This diff is collapsed.
......@@ -171,6 +171,10 @@ struct MacheteCollectiveMma {
make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}),
Int<DispatchPolicy::Stages>{})));
using SmemLayoutACopy = decltype(GmemLayoutA::TVbNbKL_to_offset_copy(
make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}),
Int<DispatchPolicy::Stages>{})));
using SmemLayoutAtomARowMajor =
decltype(rs_smem_selector<GmmaMajorA, ElementA,
decltype(cute::get<0>(TileShape_MNK{})),
......@@ -288,14 +292,7 @@ struct MacheteCollectiveMma {
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0,
"SmemLayoutAtomScale must evenly divide tile k shape.");
// Tile along modes in a way that maximizes the TMA box size.
using SmemLayoutACopy = decltype(tile_to_shape(
SmemLayoutAtomARowMajor{},
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}),
Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(),
Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
// Tile along modes in a way that maximizes the TMA box size
using SmemLayoutB = decltype(tile_to_shape(
SmemLayoutAtomB{},
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}),
......@@ -428,12 +425,12 @@ struct MacheteCollectiveMma {
// clang-format on
// ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx)
using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset(
using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset_copy(
make_shape(int32_t(0), int32_t(0), int32_t(0)))));
using ATensor = decltype(make_tensor(
get_logical_ptr(static_cast<InternalElementA const*>(nullptr)),
shape(GmemLayoutA::TVbNbKL_to_offset(
shape(GmemLayoutA::TVbNbKL_to_offset_copy(
make_shape(int32_t(0), int32_t(0), int32_t(0)))),
PrepackedStrideA{}));
......@@ -450,8 +447,8 @@ struct MacheteCollectiveMma {
static constexpr auto make_tma_copy_A(ATensor tensor_a = ATensor{}) {
return make_tma_copy<TmaElementA>(
GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}),
shape(SmemLayoutA{}(_, _, cute::Int<0>{})),
GmemTiledCopyA{}, tensor_a, SmemLayoutACopy{}(_, _, cute::Int<0>{}),
shape(SmemLayoutACopy{}(_, _, cute::Int<0>{})),
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
}
......@@ -584,7 +581,7 @@ struct MacheteCollectiveMma {
typename Params::TMA_Scale tma_load_scale;
typename Params::TMA_Zero tma_load_zero;
auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L));
auto layout = GmemLayoutA::TVbNbKL_to_offset_copy(make_shape(M, K, L));
tma_load_a = make_tma_copy_A(
make_logical_tensor(ptr_A, shape(layout), stride(layout)));
......@@ -722,7 +719,7 @@ struct MacheteCollectiveMma {
// (TILE_V,TILE_B,m,k,l)
auto make_gA_mkl = [&]() {
// ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx)
auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L));
auto layout = GmemLayoutA::TVbNbKL_to_offset_copy(make_shape(M, K, L));
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(layout));
return local_tile(mA_mkl,
make_shape(size<0>(layout), PPBlocksPerTile_MK{}),
......
......@@ -21,6 +21,8 @@
#include "cutlass_extensions/cute_utils.cuh"
#include "cutlass_extensions/vllm_numeric_conversion.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "cutlass_extensions/torch_utils.hpp"
#include "machete_collective_builder.cuh"
#include "machete_prepacked_layout.cuh"
#include "machete_interleaving_utils.cuh"
......@@ -37,27 +39,42 @@ using namespace cute;
// W is quantized, in this situation or right-hand operand is quantized so
// we compute the transpose to move it to the left-hand side.
template <typename ElementA_, typename ElementB_, typename ElementD_,
typename AccumulatorT, typename ScaleT, typename ZeroT,
class KernelSchedule, typename ScheduleConfig, bool with_C,
bool with_scales, bool with_zeropoints>
typename AccumulatorT, typename GroupScaleT, typename GroupZeroT,
typename ChannelScaleT, typename TokenScaleT, class KernelSchedule,
typename ScheduleConfig>
struct MacheteKernelTemplate {
static constexpr bool with_C = false; // not ever used
static constexpr bool with_group_scales = !std::is_same_v<GroupScaleT, void>;
static constexpr bool with_group_zeropoints =
!std::is_same_v<GroupZeroT, void>;
static constexpr bool with_channel_scales =
!std::is_same_v<ChannelScaleT, void>;
static constexpr bool with_token_scales = !std::is_same_v<TokenScaleT, void>;
using MmaType = ElementA_;
using ElementA = ElementA_;
using ElementB = ElementB_;
using ElementD = ElementD_;
using ElementC = cute::conditional_t<with_C, ElementD, void>;
using ElementZ = ZeroT;
using ElementS = ScaleT;
using ElementAccumulator =
AccumulatorT; // Element type for internal accumulation
using ElementAccumulator = AccumulatorT;
using ElementCompute = AccumulatorT; // For Epilogue
// Use dummy values when we don't have scales or zeropoints
using ElementZGroup =
cute::conditional_t<with_group_zeropoints, GroupZeroT, MmaType>;
using ElementSGroup =
cute::conditional_t<with_group_scales, GroupScaleT, MmaType>;
using ElementConvertGroup =
cute::conditional_t<with_group_scales, GroupScaleT, MmaType>;
using ElementSChannel =
cute::conditional_t<with_channel_scales, ChannelScaleT, AccumulatorT>;
using ElementSToken =
cute::conditional_t<with_token_scales, TokenScaleT, AccumulatorT>;
using BTypeTuple = cute::conditional_t<
with_scales,
cute::conditional_t<with_zeropoints,
cute::tuple<ElementB, ElementS, ElementZ>,
cute::tuple<ElementB, ElementS>>,
with_group_scales,
cute::conditional_t<with_group_zeropoints,
cute::tuple<ElementB, ElementSGroup, ElementZGroup>,
cute::tuple<ElementB, ElementSGroup>>,
ElementB>;
using LayoutA = cutlass::layout::RowMajor;
......@@ -71,8 +88,8 @@ struct MacheteKernelTemplate {
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
using StrideC = cutlass::detail::TagToStrideA_t<LayoutC>;
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
using StrideS = cutlass::detail::TagToStrideA_t<LayoutScale>;
using StrideZ = StrideS;
using StrideSGroup = cutlass::detail::TagToStrideA_t<LayoutScale>;
using StrideZGroup = StrideSGroup;
using LayoutA_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
......@@ -85,8 +102,8 @@ struct MacheteKernelTemplate {
using OperatorClass = cutlass::arch::OpClassTensorOp;
using PrepackedLayoutB =
PrepackedLayoutBTemplate<ElementA_, ElementB_, ElementD_, AccumulatorT,
LayoutA_Transpose, KernelSchedule>;
PrepackedLayoutBTemplate<ElementA_, ElementB_, ElementConvertGroup,
AccumulatorT, LayoutA_Transpose, KernelSchedule>;
static int constexpr TileShapeK =
128 * 8 / cutlass::sizeof_bits<MmaType>::value;
......@@ -103,12 +120,42 @@ struct MacheteKernelTemplate {
using EpilogueTileType = typename ScheduleConfig::EpilogueTileType;
using TileScheduler = typename ScheduleConfig::TileScheduler;
static_assert(
(!with_channel_scales && !with_token_scales) ||
((with_channel_scales && with_token_scales) &&
std::is_same_v<ElementSChannel, ElementSToken>),
"Currently token and channel scales (if present) must be the same type");
using EpilogueDescriptor =
cutlass::epilogue::collective::detail::EpilogueDescriptor<
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
ElementD, EpilogueSchedule>;
// Currently only supports float scales
using ChTokScalesEpilogue =
typename vllm::c3x::ScaledEpilogue<ElementAccumulator, ElementD,
EpilogueDescriptor>;
static_assert((with_channel_scales || with_token_scales) ||
(std::is_same_v<ElementSChannel, float> &&
std::is_same_v<ElementSToken, float>),
"Currently token and channel scales (if present) must be float "
"(and if one is present the other must be too)");
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<
cutlass::epilogue::fusion::Sm90AccFetch>;
using EVTCompute =
std::conditional_t<with_channel_scales || with_token_scales,
typename ChTokScalesEpilogue::EVTCompute,
StoreEpilogueCompute>;
// EVTCompute
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
ElementAccumulator, ElementAccumulator, ElementC, LayoutC_Transpose,
AlignmentC, ElementD, LayoutD_Transpose, AlignmentD,
EpilogueSchedule>::CollectiveOp;
ElementAccumulator, ElementSChannel, ElementC, LayoutC_Transpose,
AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, EpilogueSchedule,
EVTCompute>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::VLLMCollectiveBuilder<
......@@ -131,26 +178,44 @@ struct MacheteKernelTemplate {
using MainloopArguments = typename GemmKernel::MainloopArguments;
using EpilogueArguments = typename GemmKernel::EpilogueArguments;
template <typename ShapeA, typename ShapeC, typename ShapeD, typename ShapeS,
typename ShapeZ>
static Arguments create_arguments(
cudaStream_t stream,
ElementA const* A_ptr, // A is an MxK matrix
Layout<ShapeA, StrideA> const& layout_A,
ElementB const* B_ptr, // B is an KxN prepacked matrix
ElementD* D_ptr, // D is an MxN matrix
Layout<ShapeD, StrideD> const& layout_D,
ElementC const* C_ptr, // C is an MxN matrix
std::optional<Layout<ShapeC, StrideC>> const& layout_C,
ElementS const* S_ptr, // S is an scale_KxN matrix
std::optional<Layout<ShapeS, StrideS>> const& layout_S,
ElementZ const* Z_ptr, // Z is an scale_KxN matrix
std::optional<Layout<ShapeZ, StrideZ>> const& layout_Z,
ElementCompute alpha, ElementCompute beta,
std::optional<int> maybe_group_size) {
static_assert(!with_zeropoints || with_scales);
int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A);
torch::Tensor const& A, // MxK matrix
torch::Tensor const& B, // KxN prepacked matrix
torch::Tensor& D, // MxN matrix
c10::optional<torch::Tensor> const& maybe_g_scales, // scale_KxN matrix
c10::optional<torch::Tensor> const& maybe_g_zeros, // scale_KxN matrix
c10::optional<int64_t> maybe_group_size,
c10::optional<torch::Tensor> const& maybe_ch_scales, // len N vector
c10::optional<torch::Tensor> const& maybe_tok_scales) // len M vector
{
static_assert(!with_group_zeropoints || with_group_scales);
int M = A.size(0), N = B.size(1), K = A.size(1);
TORCH_CHECK(D.size(0) == M && D.size(1) == N);
auto layout_A = make_cute_layout<StrideA>(A, "A");
auto layout_D = make_cute_layout<StrideD>(D, "D");
auto layout_S_group =
maybe_make_cute_layout<StrideSGroup>(maybe_g_scales, "group_scales");
auto layout_Z_group =
maybe_make_cute_layout<StrideZGroup>(maybe_g_zeros, "group_zeros");
int64_t numel_S_channel = maybe_ch_scales ? maybe_ch_scales->numel() : 0;
int64_t numel_S_token = maybe_tok_scales ? maybe_tok_scales->numel() : 0;
auto unwrap = [](auto const& t) {
return t ? t->const_data_ptr() : nullptr;
};
auto A_ptr = static_cast<ElementA const*>(A.const_data_ptr());
auto B_ptr = static_cast<ElementB const*>(B.const_data_ptr());
auto D_ptr = static_cast<ElementD*>(D.mutable_data_ptr());
auto S_group_ptr =
static_cast<ElementSGroup const*>(unwrap(maybe_g_scales));
auto Z_group_ptr = static_cast<ElementZGroup const*>(unwrap(maybe_g_zeros));
auto S_channel_ptr =
static_cast<ElementSChannel const*>(unwrap(maybe_ch_scales));
auto S_token_ptr =
static_cast<ElementSToken const*>(unwrap(maybe_tok_scales));
int const group_size =
maybe_group_size == -1 ? K : maybe_group_size.value_or(K);
......@@ -159,26 +224,28 @@ struct MacheteKernelTemplate {
TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);
TORCH_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N);
if constexpr (with_C) {
TORCH_CHECK(C_ptr && layout_C);
if constexpr (with_group_scales) {
TORCH_CHECK(S_group_ptr && layout_S_group);
TORCH_CHECK((size<0>(*layout_S_group) == scale_k &&
size<1>(*layout_S_group) == N));
} else {
TORCH_CHECK(!C_ptr, "C not supported");
TORCH_CHECK(!S_group_ptr, "Scales not supported");
}
if constexpr (with_scales) {
TORCH_CHECK(S_ptr && layout_S);
TORCH_CHECK((size<0>(*layout_S) == scale_k && size<1>(*layout_S) == N));
if constexpr (with_group_zeropoints) {
TORCH_CHECK(Z_group_ptr && layout_Z_group);
TORCH_CHECK((size<0>(*layout_Z_group) == scale_k &&
size<1>(*layout_Z_group) == N));
TORCH_CHECK(layout_S_group && *layout_Z_group == *layout_S_group,
"Scales and zeros must have the same layout");
} else {
TORCH_CHECK(!S_ptr, "Scales not supported");
TORCH_CHECK(!Z_group_ptr, "Zeropoints not supported");
}
if constexpr (with_zeropoints) {
TORCH_CHECK(Z_ptr && layout_Z);
TORCH_CHECK((size<0>(*layout_Z) == scale_k && size<1>(*layout_Z) == N));
TORCH_CHECK(layout_S && *layout_Z == *layout_S,
"Scales and zeros must have the same layout");
} else {
TORCH_CHECK(!Z_ptr, "Zeropoints not supported");
if constexpr (with_channel_scales || with_token_scales) {
TORCH_CHECK(
(maybe_ch_scales->numel() == N || maybe_ch_scales->numel() == 1) &&
(maybe_tok_scales->numel() == M || maybe_tok_scales->numel() == 1));
}
// Transpose A and D
......@@ -186,24 +253,33 @@ struct MacheteKernelTemplate {
// for B (which is At)
auto stride_At = layout_A.stride();
auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride();
auto stride_Ct = stride_Dt;
if (layout_C) {
stride_Ct = permute_layout<1, 0, 2>(*layout_C).stride();
}
MainloopArguments mainloop_arguments{};
EpilogueArguments epilogue_arguments{
{alpha, beta}, C_ptr, stride_Ct, D_ptr, stride_Dt};
// {Accum, C, C_layout, D, D}
EpilogueArguments epilogue_arguments{};
if constexpr (with_channel_scales || with_token_scales) {
epilogue_arguments =
EpilogueArguments{ChTokScalesEpilogue::prepare_args(
*maybe_ch_scales, *maybe_tok_scales),
nullptr,
{},
D_ptr,
stride_Dt};
} else {
epilogue_arguments = EpilogueArguments{{}, nullptr, {}, D_ptr, stride_Dt};
}
if constexpr (with_scales && with_zeropoints) {
auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride();
mainloop_arguments =
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At,
S_ptr, stride_S, group_size, Z_ptr};
} else if constexpr (with_scales) {
auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride();
if constexpr (with_group_scales && with_group_zeropoints) {
auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride();
mainloop_arguments = MainloopArguments{
B_ptr, _StrideB{}, A_ptr, stride_At, S_ptr, stride_S, group_size};
B_ptr, _StrideB{}, A_ptr, stride_At,
S_group_ptr, stride_S_group, group_size, Z_group_ptr};
} else if constexpr (with_group_scales) {
auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride();
mainloop_arguments =
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At,
S_group_ptr, stride_S_group, group_size};
} else {
mainloop_arguments =
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At};
......
......@@ -5,73 +5,61 @@
#include "machete_mm_kernel.cuh"
#include "cutlass_extensions/torch_utils.hpp"
#include "core/scalar_type.hpp"
namespace machete {
struct PyTorchArguments {
struct MMArgs {
torch::Tensor const& A;
torch::Tensor const& B;
c10::optional<torch::Tensor> const& scales;
c10::optional<torch::Tensor> const& zeros;
c10::optional<int64_t> group_size;
c10::optional<torch::Tensor> const& C;
c10::optional<double> alpha;
c10::optional<double> beta;
c10::optional<std::string> schedule;
vllm::ScalarType const& b_type;
c10::optional<at::ScalarType> const& maybe_out_type;
c10::optional<torch::Tensor> const& maybe_group_scales;
c10::optional<torch::Tensor> const& maybe_group_zeros;
c10::optional<int64_t> maybe_group_size;
c10::optional<torch::Tensor> const& maybe_channel_scales;
c10::optional<torch::Tensor> const& maybe_token_scales;
c10::optional<std::string> maybe_schedule;
};
struct SupportedSchedulesArgs {
at::ScalarType a_type;
vllm::ScalarType b_type;
c10::optional<at::ScalarType> maybe_group_scales_type;
c10::optional<at::ScalarType> maybe_group_zeros_type;
c10::optional<at::ScalarType> maybe_channel_scales_type;
c10::optional<at::ScalarType> maybe_token_scales_type;
c10::optional<at::ScalarType> maybe_out_type;
};
torch::Tensor mm_dispatch(MMArgs args);
std::vector<std::string> supported_schedules_dispatch(
SupportedSchedulesArgs args);
template <typename MacheteKernel>
torch::Tensor run_impl(PyTorchArguments args) {
torch::Tensor run_impl(MMArgs args) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A));
auto device = args.A.device();
auto stream = at::cuda::getCurrentCUDAStream(device.index());
using EleA = typename MacheteKernel::ElementA;
using EleB = typename MacheteKernel::ElementB;
using EleC = typename MacheteKernel::ElementC;
using EleD = typename MacheteKernel::ElementD;
using EleScale = typename MacheteKernel::ElementS;
using EleZero = typename MacheteKernel::ElementZ;
using StrideA = typename MacheteKernel::StrideA;
using StrideC = typename MacheteKernel::StrideC;
using StrideD = typename MacheteKernel::StrideD;
using StrideS = typename MacheteKernel::StrideS;
using StrideZ = typename MacheteKernel::StrideZ;
int M = args.A.size(0);
int N = args.B.size(1);
int K = args.A.size(1);
// Allocate output
torch::Tensor D =
torch::empty({M, N}, torch::TensorOptions()
.dtype(equivalent_scalar_type_v<EleD>)
.device(device));
auto const &A = args.A, &B = args.B;
auto const &C = args.C, &scales = args.scales, &zeros = args.zeros;
auto layout_A = make_cute_layout<StrideA>(A, "A");
auto layout_D = make_cute_layout<StrideD>(D, "D");
auto layout_C = maybe_make_cute_layout<StrideC>(C, "C");
auto layout_S = maybe_make_cute_layout<StrideS>(scales, "scales");
auto layout_Z = maybe_make_cute_layout<StrideZ>(zeros, "zeros");
auto A_ptr = static_cast<EleA const*>(A.const_data_ptr());
auto B_ptr = static_cast<EleB const*>(B.const_data_ptr());
auto D_ptr = static_cast<EleD*>(D.mutable_data_ptr());
auto C_ptr = static_cast<EleC const*>(C ? C->const_data_ptr() : nullptr);
auto S_ptr =
static_cast<EleScale const*>(scales ? scales->const_data_ptr() : nullptr);
auto Z_ptr =
static_cast<EleZero const*>(zeros ? zeros->const_data_ptr() : nullptr);
torch::Tensor D = torch::empty(
{M, N},
torch::TensorOptions()
.dtype(equivalent_scalar_type_v<typename MacheteKernel::ElementD>)
.device(device));
auto arguments = MacheteKernel::create_arguments(
stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr,
layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0),
args.group_size);
stream, //
args.A, args.B, D, args.maybe_group_scales, args.maybe_group_zeros,
args.maybe_group_size, args.maybe_channel_scales,
args.maybe_token_scales);
TORCH_CHECK(MacheteKernel::can_implement(arguments),
"Machete kernel cannot be run with these arguments");
......@@ -84,12 +72,4 @@ torch::Tensor run_impl(PyTorchArguments args) {
return D;
};
template <typename ElementA, typename ElementB, typename ElementD = ElementA,
typename AccumulatorT = float, typename ScaleT = ElementA,
typename ZeroT = ElementA>
struct GemmDispatcher {
static torch::Tensor dispatch(PyTorchArguments args);
static std::vector<std::string> supported_schedules();
};
}; // namespace machete
\ No newline at end of file
......@@ -6,31 +6,49 @@
namespace machete {
template <typename TileShapeNKL, typename ElementB, typename BInTensor,
typename BTiledOutTensor>
static __global__ void prepack_B_kernel(BInTensor B_in,
BTiledOutTensor B_tiled_out) {
auto tB_in = local_tile(B_in, TileShapeNKL{},
make_coord(blockIdx.x, blockIdx.y, blockIdx.z));
auto tB_out = B_tiled_out(make_coord(_, _),
make_coord(blockIdx.x, blockIdx.y), blockIdx.z);
template <int threads, typename PrepackedLayoutB, typename BInTensor,
typename ElementB>
static __global__ void prepack_B_kernel(BInTensor B_in, ElementB* B_out_ptr) {
auto constexpr block_size =
Int<size(typename PrepackedLayoutB::PPBlockShape_NK{})>{};
auto constexpr eles_per_thread = Int<block_size / threads>{};
static_assert(block_size % threads == 0,
"block_size must be divisible by the number of threads");
auto tiled_copy = make_tiled_copy(Copy_Atom<DefaultCopy, ElementB>{},
Layout<Shape<_4, _32>, Stride<_32, _1>>{},
Layout<Shape<_1, _2>>{});
// Which pre-packed are we responsible for
auto blk_coord = make_coord(blockIdx.x, blockIdx.y, blockIdx.z);
auto tB_in = local_tile(
B_in, append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}),
blk_coord);
auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x);
// Find the start offset in the output for this pre-packed block
auto bNbKL_to_offset = PrepackedLayoutB::bNbKL_to_offset(shape(B_in));
Tensor thr_tile_S = thr_copy.partition_S(tB_in);
Tensor thr_tile_D = thr_copy.partition_D(tB_out);
// Tensor representing a 1:1 mapping to the output space in 1D
auto tB_out_linear =
make_tensor(get_logical_ptr(B_out_ptr) + bNbKL_to_offset(blk_coord),
make_layout(make_shape(block_size)));
// Mapping from output space (1D) to input space
auto tB_in_linear = make_tensor(
tB_in.data(),
tB_in.layout()
.compose(right_inverse(PrepackedLayoutB::ppblock_ilvd_NK_to_offset()))
.with_shape(make_shape(block_size)));
// Tile for this specific thread (could have used a TiledCopy but these work
// best with 2d layouts, this is a simple 1d layout so local_tile is enough,
// we are also not that concerned with performance for this kernel)
auto thr_tB_in_linear =
local_tile(tB_in_linear, make_shape(eles_per_thread), threadIdx.x);
auto thr_tB_out_linear =
local_tile(tB_out_linear, make_shape(eles_per_thread), threadIdx.x);
// Construct a register-backed Tensor with the same shape as each thread's
// partition
auto fragment = make_tensor<ElementB>(shape(thr_tile_D));
auto fragment = make_tensor<ElementB>(shape(thr_tB_in_linear));
// Copy from GMEM to RMEM and from RMEM to GMEM
copy(tiled_copy, thr_tile_S, fragment);
copy(Copy_Atom<DefaultCopy, uint8_t>{}, fragment, thr_tile_D);
copy(thr_tB_in_linear, fragment);
copy(Copy_Atom<DefaultCopy, uint8_t>{}, fragment, thr_tB_out_linear);
}
template <typename PrepackedLayoutB, typename InLayout>
......@@ -44,18 +62,15 @@ static void prepack_B_template(
TORCH_CHECK(size<0>(B_layout) % size<0>(TileShapeNKL{}) == 0);
TORCH_CHECK(size<1>(B_layout) % size<1>(TileShapeNKL{}) == 0);
TORCH_CHECK(size<2>(B_layout) % size<2>(TileShapeNKL{}) == 0);
auto N_tiles = size<0>(B_layout) / size<0>(TileShapeNKL{});
auto K_tiles = size<1>(B_layout) / size<1>(TileShapeNKL{});
auto L_tiles = size<2>(B_layout) / size<2>(TileShapeNKL{});
auto L_tiles = size<2>(B_layout);
auto B_in = make_tensor(get_logical_ptr(B_in_ptr), B_layout);
auto B_tiled_out =
make_tensor(get_logical_ptr(B_out_ptr), ilvd_NKbNbKL_to_offset);
prepack_B_kernel<TileShapeNKL, typename PrepackedLayoutB::ElementB>
<<<dim3(N_tiles, K_tiles, L_tiles), 128, 0, stream>>>(B_in, B_tiled_out);
prepack_B_kernel<128, PrepackedLayoutB>
<<<dim3(N_tiles, K_tiles, L_tiles), 128, 0, stream>>>(B_in, B_out_ptr);
}
}; // namespace machete
\ No newline at end of file
......@@ -2,9 +2,17 @@
#include "machete_prepack_kernel.cuh"
#include "cutlass_extensions/torch_utils.hpp"
#include "core/scalar_type.hpp"
namespace machete {
struct PrepackBArgs {
torch::Tensor const& B;
at::ScalarType a_type;
vllm::ScalarType b_type;
c10::optional<at::ScalarType> maybe_group_scales_type;
};
template <typename PrepackedLayoutB>
torch::Tensor prepack_impl(torch::Tensor const B) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(B));
......@@ -61,11 +69,6 @@ torch::Tensor prepack_impl(torch::Tensor const B) {
return D;
};
template <typename ElementA, typename ElementB, typename ElementD,
typename AccumulatorT = float, typename ScaleT = cutlass::half_t,
typename ZeroT = cutlass::half_t>
struct PrepackBDispatcher {
static torch::Tensor dispatch(torch::Tensor B);
};
torch::Tensor prepack_B_dispatch(PrepackBArgs args);
}; // namespace machete
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment