Unverified Commit 96d999fb authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Kernel] Initial Machete W4A8 support + Refactors (#9855)


Signed-off-by: default avatarLucas Wilkinson <lwilkinson@neuralmagic.com>
parent c2170a5b
This diff is collapsed.
...@@ -20,10 +20,11 @@ if __name__ == "__main__": ...@@ -20,10 +20,11 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
with open(args.filename, 'rb') as f: with open(args.filename, 'rb') as f:
data: List[TMeasurement] = pickle.load(f) data = pickle.load(f)
raw_results: List[TMeasurement] = data["results"]
results = defaultdict(lambda: list()) results = defaultdict(lambda: list())
for v in data: for v in raw_results:
result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label) result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label)
if result is not None: if result is not None:
KN = result.group(1) KN = result.group(1)
......
...@@ -40,4 +40,10 @@ WEIGHT_SHAPES = { ...@@ -40,4 +40,10 @@ WEIGHT_SHAPES = {
([8192, 57344], 1), ([8192, 57344], 1),
([28672, 8192], 0), ([28672, 8192], 0),
], ],
"meta-llama/Llama-3.1-405b-hf": [
([16384, 18432], 1),
([16384, 16384], 0),
([16384, 106496], 1),
([53248, 16384], 0),
],
} }
...@@ -20,9 +20,9 @@ CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) { ...@@ -20,9 +20,9 @@ CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) {
// is the layout f(x) = x // is the layout f(x) = x
template <typename Layout> template <typename Layout>
CUTE_HOST_DEVICE static constexpr bool is_identity_layout() { CUTE_HOST_DEVICE static constexpr bool is_identity_layout() {
if constexpr (std::is_same_v<Layout, void>) if constexpr (std::is_same_v<Layout, void>) {
return true; return true;
else { } else {
constexpr auto coalesced_layout = coalesce(Layout{}); constexpr auto coalesced_layout = coalesce(Layout{});
if constexpr (rank(coalesced_layout) == 1 && if constexpr (rank(coalesced_layout) == 1 &&
stride<0>(coalesced_layout) == 1) { stride<0>(coalesced_layout) == 1) {
......
...@@ -52,6 +52,7 @@ ...@@ -52,6 +52,7 @@
// clang-format off // clang-format off
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" #include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cute/tensor.hpp" #include "cute/tensor.hpp"
namespace cutlass::epilogue::threadblock { namespace cutlass::epilogue::threadblock {
......
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
/*
This file defines custom epilogues for fusing channel scales, token scales,
bias, and activation zero-points onto a GEMM operation using the
CUTLASS 2.x API, for sm80 (Ampere) NVIDIA GPUs.
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace vllm::c2x {
using namespace cute;
/*
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBase {
protected:
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
template <typename T>
using ColOrScalarLoad =
cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowOrScalarLoad =
cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
template <typename T>
using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
template <typename T>
using RowOrZeroLoad =
cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template <typename Descriptor, typename T>
static auto args_from_tensor(torch::Tensor const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
return Arguments{data_ptr, tensor.numel() != 1};
} else {
// it would technically work but no use case as data_ptr is never nullptr
static_assert(!std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
return Arguments{data_ptr};
}
}
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T>
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
return Arguments{data_ptr};
}
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch._scaled_mm.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogue
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
* This bias can also be used in the per-tensor azp case, where the activation
* zero point (azp) is used to compute an azp correction term,
* which is folded into the bias.
*
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBias
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
protected:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args, bias_args};
}
};
/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
* term, which should already be multiplied with the scalar azp.
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBiasAzp
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
// This is the full AZP term, azp * J @ B, shape (1,n)
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
// Compute float(accum - azp_adj), both operands are int32_t
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>;
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
EVTComputeAzp>;
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
return ArgumentType{a_args, evt_scale_b_args, bias_args};
}
};
/*
* This epilogue supports per-token azp by computing and applying
* the correction term using a rank-1 update. If the term were materialized,
* it would require O(m*n) space, and this way it only requires O(m+n) space.
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
* point for each row of A.
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBiasAzpToken
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
// Per-token azp term, shape (m,1)
using Azp = typename SUPER::template ColLoad<int32_t>;
// This is the AZP adjustment term, J @ B, shape (1,n)
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
// Compute azp * azp_adj
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, int32_t, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Azp, AzpAdj>;
// Compute float(accum - azp*azp_adj), all operands are int32_t
using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAcc =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>;
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
EVTComputeAcc>;
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
torch::Tensor const& azp,
c10::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
return ArgumentType{a_args, evt_scale_b_args, bias_args};
}
};
}; // namespace vllm::c2x
\ No newline at end of file
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
/*
This file defines custom epilogues for fusing channel scales, token scales,
bias, and activation zero-points onto a GEMM operation using the
CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later.
Epilogues must contain a public type named EVTCompute of type Sm90EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace vllm::c3x {
using namespace cute;
/*
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBase {
protected:
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
template <typename T>
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<0>, Int<1>, Int<0>>>;
// Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false>
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
// Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false>
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template <typename Descriptor, typename T>
static auto args_from_tensor(torch::Tensor const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
return Arguments{data_ptr, tensor.numel() != 1};
} else {
static_assert(!std::is_same_v<Descriptor, ColLoad<T, true>> &&
!std::is_same_v<Descriptor, RowLoad<T, true>>);
return Arguments{data_ptr};
}
}
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T>
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
std::is_same_v<Descriptor, RowLoad<T, true>>);
return Arguments{data_ptr};
}
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch.scaled_mm_.
A and B may be both either int8 or fp8_e4m3. A can be
quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogue
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
* This bias can also be used in the per-tensor azp case, where the activation
* zero point (azp) is used to compute an azp correction term,
* which is folded into the bias.
*
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBias
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args, bias_args};
}
};
/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
* term, which should already be multiplied with the scalar azp.
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBiasAzp
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD, true>;
// This is the full AZP term, azp * J @ B, shape (1,n)
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
// Compute float(accum - azp_adj), both operands are int32_t
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Accum, AzpWithAdj>;
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
return ArgumentType{a_args, evt_scale_b_args, bias_args};
}
};
/*
* This epilogue supports per-token azp by computing and applying
* the correction term using a rank-1 update. If the term were materialized,
* it would require O(m*n) space, and this way it only requires O(m+n) space.
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
* point for each row of A.
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBiasAzpToken
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD, true>;
// Per-token azp term, shape (m,1)
using Azp = typename SUPER::template ColLoad<int32_t>;
// This is the AZP adjustment term, J @ B, shape (1,n)
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
// Compute azp * azp_adj
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, int32_t, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Azp, AzpAdj>;
// Compute float(accum - azp*azp_adj), all operands are int32_t
using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAcc =
cutlass::epilogue::fusion::Sm90EVT<ComputeAcc, Accum, EVTComputeAzp>;
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
torch::Tensor const& azp,
c10::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
return ArgumentType{a_args, evt_scale_b_args, bias_args};
}
};
}; // namespace vllm::c3x
\ No newline at end of file
...@@ -35,6 +35,35 @@ VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = { ...@@ -35,6 +35,35 @@ VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
} }
} }
VLLMDataTypeSize: Dict[Union[VLLMDataType, DataType], int] = {
**DataTypeSize, # type: ignore
**{
VLLMDataType.u4b8: 4,
VLLMDataType.u8b128: 8,
}
}
VLLMDataTypeVLLMScalarTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
VLLMDataType.u4b8: "vllm::kU4B8",
VLLMDataType.u8b128: "vllm::kU8B128",
DataType.u4: "vllm::kU4",
DataType.u8: "vllm::kU8",
DataType.s4: "vllm::kS4",
DataType.s8: "vllm::kS8",
DataType.f16: "vllm::kFloat16",
DataType.bf16: "vllm::kBfloat16",
}
VLLMDataTypeTorchDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
DataType.u8: "at::ScalarType::Byte",
DataType.s8: "at::ScalarType::Char",
DataType.e4m3: "at::ScalarType::Float8_e4m3fn",
DataType.s32: "at::ScalarType::Int",
DataType.f16: "at::ScalarType::Half",
DataType.bf16: "at::ScalarType::BFloat16",
DataType.f32: "at::ScalarType::Float",
}
VLLMKernelScheduleTag: Dict[Union[ VLLMKernelScheduleTag: Dict[Union[
MixedInputKernelScheduleType, KernelScheduleType], str] = { MixedInputKernelScheduleType, KernelScheduleType], str] = {
**KernelScheduleTag, # type: ignore **KernelScheduleTag, # type: ignore
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "cutlass/numeric_conversion.h" #include "cutlass/numeric_conversion.h"
#include "cutlass_extensions/vllm_custom_types.cuh" #include "cutlass_extensions/vllm_custom_types.cuh"
#include "cutlass_extensions/cute_utils.cuh" #include "cutlass_extensions/cute_utils.cuh"
#include "cutlass_extensions/vllm_type_utils.cuh"
// this file extends: // this file extends:
// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h // https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h
...@@ -28,8 +29,19 @@ struct InterleavedNumericArrayConverter { ...@@ -28,8 +29,19 @@ struct InterleavedNumericArrayConverter {
CUTLASS_DEVICE CUTLASS_DEVICE
static result_type convert(source_type const& source) { static result_type convert(source_type const& source) {
CUTE_INVALID_CONTROL_PATH( if (cute::elect_one_sync()) {
"InterleavedNumericArrayConverter not implemented\n"); if constexpr (std::is_same_v<IlvBlkLayout, void>) {
printf(
"Convert %s <= %s (N = %d, IlvBlkLayout = void), not implemented\n",
nameof_v<T>, nameof_v<S>, N);
} else {
printf(
"Convert %s <= %s (N = %d, size(IlvBlkLayout{}) = %d), not "
"implemented\n",
nameof_v<T>, nameof_v<S>, N, size(IlvBlkLayout{}));
}
__brkpt();
}
return {}; return {};
} }
...@@ -56,11 +68,6 @@ struct InterleavedNumericArrayConverter< ...@@ -56,11 +68,6 @@ struct InterleavedNumericArrayConverter<
result_type operator()(source_type const& s) const { return convert(s); } result_type operator()(source_type const& s) const { return convert(s); }
}; };
// TODO (LucasWilkinson): Implement
// for Array<cutlass::float8_e4m3fn, N> <= Array<vllm_uint4b8_t, N>
// ....
template <typename RegConvert32bit, typename T, typename S, int N> template <typename RegConvert32bit, typename T, typename S, int N>
struct ArrayConverterPacked32Bit { struct ArrayConverterPacked32Bit {
using result_type = Array<T, N>; using result_type = Array<T, N>;
...@@ -86,14 +93,16 @@ struct ArrayConverterPacked32Bit { ...@@ -86,14 +93,16 @@ struct ArrayConverterPacked32Bit {
using ScalarConverter = NumericConverter<T, S>; using ScalarConverter = NumericConverter<T, S>;
template <typename PackedSrc> template <typename PackedSrc>
CUTLASS_DEVICE static uint32_t to_reg(PackedSrc const& source) { CUTLASS_DEVICE static auto to_regs(PackedSrc const& src) {
if constexpr (sizeof(PackedSrc) == 1) { if constexpr (sizeof(PackedSrc) == 1) {
return static_cast<uint32_t>(reinterpret_cast<const uint8_t&>(source)); return Array<uint32_t, 1>{reinterpret_cast<uint8_t const&>(src)};
} else if constexpr (sizeof(PackedSrc) == 2) { } else if constexpr (sizeof(PackedSrc) == 2) {
return static_cast<uint32_t>(reinterpret_cast<const uint16_t&>(source)); return Array<uint32_t, 1>{reinterpret_cast<uint16_t const&>(src)};
} else if constexpr (sizeof(PackedSrc) == 4) {
return Array<uint32_t, 1>{reinterpret_cast<uint32_t const&>(src)};
} else { } else {
static_assert(sizeof(PackedSrc) == 4); static_assert(sizeof(PackedSrc) == 8);
return reinterpret_cast<const uint32_t&>(source); return reinterpret_cast<Array<uint32_t, 2> const&>(src);
} }
} }
...@@ -110,7 +119,7 @@ struct ArrayConverterPacked32Bit { ...@@ -110,7 +119,7 @@ struct ArrayConverterPacked32Bit {
static_assert(std::is_same_v<typename PackedSrcType::Element, S>); static_assert(std::is_same_v<typename PackedSrcType::Element, S>);
static_assert(std::is_same_v<typename PackedResultType::Element, T>); static_assert(std::is_same_v<typename PackedResultType::Element, T>);
return RegConvert32bit::template convert<PackedResultType>(to_reg(source)); return RegConvert32bit::template convert<PackedResultType>(to_regs(source));
} }
friend class detail::VectorizedConverter; friend class detail::VectorizedConverter;
...@@ -140,6 +149,131 @@ struct ArrayConverterPacked32Bit { ...@@ -140,6 +149,131 @@ struct ArrayConverterPacked32Bit {
} }
}; };
// Convert 8 4bit values packed into a 32bit register to 8 8bit values packed
// into 2 32bit register.
template <uint8_t LUT0, uint8_t LUT1, uint8_t LUT2, uint8_t LUT3, //
uint8_t LUT4, uint8_t LUT5, uint8_t LUT6, uint8_t LUT7, //
uint8_t LUT8, uint8_t LUT9, uint8_t LUT10, uint8_t LUT11, //
uint8_t LUT12, uint8_t LUT13, uint8_t LUT14, uint8_t LUT15>
CUTLASS_DEVICE cutlass::AlignedArray<uint32_t, 2> lut_4bit_to_8bit_convert(
uint32_t src) {
cutlass::AlignedArray<uint32_t, 2> r;
// Determines if the value is in the top half of the LUT if set or
// (i.e. LUT[8:15]) in the bottom half (i.e. LUT[0:7]) if not set. Then move
// into bit position 0x4 of each nibble so when or'd with final_prmt_base it
// selects the correct candidate. When elements in final_prmt_base
// are >= 0x4, the high candidate is selected (i.e. LUT[8:15]), when elements
// are < 0x4, the low candidate is selected (i.e. LUT[0:7])
uint32_t high_bit = (src & 0x88888888) >> 1;
// `high_bit` is OR'd with 0x31203120 to find the correct value in the LUT
// (selects correct high or low candidate)
const uint32_t final_prmt_base = 0x32103210;
// Ignore the high bit when indexing into LUT, for each 4bit value
// we index into both the high and low candidates then use
// high_bit | final_prmt_base to select the correct candidate
uint32_t lut_idx = (src & 0x77777777);
auto pack = [](uint8_t a, uint8_t b, uint8_t c, uint8_t d) {
return uint32_t(a) | (uint32_t(b) << 8) | (uint32_t(c) << 16) |
(uint32_t(d) << 24);
};
static constexpr uint32_t LOW_0 = pack(LUT0, LUT1, LUT2, LUT3);
static constexpr uint32_t LOW_1 = pack(LUT4, LUT5, LUT6, LUT7);
static constexpr uint32_t HIGH_0 = pack(LUT8, LUT9, LUT10, LUT11);
static constexpr uint32_t HIGH_1 = pack(LUT12, LUT13, LUT14, LUT15);
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < 2; ++ii, lut_idx >>= 16, high_bit >>= 16) {
uint32_t final_prmt_idx = final_prmt_base | high_bit;
// This uses a look up table to convert packed int4s to packed int8s,
// using the int4 value as the index to prmt. It first select both the
// high and low candidates, then uses the high bit (i.e. `high_bit`) to
// select the correct candidate.
asm volatile(
"{\n"
" .reg .b32 low, high;\n"
" prmt.b32 low, %1, %2, %5;\n"
" prmt.b32 high, %3, %4, %5;\n"
" prmt.b32 %0, low, high, %6;\n"
"}\n"
: "=r"(r[ii])
: "n"(LOW_0), "n"(LOW_1), "n"(HIGH_0), "n"(HIGH_1), "r"(lut_idx),
"r"(final_prmt_idx));
}
return r;
};
// for Array<int8_t, N> <= Array<vllm_uint4b8_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<int8_t, vllm_uint4b8_t, N, Round> {
using result_type = Array<int8_t, N>;
using source_type = Array<vllm_uint4b8_t, N>;
static FloatRoundStyle const round_style = Round;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
// [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as int8s
auto r = lut_4bit_to_8bit_convert<0xF8, 0xF9, 0xFA, 0xFB, //
0xFC, 0xFD, 0xFE, 0xFF, //
0x00, 0x01, 0x02, 0x03, //
0x04, 0x05, 0x06, 0x07>(src_[0]);
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::float_e4m3_t, N> <= Array<vllm_uint4b8_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::float_e4m3_t, vllm_uint4b8_t, N, Round> {
using result_type = Array<cutlass::float_e4m3_t, N>;
using source_type = Array<vllm_uint4b8_t, N>;
static FloatRoundStyle const round_style = Round;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
// [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as fp8s
auto r = lut_4bit_to_8bit_convert<0xD0, 0xCE, 0xCC, 0xCA, //
0xC8, 0xC4, 0xC0, 0xB8, //
0x00, 0x38, 0x40, 0x44, //
0x48, 0x4A, 0x4C, 0x4E>(src_[0]);
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::half_t, N> <= Array<vllm_uint4b8_t, N> // for Array<cutlass::half_t, N> <= Array<vllm_uint4b8_t, N>
template <FloatRoundStyle Round, int N> template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::half_t, vllm_uint4b8_t, N, Round> { struct NumericArrayConverter<cutlass::half_t, vllm_uint4b8_t, N, Round> {
...@@ -148,7 +282,8 @@ struct NumericArrayConverter<cutlass::half_t, vllm_uint4b8_t, N, Round> { ...@@ -148,7 +282,8 @@ struct NumericArrayConverter<cutlass::half_t, vllm_uint4b8_t, N, Round> {
struct RegConvert { struct RegConvert {
template <typename PackedResultType> template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
using RegArray = using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2, cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>; sizeof(PackedResultType)>;
...@@ -249,7 +384,8 @@ struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>, ...@@ -249,7 +384,8 @@ struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
private: private:
struct RegConvert { struct RegConvert {
template <typename PackedResultType> template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
using RegArray = using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2, cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>; sizeof(PackedResultType)>;
...@@ -338,7 +474,8 @@ struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>, ...@@ -338,7 +474,8 @@ struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
private: private:
struct RegConvert { struct RegConvert {
template <typename PackedResultType> template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
using RegArray = using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2, cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>; sizeof(PackedResultType)>;
...@@ -417,7 +554,8 @@ struct NumericArrayConverter<cutlass::half_t, vllm_uint8b128_t, N, Round> { ...@@ -417,7 +554,8 @@ struct NumericArrayConverter<cutlass::half_t, vllm_uint8b128_t, N, Round> {
struct RegConvert { struct RegConvert {
template <typename PackedResultType> template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
// Hold output FP16s in reg. We need 1 reg for every 2 elements // Hold output FP16s in reg. We need 1 reg for every 2 elements
using RegArray = using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2, cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
...@@ -469,7 +607,8 @@ struct NumericArrayConverter<float, vllm_uint8b128_t, N, Round> { ...@@ -469,7 +607,8 @@ struct NumericArrayConverter<float, vllm_uint8b128_t, N, Round> {
private: private:
struct RegConvert { struct RegConvert {
template <typename PackedResultType> template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
PackedResultType r; PackedResultType r;
// __byte_perm simulates the add.u32 0x4B000000 to every u8 element of // __byte_perm simulates the add.u32 0x4B000000 to every u8 element of
...@@ -513,7 +652,8 @@ struct NumericArrayConverter<cutlass::bfloat16_t, vllm_uint4b8_t, N, Round> { ...@@ -513,7 +652,8 @@ struct NumericArrayConverter<cutlass::bfloat16_t, vllm_uint4b8_t, N, Round> {
private: private:
struct RegConvert { struct RegConvert {
template <typename PackedResultType> template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src_reg) { CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src_reg = src_[0];
// Hold output BF16s in reg. We need 1 reg for every 2 elements // Hold output BF16s in reg. We need 1 reg for every 2 elements
using RegArray = using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2, cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
...@@ -603,7 +743,8 @@ struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>, ...@@ -603,7 +743,8 @@ struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
private: private:
struct RegConvert { struct RegConvert {
template <typename PackedResultType> template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
using RegArray = using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2, cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>; sizeof(PackedResultType)>;
...@@ -671,7 +812,8 @@ struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>, ...@@ -671,7 +812,8 @@ struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
private: private:
struct RegConvert { struct RegConvert {
template <typename PackedResultType> template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
using RegArray = using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2, cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>; sizeof(PackedResultType)>;
...@@ -788,6 +930,61 @@ struct NumericArrayConverter<cutlass::bfloat16_t, vllm_uint8b128_t, N, Round> { ...@@ -788,6 +930,61 @@ struct NumericArrayConverter<cutlass::bfloat16_t, vllm_uint8b128_t, N, Round> {
#endif #endif
// for Array<int8_t, N> <= Array<cutlass::half_t, N>
// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<int8_t, cutlass::half_t, N, Round> {
using result_type = Array<int8_t, N>;
using source_type = Array<cutlass::half_t, N>;
struct RegConvert {
// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904
template <typename PackedResultType, int src_regs>
CUTLASS_DEVICE static PackedResultType convert(
Array<uint32_t, src_regs> src) {
// Hold output int8s in reg. We need 1 reg for every 4 elements
using RegArray = cutlass::AlignedArray<
uint32_t, std::max(PackedResultType::kElements / 4, size_t(1))>;
RegArray r;
static constexpr uint32_t MAGIC_BIAS_ = 0x64806480;
auto MAGIC_BIAS = *reinterpret_cast<const half2*>(&MAGIC_BIAS_);
*reinterpret_cast<half2*>(&src[0]) =
__hadd2(*reinterpret_cast<half2*>(&src[0]), MAGIC_BIAS);
if constexpr (src_regs > 1) {
*reinterpret_cast<half2*>(&src[1]) =
__hadd2(*reinterpret_cast<half2*>(&src[1]), MAGIC_BIAS);
}
static_assert(PackedResultType::kElements <= 4);
uint32_t uint8s;
static constexpr uint32_t MASK_0246 = 0x6420;
static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080;
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
: "=r"(uint8s)
: "r"(src[0]), "r"((src_regs > 1) ? src[1] : src[0]),
"n"(MASK_0246));
uint32_t int8s = (uint8s ^ UINT8s_TO_INT8s_MASK);
return reinterpret_cast<PackedResultType&>(int8s);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass } // namespace cutlass
......
#include "cutlass/bfloat16.h"
#include "cutlass/half.h"
#include "cuda_bf16.h"
#include "cutlass_extensions/vllm_custom_types.cuh"
namespace cutlass {
template <typename T>
struct nameof {
static constexpr char const* value = "unknown";
};
template <typename T>
inline constexpr auto nameof_v = nameof<T>::value;
#define NAMEOF_TYPE(T) \
template <> \
struct nameof<T> { \
static constexpr char const* value = #T; \
};
NAMEOF_TYPE(float_e4m3_t)
NAMEOF_TYPE(float_e5m2_t)
NAMEOF_TYPE(half_t)
NAMEOF_TYPE(nv_bfloat16)
NAMEOF_TYPE(bfloat16_t)
NAMEOF_TYPE(float)
NAMEOF_TYPE(int4b_t)
NAMEOF_TYPE(int8_t)
NAMEOF_TYPE(int32_t)
NAMEOF_TYPE(int64_t)
NAMEOF_TYPE(vllm_uint4b8_t)
NAMEOF_TYPE(uint4b_t)
NAMEOF_TYPE(uint8_t)
NAMEOF_TYPE(vllm_uint8b128_t)
NAMEOF_TYPE(uint32_t)
NAMEOF_TYPE(uint64_t)
}; // namespace cutlass
\ No newline at end of file
...@@ -8,6 +8,10 @@ ...@@ -8,6 +8,10 @@
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh" #include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh" #include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"
using namespace vllm;
/* /*
This file defines quantized GEMM operations using the CUTLASS 2.x API, for This file defines quantized GEMM operations using the CUTLASS 2.x API, for
NVIDIA GPUs with SM versions prior to sm90 (Hopper). NVIDIA GPUs with SM versions prior to sm90 (Hopper).
...@@ -22,12 +26,11 @@ void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a, ...@@ -22,12 +26,11 @@ void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) { if (out.dtype() == torch::kBFloat16) {
return vllm::cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, return cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); TORCH_CHECK(out.dtype() == torch::kFloat16);
return vllm::cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>( return cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} }
} }
...@@ -42,10 +45,10 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a, ...@@ -42,10 +45,10 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
if (bias) { if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(), TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype()); "currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBias>( return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias); out, a, b, a_scales, b_scales, *bias);
} else { } else {
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogue>( return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales); out, a, b, a_scales, b_scales);
} }
} }
...@@ -61,10 +64,10 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a, ...@@ -61,10 +64,10 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (azp) { if (azp) {
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBiasAzpToken>( return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias); out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else { } else {
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBiasAzp>( return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias); out, a, b, a_scales, b_scales, azp_adj, bias);
} }
} }
...@@ -78,12 +81,11 @@ void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a, ...@@ -78,12 +81,11 @@ void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) { if (out.dtype() == torch::kBFloat16) {
return vllm::cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); TORCH_CHECK(out.dtype() == torch::kFloat16);
return vllm::cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>( return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} }
} }
...@@ -98,10 +100,10 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a, ...@@ -98,10 +100,10 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
if (bias) { if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(), TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype()); "currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBias>( return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias); out, a, b, a_scales, b_scales, *bias);
} else { } else {
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogue>( return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales); out, a, b, a_scales, b_scales);
} }
} }
...@@ -117,10 +119,10 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a, ...@@ -117,10 +119,10 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (azp) { if (azp) {
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBiasAzpToken>( return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias); out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else { } else {
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBiasAzp>( return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias); out, a, b, a_scales, b_scales, azp_adj, bias);
} }
} }
...@@ -134,13 +136,12 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a, ...@@ -134,13 +136,12 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) { if (out.dtype() == torch::kBFloat16) {
return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t, return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>( Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else { } else {
assert(out.dtype() == torch::kFloat16); assert(out.dtype() == torch::kFloat16);
return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} }
} else { } else {
...@@ -148,13 +149,13 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a, ...@@ -148,13 +149,13 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
if (out.dtype() == torch::kBFloat16) { if (out.dtype() == torch::kBFloat16) {
return vllm::cutlass_gemm_sm89_fp8_dispatch< return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue>( cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); TORCH_CHECK(out.dtype() == torch::kFloat16);
return vllm::cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t, return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>( cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} }
} }
...@@ -170,10 +171,10 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a, ...@@ -170,10 +171,10 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
if (bias) { if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(), TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype()); "currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBias>( return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias); out, a, b, a_scales, b_scales, *bias);
} else { } else {
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogue>( return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales); out, a, b, a_scales, b_scales);
} }
} }
...@@ -189,10 +190,10 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a, ...@@ -189,10 +190,10 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (azp) { if (azp) {
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBiasAzpToken>( return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias); out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else { } else {
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBiasAzp>( return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias); out, a, b, a_scales, b_scales, azp_adj, bias);
} }
} }
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" #include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "broadcast_load_epilogue_c2x.hpp"
#include "common.hpp" #include "common.hpp"
// clang-format on // clang-format on
...@@ -71,307 +70,6 @@ struct enable_sm89_to_sm90 : Kernel { ...@@ -71,307 +70,6 @@ struct enable_sm89_to_sm90 : Kernel {
#endif #endif
} }
}; };
/*
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBase {
protected:
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
template <typename T>
using ColOrScalarLoad =
cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowOrScalarLoad =
cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
template <typename T>
using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
template <typename T>
using RowOrZeroLoad =
cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template <typename Descriptor, typename T>
static auto args_from_tensor(torch::Tensor const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
return Arguments{data_ptr, tensor.numel() != 1};
} else {
// it would technically work but no use case as data_ptr is never nullptr
static_assert(!std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
return Arguments{data_ptr};
}
}
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T>
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
return Arguments{data_ptr};
}
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch._scaled_mm.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogue
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
* This bias can also be used in the per-tensor azp case, where the activation
* zero point (azp) is used to compute an azp correction term,
* which is folded into the bias.
*
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBias
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
protected:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args, bias_args};
}
};
/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
* term, which should already be multiplied with the scalar azp.
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBiasAzp
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
// This is the full AZP term, azp * J @ B, shape (1,n)
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
// Compute float(accum - azp_adj), both operands are int32_t
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>;
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
EVTComputeAzp>;
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
return ArgumentType{a_args, evt_scale_b_args, bias_args};
}
};
/*
* This epilogue supports per-token azp by computing and applying
* the correction term using a rank-1 update. If the term were materialized,
* it would require O(m*n) space, and this way it only requires O(m+n) space.
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
* point for each row of A.
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBiasAzpToken
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
// Per-token azp term, shape (m,1)
using Azp = typename SUPER::template ColLoad<int32_t>;
// This is the AZP adjustment term, J @ B, shape (1,n)
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
// Compute azp * azp_adj
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, int32_t, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Azp, AzpAdj>;
// Compute float(accum - azp*azp_adj), all operands are int32_t
using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAcc =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>;
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
EVTComputeAcc>;
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
torch::Tensor const& azp,
c10::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
return ArgumentType{a_args, evt_scale_b_args, bias_args};
}
};
template <typename Arch, template <typename> typename ArchGuard, template <typename Arch, template <typename> typename ArchGuard,
typename ElementAB_, typename ElementD_, typename ElementAB_, typename ElementD_,
template <typename, typename> typename Epilogue_, typename TileShape, template <typename, typename> typename Epilogue_, typename TileShape,
......
...@@ -23,11 +23,12 @@ ...@@ -23,11 +23,12 @@
#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp"
#include "broadcast_load_epilogue_c3x.hpp" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "common.hpp" #include "common.hpp"
// clang-format on // clang-format on
using namespace cute; using namespace cute;
using namespace vllm;
/* /*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for This file defines quantized GEMM operations using the CUTLASS 3.x API, for
...@@ -56,305 +57,6 @@ struct enable_sm90_or_later : Kernel { ...@@ -56,305 +57,6 @@ struct enable_sm90_or_later : Kernel {
#endif #endif
} }
}; };
/*
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBase {
protected:
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
template <typename T>
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<0>, Int<1>, Int<0>>>;
// Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false>
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
// Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false>
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template <typename Descriptor, typename T>
static auto args_from_tensor(torch::Tensor const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
return Arguments{data_ptr, tensor.numel() != 1};
} else {
static_assert(!std::is_same_v<Descriptor, ColLoad<T, true>> &&
!std::is_same_v<Descriptor, RowLoad<T, true>>);
return Arguments{data_ptr};
}
}
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T>
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
std::is_same_v<Descriptor, RowLoad<T, true>>);
return Arguments{data_ptr};
}
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch.scaled_mm_.
A and B may be both either int8 or fp8_e4m3. A can be
quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogue
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
* This bias can also be used in the per-tensor azp case, where the activation
* zero point (azp) is used to compute an azp correction term,
* which is folded into the bias.
*
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBias
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args, bias_args};
}
};
/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
* term, which should already be multiplied with the scalar azp.
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBiasAzp
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD, true>;
// This is the full AZP term, azp * J @ B, shape (1,n)
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
// Compute float(accum - azp_adj), both operands are int32_t
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Accum, AzpWithAdj>;
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
return ArgumentType{a_args, evt_scale_b_args, bias_args};
}
};
/*
* This epilogue supports per-token azp by computing and applying
* the correction term using a rank-1 update. If the term were materialized,
* it would require O(m*n) space, and this way it only requires O(m+n) space.
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
* point for each row of A.
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBiasAzpToken
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD, true>;
// Per-token azp term, shape (m,1)
using Azp = typename SUPER::template ColLoad<int32_t>;
// This is the AZP adjustment term, J @ B, shape (1,n)
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
// Compute azp * azp_adj
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, int32_t, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Azp, AzpAdj>;
// Compute float(accum - azp*azp_adj), all operands are int32_t
using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAcc =
cutlass::epilogue::fusion::Sm90EVT<ComputeAcc, Accum, EVTComputeAzp>;
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
torch::Tensor const& azp,
c10::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
return ArgumentType{a_args, evt_scale_b_args, bias_args};
}
};
template <typename ElementAB_, typename ElementD_, template <typename ElementAB_, typename ElementD_,
template <typename, typename, typename> typename Epilogue_, template <typename, typename, typename> typename Epilogue_,
typename TileShape, typename ClusterShape, typename KernelSchedule, typename TileShape, typename ClusterShape, typename KernelSchedule,
...@@ -721,11 +423,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, ...@@ -721,11 +423,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
if (bias) { if (bias) {
TORCH_CHECK(bias->dtype() == c.dtype(), TORCH_CHECK(bias->dtype() == c.dtype(),
"currently bias dtype must match output dtype ", c.dtype()); "currently bias dtype must match output dtype ", c.dtype());
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBias>( return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBias>(
c, a, b, a_scales, b_scales, *bias); c, a, b, a_scales, b_scales, *bias);
} else { } else {
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogue>(c, a, b, a_scales, return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogue>(
b_scales); c, a, b, a_scales, b_scales);
} }
} }
...@@ -740,10 +442,10 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a, ...@@ -740,10 +442,10 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (azp) { if (azp) {
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBiasAzpToken>( return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias); out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else { } else {
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBiasAzp>( return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias); out, a, b, a_scales, b_scales, azp_adj, bias);
} }
} }
......
This diff is collapsed.
...@@ -171,6 +171,10 @@ struct MacheteCollectiveMma { ...@@ -171,6 +171,10 @@ struct MacheteCollectiveMma {
make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}), make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}),
Int<DispatchPolicy::Stages>{}))); Int<DispatchPolicy::Stages>{})));
using SmemLayoutACopy = decltype(GmemLayoutA::TVbNbKL_to_offset_copy(
make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}),
Int<DispatchPolicy::Stages>{})));
using SmemLayoutAtomARowMajor = using SmemLayoutAtomARowMajor =
decltype(rs_smem_selector<GmmaMajorA, ElementA, decltype(rs_smem_selector<GmmaMajorA, ElementA,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<0>(TileShape_MNK{})),
...@@ -288,14 +292,7 @@ struct MacheteCollectiveMma { ...@@ -288,14 +292,7 @@ struct MacheteCollectiveMma {
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0, static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0,
"SmemLayoutAtomScale must evenly divide tile k shape."); "SmemLayoutAtomScale must evenly divide tile k shape.");
// Tile along modes in a way that maximizes the TMA box size. // Tile along modes in a way that maximizes the TMA box size
using SmemLayoutACopy = decltype(tile_to_shape(
SmemLayoutAtomARowMajor{},
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}),
Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(),
Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
using SmemLayoutB = decltype(tile_to_shape( using SmemLayoutB = decltype(tile_to_shape(
SmemLayoutAtomB{}, SmemLayoutAtomB{},
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}),
...@@ -428,12 +425,12 @@ struct MacheteCollectiveMma { ...@@ -428,12 +425,12 @@ struct MacheteCollectiveMma {
// clang-format on // clang-format on
// ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx) // ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx)
using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset( using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset_copy(
make_shape(int32_t(0), int32_t(0), int32_t(0))))); make_shape(int32_t(0), int32_t(0), int32_t(0)))));
using ATensor = decltype(make_tensor( using ATensor = decltype(make_tensor(
get_logical_ptr(static_cast<InternalElementA const*>(nullptr)), get_logical_ptr(static_cast<InternalElementA const*>(nullptr)),
shape(GmemLayoutA::TVbNbKL_to_offset( shape(GmemLayoutA::TVbNbKL_to_offset_copy(
make_shape(int32_t(0), int32_t(0), int32_t(0)))), make_shape(int32_t(0), int32_t(0), int32_t(0)))),
PrepackedStrideA{})); PrepackedStrideA{}));
...@@ -450,8 +447,8 @@ struct MacheteCollectiveMma { ...@@ -450,8 +447,8 @@ struct MacheteCollectiveMma {
static constexpr auto make_tma_copy_A(ATensor tensor_a = ATensor{}) { static constexpr auto make_tma_copy_A(ATensor tensor_a = ATensor{}) {
return make_tma_copy<TmaElementA>( return make_tma_copy<TmaElementA>(
GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}), GmemTiledCopyA{}, tensor_a, SmemLayoutACopy{}(_, _, cute::Int<0>{}),
shape(SmemLayoutA{}(_, _, cute::Int<0>{})), shape(SmemLayoutACopy{}(_, _, cute::Int<0>{})),
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
} }
...@@ -584,7 +581,7 @@ struct MacheteCollectiveMma { ...@@ -584,7 +581,7 @@ struct MacheteCollectiveMma {
typename Params::TMA_Scale tma_load_scale; typename Params::TMA_Scale tma_load_scale;
typename Params::TMA_Zero tma_load_zero; typename Params::TMA_Zero tma_load_zero;
auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L)); auto layout = GmemLayoutA::TVbNbKL_to_offset_copy(make_shape(M, K, L));
tma_load_a = make_tma_copy_A( tma_load_a = make_tma_copy_A(
make_logical_tensor(ptr_A, shape(layout), stride(layout))); make_logical_tensor(ptr_A, shape(layout), stride(layout)));
...@@ -722,7 +719,7 @@ struct MacheteCollectiveMma { ...@@ -722,7 +719,7 @@ struct MacheteCollectiveMma {
// (TILE_V,TILE_B,m,k,l) // (TILE_V,TILE_B,m,k,l)
auto make_gA_mkl = [&]() { auto make_gA_mkl = [&]() {
// ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx) // ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx)
auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L)); auto layout = GmemLayoutA::TVbNbKL_to_offset_copy(make_shape(M, K, L));
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(layout)); Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(layout));
return local_tile(mA_mkl, return local_tile(mA_mkl,
make_shape(size<0>(layout), PPBlocksPerTile_MK{}), make_shape(size<0>(layout), PPBlocksPerTile_MK{}),
......
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
#include "cutlass_extensions/cute_utils.cuh" #include "cutlass_extensions/cute_utils.cuh"
#include "cutlass_extensions/vllm_numeric_conversion.cuh" #include "cutlass_extensions/vllm_numeric_conversion.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "cutlass_extensions/torch_utils.hpp"
#include "machete_collective_builder.cuh" #include "machete_collective_builder.cuh"
#include "machete_prepacked_layout.cuh" #include "machete_prepacked_layout.cuh"
#include "machete_interleaving_utils.cuh" #include "machete_interleaving_utils.cuh"
...@@ -37,27 +39,42 @@ using namespace cute; ...@@ -37,27 +39,42 @@ using namespace cute;
// W is quantized, in this situation or right-hand operand is quantized so // W is quantized, in this situation or right-hand operand is quantized so
// we compute the transpose to move it to the left-hand side. // we compute the transpose to move it to the left-hand side.
template <typename ElementA_, typename ElementB_, typename ElementD_, template <typename ElementA_, typename ElementB_, typename ElementD_,
typename AccumulatorT, typename ScaleT, typename ZeroT, typename AccumulatorT, typename GroupScaleT, typename GroupZeroT,
class KernelSchedule, typename ScheduleConfig, bool with_C, typename ChannelScaleT, typename TokenScaleT, class KernelSchedule,
bool with_scales, bool with_zeropoints> typename ScheduleConfig>
struct MacheteKernelTemplate { struct MacheteKernelTemplate {
static constexpr bool with_C = false; // not ever used
static constexpr bool with_group_scales = !std::is_same_v<GroupScaleT, void>;
static constexpr bool with_group_zeropoints =
!std::is_same_v<GroupZeroT, void>;
static constexpr bool with_channel_scales =
!std::is_same_v<ChannelScaleT, void>;
static constexpr bool with_token_scales = !std::is_same_v<TokenScaleT, void>;
using MmaType = ElementA_; using MmaType = ElementA_;
using ElementA = ElementA_; using ElementA = ElementA_;
using ElementB = ElementB_; using ElementB = ElementB_;
using ElementD = ElementD_; using ElementD = ElementD_;
using ElementC = cute::conditional_t<with_C, ElementD, void>; using ElementC = cute::conditional_t<with_C, ElementD, void>;
using ElementZ = ZeroT; using ElementAccumulator = AccumulatorT;
using ElementS = ScaleT;
using ElementAccumulator =
AccumulatorT; // Element type for internal accumulation
using ElementCompute = AccumulatorT; // For Epilogue using ElementCompute = AccumulatorT; // For Epilogue
// Use dummy values when we don't have scales or zeropoints
using ElementZGroup =
cute::conditional_t<with_group_zeropoints, GroupZeroT, MmaType>;
using ElementSGroup =
cute::conditional_t<with_group_scales, GroupScaleT, MmaType>;
using ElementConvertGroup =
cute::conditional_t<with_group_scales, GroupScaleT, MmaType>;
using ElementSChannel =
cute::conditional_t<with_channel_scales, ChannelScaleT, AccumulatorT>;
using ElementSToken =
cute::conditional_t<with_token_scales, TokenScaleT, AccumulatorT>;
using BTypeTuple = cute::conditional_t< using BTypeTuple = cute::conditional_t<
with_scales, with_group_scales,
cute::conditional_t<with_zeropoints, cute::conditional_t<with_group_zeropoints,
cute::tuple<ElementB, ElementS, ElementZ>, cute::tuple<ElementB, ElementSGroup, ElementZGroup>,
cute::tuple<ElementB, ElementS>>, cute::tuple<ElementB, ElementSGroup>>,
ElementB>; ElementB>;
using LayoutA = cutlass::layout::RowMajor; using LayoutA = cutlass::layout::RowMajor;
...@@ -71,8 +88,8 @@ struct MacheteKernelTemplate { ...@@ -71,8 +88,8 @@ struct MacheteKernelTemplate {
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>; using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
using StrideC = cutlass::detail::TagToStrideA_t<LayoutC>; using StrideC = cutlass::detail::TagToStrideA_t<LayoutC>;
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>; using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
using StrideS = cutlass::detail::TagToStrideA_t<LayoutScale>; using StrideSGroup = cutlass::detail::TagToStrideA_t<LayoutScale>;
using StrideZ = StrideS; using StrideZGroup = StrideSGroup;
using LayoutA_Transpose = using LayoutA_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutA>::type; typename cutlass::layout::LayoutTranspose<LayoutA>::type;
...@@ -85,8 +102,8 @@ struct MacheteKernelTemplate { ...@@ -85,8 +102,8 @@ struct MacheteKernelTemplate {
using OperatorClass = cutlass::arch::OpClassTensorOp; using OperatorClass = cutlass::arch::OpClassTensorOp;
using PrepackedLayoutB = using PrepackedLayoutB =
PrepackedLayoutBTemplate<ElementA_, ElementB_, ElementD_, AccumulatorT, PrepackedLayoutBTemplate<ElementA_, ElementB_, ElementConvertGroup,
LayoutA_Transpose, KernelSchedule>; AccumulatorT, LayoutA_Transpose, KernelSchedule>;
static int constexpr TileShapeK = static int constexpr TileShapeK =
128 * 8 / cutlass::sizeof_bits<MmaType>::value; 128 * 8 / cutlass::sizeof_bits<MmaType>::value;
...@@ -103,12 +120,42 @@ struct MacheteKernelTemplate { ...@@ -103,12 +120,42 @@ struct MacheteKernelTemplate {
using EpilogueTileType = typename ScheduleConfig::EpilogueTileType; using EpilogueTileType = typename ScheduleConfig::EpilogueTileType;
using TileScheduler = typename ScheduleConfig::TileScheduler; using TileScheduler = typename ScheduleConfig::TileScheduler;
static_assert(
(!with_channel_scales && !with_token_scales) ||
((with_channel_scales && with_token_scales) &&
std::is_same_v<ElementSChannel, ElementSToken>),
"Currently token and channel scales (if present) must be the same type");
using EpilogueDescriptor =
cutlass::epilogue::collective::detail::EpilogueDescriptor<
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
ElementD, EpilogueSchedule>;
// Currently only supports float scales
using ChTokScalesEpilogue =
typename vllm::c3x::ScaledEpilogue<ElementAccumulator, ElementD,
EpilogueDescriptor>;
static_assert((with_channel_scales || with_token_scales) ||
(std::is_same_v<ElementSChannel, float> &&
std::is_same_v<ElementSToken, float>),
"Currently token and channel scales (if present) must be float "
"(and if one is present the other must be too)");
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<
cutlass::epilogue::fusion::Sm90AccFetch>;
using EVTCompute =
std::conditional_t<with_channel_scales || with_token_scales,
typename ChTokScalesEpilogue::EVTCompute,
StoreEpilogueCompute>;
// EVTCompute
using CollectiveEpilogue = using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder< typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
ElementAccumulator, ElementAccumulator, ElementC, LayoutC_Transpose, ElementAccumulator, ElementSChannel, ElementC, LayoutC_Transpose,
AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, EpilogueSchedule,
EpilogueSchedule>::CollectiveOp; EVTCompute>::CollectiveOp;
using CollectiveMainloop = using CollectiveMainloop =
typename cutlass::gemm::collective::VLLMCollectiveBuilder< typename cutlass::gemm::collective::VLLMCollectiveBuilder<
...@@ -131,26 +178,44 @@ struct MacheteKernelTemplate { ...@@ -131,26 +178,44 @@ struct MacheteKernelTemplate {
using MainloopArguments = typename GemmKernel::MainloopArguments; using MainloopArguments = typename GemmKernel::MainloopArguments;
using EpilogueArguments = typename GemmKernel::EpilogueArguments; using EpilogueArguments = typename GemmKernel::EpilogueArguments;
template <typename ShapeA, typename ShapeC, typename ShapeD, typename ShapeS,
typename ShapeZ>
static Arguments create_arguments( static Arguments create_arguments(
cudaStream_t stream, cudaStream_t stream,
ElementA const* A_ptr, // A is an MxK matrix torch::Tensor const& A, // MxK matrix
Layout<ShapeA, StrideA> const& layout_A, torch::Tensor const& B, // KxN prepacked matrix
ElementB const* B_ptr, // B is an KxN prepacked matrix torch::Tensor& D, // MxN matrix
ElementD* D_ptr, // D is an MxN matrix c10::optional<torch::Tensor> const& maybe_g_scales, // scale_KxN matrix
Layout<ShapeD, StrideD> const& layout_D, c10::optional<torch::Tensor> const& maybe_g_zeros, // scale_KxN matrix
ElementC const* C_ptr, // C is an MxN matrix c10::optional<int64_t> maybe_group_size,
std::optional<Layout<ShapeC, StrideC>> const& layout_C, c10::optional<torch::Tensor> const& maybe_ch_scales, // len N vector
ElementS const* S_ptr, // S is an scale_KxN matrix c10::optional<torch::Tensor> const& maybe_tok_scales) // len M vector
std::optional<Layout<ShapeS, StrideS>> const& layout_S, {
ElementZ const* Z_ptr, // Z is an scale_KxN matrix static_assert(!with_group_zeropoints || with_group_scales);
std::optional<Layout<ShapeZ, StrideZ>> const& layout_Z,
ElementCompute alpha, ElementCompute beta, int M = A.size(0), N = B.size(1), K = A.size(1);
std::optional<int> maybe_group_size) { TORCH_CHECK(D.size(0) == M && D.size(1) == N);
static_assert(!with_zeropoints || with_scales);
auto layout_A = make_cute_layout<StrideA>(A, "A");
int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A); auto layout_D = make_cute_layout<StrideD>(D, "D");
auto layout_S_group =
maybe_make_cute_layout<StrideSGroup>(maybe_g_scales, "group_scales");
auto layout_Z_group =
maybe_make_cute_layout<StrideZGroup>(maybe_g_zeros, "group_zeros");
int64_t numel_S_channel = maybe_ch_scales ? maybe_ch_scales->numel() : 0;
int64_t numel_S_token = maybe_tok_scales ? maybe_tok_scales->numel() : 0;
auto unwrap = [](auto const& t) {
return t ? t->const_data_ptr() : nullptr;
};
auto A_ptr = static_cast<ElementA const*>(A.const_data_ptr());
auto B_ptr = static_cast<ElementB const*>(B.const_data_ptr());
auto D_ptr = static_cast<ElementD*>(D.mutable_data_ptr());
auto S_group_ptr =
static_cast<ElementSGroup const*>(unwrap(maybe_g_scales));
auto Z_group_ptr = static_cast<ElementZGroup const*>(unwrap(maybe_g_zeros));
auto S_channel_ptr =
static_cast<ElementSChannel const*>(unwrap(maybe_ch_scales));
auto S_token_ptr =
static_cast<ElementSToken const*>(unwrap(maybe_tok_scales));
int const group_size = int const group_size =
maybe_group_size == -1 ? K : maybe_group_size.value_or(K); maybe_group_size == -1 ? K : maybe_group_size.value_or(K);
...@@ -159,26 +224,28 @@ struct MacheteKernelTemplate { ...@@ -159,26 +224,28 @@ struct MacheteKernelTemplate {
TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K); TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);
TORCH_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N); TORCH_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N);
if constexpr (with_C) { if constexpr (with_group_scales) {
TORCH_CHECK(C_ptr && layout_C); TORCH_CHECK(S_group_ptr && layout_S_group);
TORCH_CHECK((size<0>(*layout_S_group) == scale_k &&
size<1>(*layout_S_group) == N));
} else { } else {
TORCH_CHECK(!C_ptr, "C not supported"); TORCH_CHECK(!S_group_ptr, "Scales not supported");
} }
if constexpr (with_scales) { if constexpr (with_group_zeropoints) {
TORCH_CHECK(S_ptr && layout_S); TORCH_CHECK(Z_group_ptr && layout_Z_group);
TORCH_CHECK((size<0>(*layout_S) == scale_k && size<1>(*layout_S) == N)); TORCH_CHECK((size<0>(*layout_Z_group) == scale_k &&
size<1>(*layout_Z_group) == N));
TORCH_CHECK(layout_S_group && *layout_Z_group == *layout_S_group,
"Scales and zeros must have the same layout");
} else { } else {
TORCH_CHECK(!S_ptr, "Scales not supported"); TORCH_CHECK(!Z_group_ptr, "Zeropoints not supported");
} }
if constexpr (with_zeropoints) { if constexpr (with_channel_scales || with_token_scales) {
TORCH_CHECK(Z_ptr && layout_Z); TORCH_CHECK(
TORCH_CHECK((size<0>(*layout_Z) == scale_k && size<1>(*layout_Z) == N)); (maybe_ch_scales->numel() == N || maybe_ch_scales->numel() == 1) &&
TORCH_CHECK(layout_S && *layout_Z == *layout_S, (maybe_tok_scales->numel() == M || maybe_tok_scales->numel() == 1));
"Scales and zeros must have the same layout");
} else {
TORCH_CHECK(!Z_ptr, "Zeropoints not supported");
} }
// Transpose A and D // Transpose A and D
...@@ -186,24 +253,33 @@ struct MacheteKernelTemplate { ...@@ -186,24 +253,33 @@ struct MacheteKernelTemplate {
// for B (which is At) // for B (which is At)
auto stride_At = layout_A.stride(); auto stride_At = layout_A.stride();
auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride(); auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride();
auto stride_Ct = stride_Dt;
if (layout_C) {
stride_Ct = permute_layout<1, 0, 2>(*layout_C).stride();
}
MainloopArguments mainloop_arguments{}; MainloopArguments mainloop_arguments{};
EpilogueArguments epilogue_arguments{ // {Accum, C, C_layout, D, D}
{alpha, beta}, C_ptr, stride_Ct, D_ptr, stride_Dt}; EpilogueArguments epilogue_arguments{};
if constexpr (with_channel_scales || with_token_scales) {
epilogue_arguments =
EpilogueArguments{ChTokScalesEpilogue::prepare_args(
*maybe_ch_scales, *maybe_tok_scales),
nullptr,
{},
D_ptr,
stride_Dt};
} else {
epilogue_arguments = EpilogueArguments{{}, nullptr, {}, D_ptr, stride_Dt};
}
if constexpr (with_scales && with_zeropoints) { if constexpr (with_group_scales && with_group_zeropoints) {
auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride(); auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride();
mainloop_arguments =
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At,
S_ptr, stride_S, group_size, Z_ptr};
} else if constexpr (with_scales) {
auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride();
mainloop_arguments = MainloopArguments{ mainloop_arguments = MainloopArguments{
B_ptr, _StrideB{}, A_ptr, stride_At, S_ptr, stride_S, group_size}; B_ptr, _StrideB{}, A_ptr, stride_At,
S_group_ptr, stride_S_group, group_size, Z_group_ptr};
} else if constexpr (with_group_scales) {
auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride();
mainloop_arguments =
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At,
S_group_ptr, stride_S_group, group_size};
} else { } else {
mainloop_arguments = mainloop_arguments =
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At}; MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At};
......
...@@ -5,73 +5,61 @@ ...@@ -5,73 +5,61 @@
#include "machete_mm_kernel.cuh" #include "machete_mm_kernel.cuh"
#include "cutlass_extensions/torch_utils.hpp" #include "cutlass_extensions/torch_utils.hpp"
#include "core/scalar_type.hpp"
namespace machete { namespace machete {
struct PyTorchArguments { struct MMArgs {
torch::Tensor const& A; torch::Tensor const& A;
torch::Tensor const& B; torch::Tensor const& B;
c10::optional<torch::Tensor> const& scales; vllm::ScalarType const& b_type;
c10::optional<torch::Tensor> const& zeros; c10::optional<at::ScalarType> const& maybe_out_type;
c10::optional<int64_t> group_size; c10::optional<torch::Tensor> const& maybe_group_scales;
c10::optional<torch::Tensor> const& C; c10::optional<torch::Tensor> const& maybe_group_zeros;
c10::optional<double> alpha; c10::optional<int64_t> maybe_group_size;
c10::optional<double> beta; c10::optional<torch::Tensor> const& maybe_channel_scales;
c10::optional<std::string> schedule; c10::optional<torch::Tensor> const& maybe_token_scales;
c10::optional<std::string> maybe_schedule;
}; };
struct SupportedSchedulesArgs {
at::ScalarType a_type;
vllm::ScalarType b_type;
c10::optional<at::ScalarType> maybe_group_scales_type;
c10::optional<at::ScalarType> maybe_group_zeros_type;
c10::optional<at::ScalarType> maybe_channel_scales_type;
c10::optional<at::ScalarType> maybe_token_scales_type;
c10::optional<at::ScalarType> maybe_out_type;
};
torch::Tensor mm_dispatch(MMArgs args);
std::vector<std::string> supported_schedules_dispatch(
SupportedSchedulesArgs args);
template <typename MacheteKernel> template <typename MacheteKernel>
torch::Tensor run_impl(PyTorchArguments args) { torch::Tensor run_impl(MMArgs args) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A)); const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A));
auto device = args.A.device(); auto device = args.A.device();
auto stream = at::cuda::getCurrentCUDAStream(device.index()); auto stream = at::cuda::getCurrentCUDAStream(device.index());
using EleA = typename MacheteKernel::ElementA;
using EleB = typename MacheteKernel::ElementB;
using EleC = typename MacheteKernel::ElementC;
using EleD = typename MacheteKernel::ElementD;
using EleScale = typename MacheteKernel::ElementS;
using EleZero = typename MacheteKernel::ElementZ;
using StrideA = typename MacheteKernel::StrideA;
using StrideC = typename MacheteKernel::StrideC;
using StrideD = typename MacheteKernel::StrideD;
using StrideS = typename MacheteKernel::StrideS;
using StrideZ = typename MacheteKernel::StrideZ;
int M = args.A.size(0); int M = args.A.size(0);
int N = args.B.size(1); int N = args.B.size(1);
int K = args.A.size(1); int K = args.A.size(1);
// Allocate output // Allocate output
torch::Tensor D = torch::Tensor D = torch::empty(
torch::empty({M, N}, torch::TensorOptions() {M, N},
.dtype(equivalent_scalar_type_v<EleD>) torch::TensorOptions()
.device(device)); .dtype(equivalent_scalar_type_v<typename MacheteKernel::ElementD>)
.device(device));
auto const &A = args.A, &B = args.B;
auto const &C = args.C, &scales = args.scales, &zeros = args.zeros;
auto layout_A = make_cute_layout<StrideA>(A, "A");
auto layout_D = make_cute_layout<StrideD>(D, "D");
auto layout_C = maybe_make_cute_layout<StrideC>(C, "C");
auto layout_S = maybe_make_cute_layout<StrideS>(scales, "scales");
auto layout_Z = maybe_make_cute_layout<StrideZ>(zeros, "zeros");
auto A_ptr = static_cast<EleA const*>(A.const_data_ptr());
auto B_ptr = static_cast<EleB const*>(B.const_data_ptr());
auto D_ptr = static_cast<EleD*>(D.mutable_data_ptr());
auto C_ptr = static_cast<EleC const*>(C ? C->const_data_ptr() : nullptr);
auto S_ptr =
static_cast<EleScale const*>(scales ? scales->const_data_ptr() : nullptr);
auto Z_ptr =
static_cast<EleZero const*>(zeros ? zeros->const_data_ptr() : nullptr);
auto arguments = MacheteKernel::create_arguments( auto arguments = MacheteKernel::create_arguments(
stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr, stream, //
layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0), args.A, args.B, D, args.maybe_group_scales, args.maybe_group_zeros,
args.group_size); args.maybe_group_size, args.maybe_channel_scales,
args.maybe_token_scales);
TORCH_CHECK(MacheteKernel::can_implement(arguments), TORCH_CHECK(MacheteKernel::can_implement(arguments),
"Machete kernel cannot be run with these arguments"); "Machete kernel cannot be run with these arguments");
...@@ -84,12 +72,4 @@ torch::Tensor run_impl(PyTorchArguments args) { ...@@ -84,12 +72,4 @@ torch::Tensor run_impl(PyTorchArguments args) {
return D; return D;
}; };
template <typename ElementA, typename ElementB, typename ElementD = ElementA,
typename AccumulatorT = float, typename ScaleT = ElementA,
typename ZeroT = ElementA>
struct GemmDispatcher {
static torch::Tensor dispatch(PyTorchArguments args);
static std::vector<std::string> supported_schedules();
};
}; // namespace machete }; // namespace machete
\ No newline at end of file
...@@ -6,31 +6,49 @@ ...@@ -6,31 +6,49 @@
namespace machete { namespace machete {
template <typename TileShapeNKL, typename ElementB, typename BInTensor, template <int threads, typename PrepackedLayoutB, typename BInTensor,
typename BTiledOutTensor> typename ElementB>
static __global__ void prepack_B_kernel(BInTensor B_in, static __global__ void prepack_B_kernel(BInTensor B_in, ElementB* B_out_ptr) {
BTiledOutTensor B_tiled_out) { auto constexpr block_size =
auto tB_in = local_tile(B_in, TileShapeNKL{}, Int<size(typename PrepackedLayoutB::PPBlockShape_NK{})>{};
make_coord(blockIdx.x, blockIdx.y, blockIdx.z)); auto constexpr eles_per_thread = Int<block_size / threads>{};
auto tB_out = B_tiled_out(make_coord(_, _), static_assert(block_size % threads == 0,
make_coord(blockIdx.x, blockIdx.y), blockIdx.z); "block_size must be divisible by the number of threads");
auto tiled_copy = make_tiled_copy(Copy_Atom<DefaultCopy, ElementB>{}, // Which pre-packed are we responsible for
Layout<Shape<_4, _32>, Stride<_32, _1>>{}, auto blk_coord = make_coord(blockIdx.x, blockIdx.y, blockIdx.z);
Layout<Shape<_1, _2>>{}); auto tB_in = local_tile(
B_in, append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}),
blk_coord);
auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x); // Find the start offset in the output for this pre-packed block
auto bNbKL_to_offset = PrepackedLayoutB::bNbKL_to_offset(shape(B_in));
Tensor thr_tile_S = thr_copy.partition_S(tB_in); // Tensor representing a 1:1 mapping to the output space in 1D
Tensor thr_tile_D = thr_copy.partition_D(tB_out); auto tB_out_linear =
make_tensor(get_logical_ptr(B_out_ptr) + bNbKL_to_offset(blk_coord),
make_layout(make_shape(block_size)));
// Mapping from output space (1D) to input space
auto tB_in_linear = make_tensor(
tB_in.data(),
tB_in.layout()
.compose(right_inverse(PrepackedLayoutB::ppblock_ilvd_NK_to_offset()))
.with_shape(make_shape(block_size)));
// Tile for this specific thread (could have used a TiledCopy but these work
// best with 2d layouts, this is a simple 1d layout so local_tile is enough,
// we are also not that concerned with performance for this kernel)
auto thr_tB_in_linear =
local_tile(tB_in_linear, make_shape(eles_per_thread), threadIdx.x);
auto thr_tB_out_linear =
local_tile(tB_out_linear, make_shape(eles_per_thread), threadIdx.x);
// Construct a register-backed Tensor with the same shape as each thread's // Construct a register-backed Tensor with the same shape as each thread's
// partition // partition
auto fragment = make_tensor<ElementB>(shape(thr_tile_D)); auto fragment = make_tensor<ElementB>(shape(thr_tB_in_linear));
// Copy from GMEM to RMEM and from RMEM to GMEM copy(thr_tB_in_linear, fragment);
copy(tiled_copy, thr_tile_S, fragment); copy(Copy_Atom<DefaultCopy, uint8_t>{}, fragment, thr_tB_out_linear);
copy(Copy_Atom<DefaultCopy, uint8_t>{}, fragment, thr_tile_D);
} }
template <typename PrepackedLayoutB, typename InLayout> template <typename PrepackedLayoutB, typename InLayout>
...@@ -44,18 +62,15 @@ static void prepack_B_template( ...@@ -44,18 +62,15 @@ static void prepack_B_template(
TORCH_CHECK(size<0>(B_layout) % size<0>(TileShapeNKL{}) == 0); TORCH_CHECK(size<0>(B_layout) % size<0>(TileShapeNKL{}) == 0);
TORCH_CHECK(size<1>(B_layout) % size<1>(TileShapeNKL{}) == 0); TORCH_CHECK(size<1>(B_layout) % size<1>(TileShapeNKL{}) == 0);
TORCH_CHECK(size<2>(B_layout) % size<2>(TileShapeNKL{}) == 0);
auto N_tiles = size<0>(B_layout) / size<0>(TileShapeNKL{}); auto N_tiles = size<0>(B_layout) / size<0>(TileShapeNKL{});
auto K_tiles = size<1>(B_layout) / size<1>(TileShapeNKL{}); auto K_tiles = size<1>(B_layout) / size<1>(TileShapeNKL{});
auto L_tiles = size<2>(B_layout) / size<2>(TileShapeNKL{}); auto L_tiles = size<2>(B_layout);
auto B_in = make_tensor(get_logical_ptr(B_in_ptr), B_layout); auto B_in = make_tensor(get_logical_ptr(B_in_ptr), B_layout);
auto B_tiled_out =
make_tensor(get_logical_ptr(B_out_ptr), ilvd_NKbNbKL_to_offset);
prepack_B_kernel<TileShapeNKL, typename PrepackedLayoutB::ElementB> prepack_B_kernel<128, PrepackedLayoutB>
<<<dim3(N_tiles, K_tiles, L_tiles), 128, 0, stream>>>(B_in, B_tiled_out); <<<dim3(N_tiles, K_tiles, L_tiles), 128, 0, stream>>>(B_in, B_out_ptr);
} }
}; // namespace machete }; // namespace machete
\ No newline at end of file
...@@ -2,9 +2,17 @@ ...@@ -2,9 +2,17 @@
#include "machete_prepack_kernel.cuh" #include "machete_prepack_kernel.cuh"
#include "cutlass_extensions/torch_utils.hpp" #include "cutlass_extensions/torch_utils.hpp"
#include "core/scalar_type.hpp"
namespace machete { namespace machete {
struct PrepackBArgs {
torch::Tensor const& B;
at::ScalarType a_type;
vllm::ScalarType b_type;
c10::optional<at::ScalarType> maybe_group_scales_type;
};
template <typename PrepackedLayoutB> template <typename PrepackedLayoutB>
torch::Tensor prepack_impl(torch::Tensor const B) { torch::Tensor prepack_impl(torch::Tensor const B) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(B)); const at::cuda::OptionalCUDAGuard device_guard(device_of(B));
...@@ -61,11 +69,6 @@ torch::Tensor prepack_impl(torch::Tensor const B) { ...@@ -61,11 +69,6 @@ torch::Tensor prepack_impl(torch::Tensor const B) {
return D; return D;
}; };
template <typename ElementA, typename ElementB, typename ElementD, torch::Tensor prepack_B_dispatch(PrepackBArgs args);
typename AccumulatorT = float, typename ScaleT = cutlass::half_t,
typename ZeroT = cutlass::half_t>
struct PrepackBDispatcher {
static torch::Tensor dispatch(torch::Tensor B);
};
}; // namespace machete }; // namespace machete
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment