Commit 4d3a2c28 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.5' into v0.6.5-dev

parents 92ec5d8e 2d1b9baa
...@@ -20,9 +20,9 @@ CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) { ...@@ -20,9 +20,9 @@ CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) {
// is the layout f(x) = x // is the layout f(x) = x
template <typename Layout> template <typename Layout>
CUTE_HOST_DEVICE static constexpr bool is_identity_layout() { CUTE_HOST_DEVICE static constexpr bool is_identity_layout() {
if constexpr (std::is_same_v<Layout, void>) if constexpr (std::is_same_v<Layout, void>) {
return true; return true;
else { } else {
constexpr auto coalesced_layout = coalesce(Layout{}); constexpr auto coalesced_layout = coalesce(Layout{});
if constexpr (rank(coalesced_layout) == 1 && if constexpr (rank(coalesced_layout) == 1 &&
stride<0>(coalesced_layout) == 1) { stride<0>(coalesced_layout) == 1) {
......
...@@ -52,6 +52,7 @@ ...@@ -52,6 +52,7 @@
// clang-format off // clang-format off
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" #include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cute/tensor.hpp" #include "cute/tensor.hpp"
namespace cutlass::epilogue::threadblock { namespace cutlass::epilogue::threadblock {
......
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
/*
This file defines custom epilogues for fusing channel scales, token scales,
bias, and activation zero-points onto a GEMM operation using the
CUTLASS 2.x API, for sm80 (Ampere) NVIDIA GPUs.
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace vllm::c2x {
using namespace cute;
/*
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBase {
protected:
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
template <typename T>
using ColOrScalarLoad =
cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowOrScalarLoad =
cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
template <typename T>
using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
template <typename T>
using RowOrZeroLoad =
cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template <typename Descriptor, typename T>
static auto args_from_tensor(torch::Tensor const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
return Arguments{data_ptr, tensor.numel() != 1};
} else {
// it would technically work but no use case as data_ptr is never nullptr
static_assert(!std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
return Arguments{data_ptr};
}
}
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T>
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
return Arguments{data_ptr};
}
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch._scaled_mm.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogue
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
* This bias can also be used in the per-tensor azp case, where the activation
* zero point (azp) is used to compute an azp correction term,
* which is folded into the bias.
*
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBias
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
protected:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args, bias_args};
}
};
/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
* term, which should already be multiplied with the scalar azp.
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBiasAzp
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
// This is the full AZP term, azp * J @ B, shape (1,n)
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
// Compute float(accum - azp_adj), both operands are int32_t
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>;
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
EVTComputeAzp>;
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
return ArgumentType{a_args, evt_scale_b_args, bias_args};
}
};
/*
* This epilogue supports per-token azp by computing and applying
* the correction term using a rank-1 update. If the term were materialized,
* it would require O(m*n) space, and this way it only requires O(m+n) space.
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
* point for each row of A.
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBiasAzpToken
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
// Per-token azp term, shape (m,1)
using Azp = typename SUPER::template ColLoad<int32_t>;
// This is the AZP adjustment term, J @ B, shape (1,n)
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
// Compute azp * azp_adj
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, int32_t, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Azp, AzpAdj>;
// Compute float(accum - azp*azp_adj), all operands are int32_t
using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAcc =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>;
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
EVTComputeAcc>;
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
torch::Tensor const& azp,
c10::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
return ArgumentType{a_args, evt_scale_b_args, bias_args};
}
};
}; // namespace vllm::c2x
\ No newline at end of file
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
/*
This file defines custom epilogues for fusing channel scales, token scales,
bias, and activation zero-points onto a GEMM operation using the
CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later.
Epilogues must contain a public type named EVTCompute of type Sm90EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace vllm::c3x {
using namespace cute;
/*
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBase {
protected:
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
template <typename T>
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<0>, Int<1>, Int<0>>>;
// Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false>
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
// Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false>
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template <typename Descriptor, typename T>
static auto args_from_tensor(torch::Tensor const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
return Arguments{data_ptr, tensor.numel() != 1};
} else {
static_assert(!std::is_same_v<Descriptor, ColLoad<T, true>> &&
!std::is_same_v<Descriptor, RowLoad<T, true>>);
return Arguments{data_ptr};
}
}
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T>
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
std::is_same_v<Descriptor, RowLoad<T, true>>);
return Arguments{data_ptr};
}
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch.scaled_mm_.
A and B may be both either int8 or fp8_e4m3. A can be
quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogue
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
* This bias can also be used in the per-tensor azp case, where the activation
* zero point (azp) is used to compute an azp correction term,
* which is folded into the bias.
*
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBias
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args, bias_args};
}
};
/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
* term, which should already be multiplied with the scalar azp.
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBiasAzp
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD, true>;
// This is the full AZP term, azp * J @ B, shape (1,n)
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
// Compute float(accum - azp_adj), both operands are int32_t
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Accum, AzpWithAdj>;
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
return ArgumentType{a_args, evt_scale_b_args, bias_args};
}
};
/*
* This epilogue supports per-token azp by computing and applying
* the correction term using a rank-1 update. If the term were materialized,
* it would require O(m*n) space, and this way it only requires O(m+n) space.
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
* point for each row of A.
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBiasAzpToken
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD, true>;
// Per-token azp term, shape (m,1)
using Azp = typename SUPER::template ColLoad<int32_t>;
// This is the AZP adjustment term, J @ B, shape (1,n)
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
// Compute azp * azp_adj
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, int32_t, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Azp, AzpAdj>;
// Compute float(accum - azp*azp_adj), all operands are int32_t
using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAcc =
cutlass::epilogue::fusion::Sm90EVT<ComputeAcc, Accum, EVTComputeAzp>;
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
torch::Tensor const& azp,
c10::optional<torch::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
return ArgumentType{a_args, evt_scale_b_args, bias_args};
}
};
}; // namespace vllm::c3x
\ No newline at end of file
...@@ -35,6 +35,35 @@ VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = { ...@@ -35,6 +35,35 @@ VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
} }
} }
VLLMDataTypeSize: Dict[Union[VLLMDataType, DataType], int] = {
**DataTypeSize, # type: ignore
**{
VLLMDataType.u4b8: 4,
VLLMDataType.u8b128: 8,
}
}
VLLMDataTypeVLLMScalarTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
VLLMDataType.u4b8: "vllm::kU4B8",
VLLMDataType.u8b128: "vllm::kU8B128",
DataType.u4: "vllm::kU4",
DataType.u8: "vllm::kU8",
DataType.s4: "vllm::kS4",
DataType.s8: "vllm::kS8",
DataType.f16: "vllm::kFloat16",
DataType.bf16: "vllm::kBfloat16",
}
VLLMDataTypeTorchDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
DataType.u8: "at::ScalarType::Byte",
DataType.s8: "at::ScalarType::Char",
DataType.e4m3: "at::ScalarType::Float8_e4m3fn",
DataType.s32: "at::ScalarType::Int",
DataType.f16: "at::ScalarType::Half",
DataType.bf16: "at::ScalarType::BFloat16",
DataType.f32: "at::ScalarType::Float",
}
VLLMKernelScheduleTag: Dict[Union[ VLLMKernelScheduleTag: Dict[Union[
MixedInputKernelScheduleType, KernelScheduleType], str] = { MixedInputKernelScheduleType, KernelScheduleType], str] = {
**KernelScheduleTag, # type: ignore **KernelScheduleTag, # type: ignore
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "cutlass/numeric_conversion.h" #include "cutlass/numeric_conversion.h"
#include "cutlass_extensions/vllm_custom_types.cuh" #include "cutlass_extensions/vllm_custom_types.cuh"
#include "cutlass_extensions/cute_utils.cuh" #include "cutlass_extensions/cute_utils.cuh"
#include "cutlass_extensions/vllm_type_utils.cuh"
// this file extends: // this file extends:
// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h // https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h
...@@ -28,8 +29,19 @@ struct InterleavedNumericArrayConverter { ...@@ -28,8 +29,19 @@ struct InterleavedNumericArrayConverter {
CUTLASS_DEVICE CUTLASS_DEVICE
static result_type convert(source_type const& source) { static result_type convert(source_type const& source) {
CUTE_INVALID_CONTROL_PATH( if (cute::elect_one_sync()) {
"InterleavedNumericArrayConverter not implemented\n"); if constexpr (std::is_same_v<IlvBlkLayout, void>) {
printf(
"Convert %s <= %s (N = %d, IlvBlkLayout = void), not implemented\n",
nameof_v<T>, nameof_v<S>, N);
} else {
printf(
"Convert %s <= %s (N = %d, size(IlvBlkLayout{}) = %d), not "
"implemented\n",
nameof_v<T>, nameof_v<S>, N, size(IlvBlkLayout{}));
}
__brkpt();
}
return {}; return {};
} }
...@@ -56,11 +68,6 @@ struct InterleavedNumericArrayConverter< ...@@ -56,11 +68,6 @@ struct InterleavedNumericArrayConverter<
result_type operator()(source_type const& s) const { return convert(s); } result_type operator()(source_type const& s) const { return convert(s); }
}; };
// TODO (LucasWilkinson): Implement
// for Array<cutlass::float8_e4m3fn, N> <= Array<vllm_uint4b8_t, N>
// ....
template <typename RegConvert32bit, typename T, typename S, int N> template <typename RegConvert32bit, typename T, typename S, int N>
struct ArrayConverterPacked32Bit { struct ArrayConverterPacked32Bit {
using result_type = Array<T, N>; using result_type = Array<T, N>;
...@@ -86,14 +93,16 @@ struct ArrayConverterPacked32Bit { ...@@ -86,14 +93,16 @@ struct ArrayConverterPacked32Bit {
using ScalarConverter = NumericConverter<T, S>; using ScalarConverter = NumericConverter<T, S>;
template <typename PackedSrc> template <typename PackedSrc>
CUTLASS_DEVICE static uint32_t to_reg(PackedSrc const& source) { CUTLASS_DEVICE static auto to_regs(PackedSrc const& src) {
if constexpr (sizeof(PackedSrc) == 1) { if constexpr (sizeof(PackedSrc) == 1) {
return static_cast<uint32_t>(reinterpret_cast<const uint8_t&>(source)); return Array<uint32_t, 1>{reinterpret_cast<uint8_t const&>(src)};
} else if constexpr (sizeof(PackedSrc) == 2) { } else if constexpr (sizeof(PackedSrc) == 2) {
return static_cast<uint32_t>(reinterpret_cast<const uint16_t&>(source)); return Array<uint32_t, 1>{reinterpret_cast<uint16_t const&>(src)};
} else if constexpr (sizeof(PackedSrc) == 4) {
return Array<uint32_t, 1>{reinterpret_cast<uint32_t const&>(src)};
} else { } else {
static_assert(sizeof(PackedSrc) == 4); static_assert(sizeof(PackedSrc) == 8);
return reinterpret_cast<const uint32_t&>(source); return reinterpret_cast<Array<uint32_t, 2> const&>(src);
} }
} }
...@@ -110,7 +119,7 @@ struct ArrayConverterPacked32Bit { ...@@ -110,7 +119,7 @@ struct ArrayConverterPacked32Bit {
static_assert(std::is_same_v<typename PackedSrcType::Element, S>); static_assert(std::is_same_v<typename PackedSrcType::Element, S>);
static_assert(std::is_same_v<typename PackedResultType::Element, T>); static_assert(std::is_same_v<typename PackedResultType::Element, T>);
return RegConvert32bit::template convert<PackedResultType>(to_reg(source)); return RegConvert32bit::template convert<PackedResultType>(to_regs(source));
} }
friend class detail::VectorizedConverter; friend class detail::VectorizedConverter;
...@@ -140,6 +149,131 @@ struct ArrayConverterPacked32Bit { ...@@ -140,6 +149,131 @@ struct ArrayConverterPacked32Bit {
} }
}; };
// Convert 8 4bit values packed into a 32bit register to 8 8bit values packed
// into 2 32bit register.
template <uint8_t LUT0, uint8_t LUT1, uint8_t LUT2, uint8_t LUT3, //
uint8_t LUT4, uint8_t LUT5, uint8_t LUT6, uint8_t LUT7, //
uint8_t LUT8, uint8_t LUT9, uint8_t LUT10, uint8_t LUT11, //
uint8_t LUT12, uint8_t LUT13, uint8_t LUT14, uint8_t LUT15>
CUTLASS_DEVICE cutlass::AlignedArray<uint32_t, 2> lut_4bit_to_8bit_convert(
uint32_t src) {
cutlass::AlignedArray<uint32_t, 2> r;
// Determines if the value is in the top half of the LUT if set or
// (i.e. LUT[8:15]) in the bottom half (i.e. LUT[0:7]) if not set. Then move
// into bit position 0x4 of each nibble so when or'd with final_prmt_base it
// selects the correct candidate. When elements in final_prmt_base
// are >= 0x4, the high candidate is selected (i.e. LUT[8:15]), when elements
// are < 0x4, the low candidate is selected (i.e. LUT[0:7])
uint32_t high_bit = (src & 0x88888888) >> 1;
// `high_bit` is OR'd with 0x31203120 to find the correct value in the LUT
// (selects correct high or low candidate)
const uint32_t final_prmt_base = 0x32103210;
// Ignore the high bit when indexing into LUT, for each 4bit value
// we index into both the high and low candidates then use
// high_bit | final_prmt_base to select the correct candidate
uint32_t lut_idx = (src & 0x77777777);
auto pack = [](uint8_t a, uint8_t b, uint8_t c, uint8_t d) {
return uint32_t(a) | (uint32_t(b) << 8) | (uint32_t(c) << 16) |
(uint32_t(d) << 24);
};
static constexpr uint32_t LOW_0 = pack(LUT0, LUT1, LUT2, LUT3);
static constexpr uint32_t LOW_1 = pack(LUT4, LUT5, LUT6, LUT7);
static constexpr uint32_t HIGH_0 = pack(LUT8, LUT9, LUT10, LUT11);
static constexpr uint32_t HIGH_1 = pack(LUT12, LUT13, LUT14, LUT15);
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < 2; ++ii, lut_idx >>= 16, high_bit >>= 16) {
uint32_t final_prmt_idx = final_prmt_base | high_bit;
// This uses a look up table to convert packed int4s to packed int8s,
// using the int4 value as the index to prmt. It first select both the
// high and low candidates, then uses the high bit (i.e. `high_bit`) to
// select the correct candidate.
asm volatile(
"{\n"
" .reg .b32 low, high;\n"
" prmt.b32 low, %1, %2, %5;\n"
" prmt.b32 high, %3, %4, %5;\n"
" prmt.b32 %0, low, high, %6;\n"
"}\n"
: "=r"(r[ii])
: "n"(LOW_0), "n"(LOW_1), "n"(HIGH_0), "n"(HIGH_1), "r"(lut_idx),
"r"(final_prmt_idx));
}
return r;
};
// for Array<int8_t, N> <= Array<vllm_uint4b8_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<int8_t, vllm_uint4b8_t, N, Round> {
using result_type = Array<int8_t, N>;
using source_type = Array<vllm_uint4b8_t, N>;
static FloatRoundStyle const round_style = Round;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
// [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as int8s
auto r = lut_4bit_to_8bit_convert<0xF8, 0xF9, 0xFA, 0xFB, //
0xFC, 0xFD, 0xFE, 0xFF, //
0x00, 0x01, 0x02, 0x03, //
0x04, 0x05, 0x06, 0x07>(src_[0]);
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::float_e4m3_t, N> <= Array<vllm_uint4b8_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::float_e4m3_t, vllm_uint4b8_t, N, Round> {
using result_type = Array<cutlass::float_e4m3_t, N>;
using source_type = Array<vllm_uint4b8_t, N>;
static FloatRoundStyle const round_style = Round;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
// [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as fp8s
auto r = lut_4bit_to_8bit_convert<0xD0, 0xCE, 0xCC, 0xCA, //
0xC8, 0xC4, 0xC0, 0xB8, //
0x00, 0x38, 0x40, 0x44, //
0x48, 0x4A, 0x4C, 0x4E>(src_[0]);
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::half_t, N> <= Array<vllm_uint4b8_t, N> // for Array<cutlass::half_t, N> <= Array<vllm_uint4b8_t, N>
template <FloatRoundStyle Round, int N> template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::half_t, vllm_uint4b8_t, N, Round> { struct NumericArrayConverter<cutlass::half_t, vllm_uint4b8_t, N, Round> {
...@@ -148,7 +282,8 @@ struct NumericArrayConverter<cutlass::half_t, vllm_uint4b8_t, N, Round> { ...@@ -148,7 +282,8 @@ struct NumericArrayConverter<cutlass::half_t, vllm_uint4b8_t, N, Round> {
struct RegConvert { struct RegConvert {
template <typename PackedResultType> template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
using RegArray = using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2, cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>; sizeof(PackedResultType)>;
...@@ -249,7 +384,8 @@ struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>, ...@@ -249,7 +384,8 @@ struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
private: private:
struct RegConvert { struct RegConvert {
template <typename PackedResultType> template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
using RegArray = using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2, cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>; sizeof(PackedResultType)>;
...@@ -338,7 +474,8 @@ struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>, ...@@ -338,7 +474,8 @@ struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
private: private:
struct RegConvert { struct RegConvert {
template <typename PackedResultType> template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
using RegArray = using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2, cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>; sizeof(PackedResultType)>;
...@@ -417,7 +554,8 @@ struct NumericArrayConverter<cutlass::half_t, vllm_uint8b128_t, N, Round> { ...@@ -417,7 +554,8 @@ struct NumericArrayConverter<cutlass::half_t, vllm_uint8b128_t, N, Round> {
struct RegConvert { struct RegConvert {
template <typename PackedResultType> template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
// Hold output FP16s in reg. We need 1 reg for every 2 elements // Hold output FP16s in reg. We need 1 reg for every 2 elements
using RegArray = using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2, cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
...@@ -469,7 +607,8 @@ struct NumericArrayConverter<float, vllm_uint8b128_t, N, Round> { ...@@ -469,7 +607,8 @@ struct NumericArrayConverter<float, vllm_uint8b128_t, N, Round> {
private: private:
struct RegConvert { struct RegConvert {
template <typename PackedResultType> template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
PackedResultType r; PackedResultType r;
// __byte_perm simulates the add.u32 0x4B000000 to every u8 element of // __byte_perm simulates the add.u32 0x4B000000 to every u8 element of
...@@ -513,7 +652,8 @@ struct NumericArrayConverter<cutlass::bfloat16_t, vllm_uint4b8_t, N, Round> { ...@@ -513,7 +652,8 @@ struct NumericArrayConverter<cutlass::bfloat16_t, vllm_uint4b8_t, N, Round> {
private: private:
struct RegConvert { struct RegConvert {
template <typename PackedResultType> template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src_reg) { CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src_reg = src_[0];
// Hold output BF16s in reg. We need 1 reg for every 2 elements // Hold output BF16s in reg. We need 1 reg for every 2 elements
using RegArray = using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2, cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
...@@ -603,7 +743,8 @@ struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>, ...@@ -603,7 +743,8 @@ struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
private: private:
struct RegConvert { struct RegConvert {
template <typename PackedResultType> template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
using RegArray = using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2, cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>; sizeof(PackedResultType)>;
...@@ -671,7 +812,8 @@ struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>, ...@@ -671,7 +812,8 @@ struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
private: private:
struct RegConvert { struct RegConvert {
template <typename PackedResultType> template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
using RegArray = using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2, cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>; sizeof(PackedResultType)>;
...@@ -788,6 +930,61 @@ struct NumericArrayConverter<cutlass::bfloat16_t, vllm_uint8b128_t, N, Round> { ...@@ -788,6 +930,61 @@ struct NumericArrayConverter<cutlass::bfloat16_t, vllm_uint8b128_t, N, Round> {
#endif #endif
// for Array<int8_t, N> <= Array<cutlass::half_t, N>
// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<int8_t, cutlass::half_t, N, Round> {
using result_type = Array<int8_t, N>;
using source_type = Array<cutlass::half_t, N>;
struct RegConvert {
// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904
template <typename PackedResultType, int src_regs>
CUTLASS_DEVICE static PackedResultType convert(
Array<uint32_t, src_regs> src) {
// Hold output int8s in reg. We need 1 reg for every 4 elements
using RegArray = cutlass::AlignedArray<
uint32_t, std::max(PackedResultType::kElements / 4, size_t(1))>;
RegArray r;
static constexpr uint32_t MAGIC_BIAS_ = 0x64806480;
auto MAGIC_BIAS = *reinterpret_cast<const half2*>(&MAGIC_BIAS_);
*reinterpret_cast<half2*>(&src[0]) =
__hadd2(*reinterpret_cast<half2*>(&src[0]), MAGIC_BIAS);
if constexpr (src_regs > 1) {
*reinterpret_cast<half2*>(&src[1]) =
__hadd2(*reinterpret_cast<half2*>(&src[1]), MAGIC_BIAS);
}
static_assert(PackedResultType::kElements <= 4);
uint32_t uint8s;
static constexpr uint32_t MASK_0246 = 0x6420;
static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080;
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
: "=r"(uint8s)
: "r"(src[0]), "r"((src_regs > 1) ? src[1] : src[0]),
"n"(MASK_0246));
uint32_t int8s = (uint8s ^ UINT8s_TO_INT8s_MASK);
return reinterpret_cast<PackedResultType&>(int8s);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass } // namespace cutlass
......
#include "cutlass/bfloat16.h"
#include "cutlass/half.h"
#include "cuda_bf16.h"
#include "cutlass_extensions/vllm_custom_types.cuh"
namespace cutlass {
template <typename T>
struct nameof {
static constexpr char const* value = "unknown";
};
template <typename T>
inline constexpr auto nameof_v = nameof<T>::value;
#define NAMEOF_TYPE(T) \
template <> \
struct nameof<T> { \
static constexpr char const* value = #T; \
};
NAMEOF_TYPE(float_e4m3_t)
NAMEOF_TYPE(float_e5m2_t)
NAMEOF_TYPE(half_t)
NAMEOF_TYPE(nv_bfloat16)
NAMEOF_TYPE(bfloat16_t)
NAMEOF_TYPE(float)
NAMEOF_TYPE(int4b_t)
NAMEOF_TYPE(int8_t)
NAMEOF_TYPE(int32_t)
NAMEOF_TYPE(int64_t)
NAMEOF_TYPE(vllm_uint4b8_t)
NAMEOF_TYPE(uint4b_t)
NAMEOF_TYPE(uint8_t)
NAMEOF_TYPE(vllm_uint8b128_t)
NAMEOF_TYPE(uint32_t)
NAMEOF_TYPE(uint64_t)
}; // namespace cutlass
\ No newline at end of file
...@@ -14,6 +14,20 @@ ...@@ -14,6 +14,20 @@
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
// TODO(luka/varun): use FP8_TYPE macro after refactoring
#ifndef USE_ROCM
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#else
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#endif
#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
......
#include <torch/all.h> #include "type_convert.cuh"
#include <ATen/cuda/CUDAContext.h> #include "dispatch_utils.h"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include "dispatch_utils.h"
#ifndef USE_ROCM #ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cub/util_type.cuh>
#include <cub/cub.cuh> #include <cub/cub.cuh>
#else #else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp> #include <hipcub/hipcub.hpp>
using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162;
#endif #endif
namespace vllm { namespace vllm {
...@@ -51,155 +43,6 @@ __global__ void rms_norm_kernel( ...@@ -51,155 +43,6 @@ __global__ void rms_norm_kernel(
} }
} }
/* Converter structs for the conversion from torch types to HIP/CUDA types,
and the associated type conversions within HIP/CUDA. These helpers need
to be implemented for now because the relevant type conversion
operators/constructors are not consistently implemented by HIP/CUDA, so
a generic conversion via type casts cannot be implemented.
Each struct should have the member static constexpr bool `exists`:
If false, the optimized kernel is not used for the corresponding torch type.
If true, the struct should be fully defined as shown in the examples below.
*/
template <typename torch_type>
struct _typeConvert {
static constexpr bool exists = false;
};
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion
template <>
struct _typeConvert<c10::Half> {
static constexpr bool exists = true;
using hip_type = __half;
using packed_hip_type = __half2;
__device__ static inline float convert(hip_type x) { return __half2float(x); }
__device__ static inline float2 convert(packed_hip_type x) {
return __half22float2(x);
}
__device__ static inline hip_type convert(float x) {
return __float2half_rn(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22half2_rn(x);
}
};
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely
template <>
struct _typeConvert<c10::BFloat16> {
static constexpr bool exists = true;
using hip_type = __nv_bfloat16;
using packed_hip_type = __nv_bfloat162;
__device__ static inline float convert(hip_type x) {
return __bfloat162float(x);
}
__device__ static inline float2 convert(packed_hip_type x) {
return __bfloat1622float2(x);
}
__device__ static inline hip_type convert(float x) {
return __float2bfloat16(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22bfloat162_rn(x);
}
};
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
// 12000))
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
for appropriate specializations of fused_add_rms_norm_kernel.
Only functions that are necessary in that kernel are implemented.
Alignment to 16 bytes is required to use 128-bit global memory ops.
*/
template <typename scalar_t, int width>
struct alignas(16) _f16Vec {
/* Not theoretically necessary that width is a power of 2 but should
almost always be the case for optimization purposes */
static_assert(width > 0 && (width & (width - 1)) == 0,
"Width is not a positive power of 2!");
using Converter = _typeConvert<scalar_t>;
using T1 = typename Converter::hip_type;
using T2 = typename Converter::packed_hip_type;
T1 data[width];
__device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i + 1]};
temp += T2{other.data[i], other.data[i + 1]};
data[i] = temp.x;
data[i + 1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) data[i] += other.data[i];
}
return *this;
}
__device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i + 1]};
temp *= T2{other.data[i], other.data[i + 1]};
data[i] = temp.x;
data[i + 1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) data[i] *= other.data[i];
}
return *this;
}
__device__ _f16Vec& operator*=(const float scale) {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 temp_f = Converter::convert(T2{data[i], data[i + 1]});
temp_f.x *= scale;
temp_f.y *= scale;
T2 temp = Converter::convert(temp_f);
data[i] = temp.x;
data[i + 1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) {
float temp = Converter::convert(data[i]) * scale;
data[i] = Converter::convert(temp);
}
}
return *this;
}
__device__ float sum_squares() const {
float result = 0.0f;
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 z = Converter::convert(T2{data[i], data[i + 1]});
result += z.x * z.x + z.y * z.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) {
float x = Converter::convert(data[i]);
result += x * x;
}
}
return result;
}
};
/* Function specialization in the case of FP16/BF16 tensors. /* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are Additional optimizations we can make in this case are
packed and vectorized operations, which help with the packed and vectorized operations, which help with the
......
/*
* This file contains the CUDA kernels for the fused quantized layernorm.
* The kernels correspond to the kernels in layernorm_kernels.cu, except they
* also produce quantized output directly.
* Currently, only static fp8 quantization is supported.
*/
#include "type_convert.cuh"
#include "quantization/fp8/common.cuh"
#include "dispatch_utils.h"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#ifndef USE_ROCM
#include <cub/cub.cuh>
#else
#include <hipcub/hipcub.hpp>
#endif
namespace vllm {
// TODO(woosuk): Further optimize this kernel.
template <typename scalar_t>
__global__ void rms_norm_static_fp8_quant_kernel(
FP8_TYPE* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float* __restrict__ scale, // [1]
const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float)input[blockIdx.x * hidden_size + idx];
variance += x * x;
}
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
// invert scale to avoid division
float const scale_inv = 1.0f / *scale;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
out[blockIdx.x * hidden_size + idx] =
scaled_fp8_conversion<true>(out_norm, scale_inv);
}
}
/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck. */
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
fused_add_rms_norm_static_fp8_quant_kernel(
FP8_TYPE* __restrict__ out, // [..., hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float* __restrict__ scale, // [1]
const float epsilon, const int num_tokens, const int hidden_size) {
// Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
const int vec_hidden_size = hidden_size / width;
__shared__ float s_variance;
float variance = 0.0f;
/* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */
auto* __restrict__ input_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
auto* __restrict__ residual_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
auto* __restrict__ weight_v =
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16Vec<scalar_t, width> temp = input_v[id];
temp += residual_v[id];
variance += temp.sum_squares();
residual_v[id] = temp;
}
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
// invert scale to avoid division
float const scale_inv = 1.0f / *scale;
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16Vec<scalar_t, width> temp = residual_v[id];
temp *= s_variance;
temp *= weight_v[idx];
#pragma unroll
for (int i = 0; i < width; ++i) {
out[id * width + i] =
scaled_fp8_conversion<true>(float(temp.data[i]), scale_inv);
}
}
}
/* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
fused_add_rms_norm_static_fp8_quant_kernel(
FP8_TYPE* __restrict__ out, // [..., hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float* __restrict__ scale, // [1]
const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
scalar_t z = input[blockIdx.x * hidden_size + idx];
z += residual[blockIdx.x * hidden_size + idx];
float x = (float)z;
variance += x * x;
residual[blockIdx.x * hidden_size + idx] = z;
}
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
// invert scale to avoid division
float const scale_inv = 1.0f / *scale;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)residual[blockIdx.x * hidden_size + idx];
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
out[blockIdx.x * hidden_size + idx] =
scaled_fp8_conversion<true>(out_norm, scale_inv);
}
}
} // namespace vllm
void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
torch::Tensor& scale, // [1]
double epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
vllm::rms_norm_static_fp8_quant_kernel<scalar_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), epsilon,
num_tokens, hidden_size);
});
}
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, width> \
<<<grid, block, 0, stream>>>( \
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
scale.data_ptr<float>(), epsilon, num_tokens, hidden_size); \
});
void fused_add_rms_norm_static_fp8_quant(
torch::Tensor& out, // [..., hidden_size],
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
torch::Tensor& scale, // [1]
double epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 block(std::min(hidden_size, max_block_size));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s
since we can load at most 128 bits at once in a global memory op.
However, this requires each tensor's data to be aligned to 16
bytes.
*/
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0) {
LAUNCH_FUSED_ADD_RMS_NORM(8);
} else {
LAUNCH_FUSED_ADD_RMS_NORM(0);
}
}
...@@ -39,8 +39,6 @@ ...@@ -39,8 +39,6 @@
template<typename input_t, typename weight_t> template<typename input_t, typename weight_t>
void causal_conv1d_fwd_cuda(ConvParamsBase &params, cudaStream_t stream); void causal_conv1d_fwd_cuda(ConvParamsBase &params, cudaStream_t stream);
template <typename input_t, typename weight_t>
void causal_conv1d_channellast_fwd_cuda(ConvParamsBase &params, cudaStream_t stream);
template<typename input_t, typename weight_t> template<typename input_t, typename weight_t>
void causal_conv1d_update_cuda(ConvParamsBase &params, cudaStream_t stream); void causal_conv1d_update_cuda(ConvParamsBase &params, cudaStream_t stream);
...@@ -55,8 +53,12 @@ void set_conv_params_fwd(ConvParamsBase &params, ...@@ -55,8 +53,12 @@ void set_conv_params_fwd(ConvParamsBase &params,
const at::Tensor x, const at::Tensor x,
const at::Tensor weight, const at::Tensor weight,
const at::Tensor out, const at::Tensor out,
void* bias_ptr, const c10::optional<at::Tensor>& bias,
bool silu_activation) { bool silu_activation,
int64_t pad_slot_id,
const c10::optional<at::Tensor>& query_start_loc = std::nullopt,
const c10::optional<at::Tensor>& cache_indices = std::nullopt,
const c10::optional<at::Tensor>& has_initial_state = std::nullopt) {
// Reset the parameters // Reset the parameters
memset(&params, 0, sizeof(params)); memset(&params, 0, sizeof(params));
...@@ -65,33 +67,41 @@ void set_conv_params_fwd(ConvParamsBase &params, ...@@ -65,33 +67,41 @@ void set_conv_params_fwd(ConvParamsBase &params,
params.dim = dim; params.dim = dim;
params.seqlen = seqlen; params.seqlen = seqlen;
params.width = width; params.width = width;
params.pad_slot_id = pad_slot_id;
params.silu_activation = silu_activation; params.silu_activation = silu_activation;
// Set the pointers and strides. // Set the pointers and strides.
params.x_ptr = x.data_ptr(); params.x_ptr = x.data_ptr();
params.weight_ptr = weight.data_ptr(); params.weight_ptr = weight.data_ptr();
params.bias_ptr = bias_ptr; params.bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr;
params.out_ptr = out.data_ptr(); params.out_ptr = out.data_ptr();
// All stride are in elements, not bytes. // All stride are in elements, not bytes.
params.x_batch_stride = x.stride(0); params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr;
params.x_c_stride = x.stride(1); params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr;
params.x_l_stride = x.stride(-1); params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr;
const bool varlen = params.query_start_loc_ptr != nullptr;
params.x_batch_stride = x.stride(varlen ? 1 : 0);
params.x_c_stride = x.stride(varlen ? 0 : 1);
params.x_l_stride = x.stride(varlen ? 1 : -1);
params.weight_c_stride = weight.stride(0); params.weight_c_stride = weight.stride(0);
params.weight_width_stride = weight.stride(1); params.weight_width_stride = weight.stride(1);
params.out_batch_stride = out.stride(0); params.out_batch_stride = out.stride(varlen ? 1 : 0);
params.out_c_stride = out.stride(1); params.out_c_stride = out.stride(varlen ? 0 : 1);
params.out_l_stride = out.stride(-1); params.out_l_stride = out.stride(varlen ? 1 : -1);
} }
at::Tensor void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
const c10::optional<at::Tensor> &bias_, const c10::optional<at::Tensor> &bias_,
const c10::optional<at::Tensor> &seq_idx_, const c10::optional<at::Tensor> &conv_states,
const c10::optional<at::Tensor> &initial_states_, const c10::optional<at::Tensor> &query_start_loc,
const c10::optional<at::Tensor> &final_states_out_, const c10::optional<at::Tensor> &cache_indices,
bool silu_activation) { const c10::optional<at::Tensor> &has_initial_state,
bool silu_activation,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t pad_slot_id) {
auto input_type = x.scalar_type(); auto input_type = x.scalar_type();
auto weight_type = weight.scalar_type(); auto weight_type = weight.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
...@@ -99,24 +109,22 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, ...@@ -99,24 +109,22 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
TORCH_CHECK(x.is_cuda()); TORCH_CHECK(x.is_cuda());
TORCH_CHECK(weight.is_cuda()); TORCH_CHECK(weight.is_cuda());
const bool varlen = query_start_loc.has_value() ? true : false;
const auto sizes = x.sizes(); const auto sizes = x.sizes();
const int batch_size = sizes[0]; const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0];
const int dim = sizes[1]; const int dim = varlen ? sizes[0] : sizes[1];
const int seqlen = sizes[2]; const int seqlen = varlen ? sizes[1] : sizes[2];
const int width = weight.size(-1); const int width = weight.size(-1);
if (varlen){
CHECK_SHAPE(x, batch_size, dim, seqlen); CHECK_SHAPE(x, dim, seqlen);
}
else {
CHECK_SHAPE(x, batch_size, dim, seqlen);
}
CHECK_SHAPE(weight, dim, width); CHECK_SHAPE(weight, dim, width);
TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
if (is_channel_last) {
TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8");
}
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
if (bias_.has_value()) { if (bias_.has_value()) {
auto bias = bias_.value(); auto bias = bias_.value();
...@@ -126,56 +134,51 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, ...@@ -126,56 +134,51 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
CHECK_SHAPE(bias, dim); CHECK_SHAPE(bias, dim);
} }
if (seq_idx_.has_value()) {
TORCH_CHECK(is_channel_last, "seq_idx is only supported for channel last layout");
auto seq_idx = seq_idx_.value();
TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32);
TORCH_CHECK(seq_idx.is_cuda());
TORCH_CHECK(seq_idx.is_contiguous());
CHECK_SHAPE(seq_idx, batch_size, seqlen);
}
at::Tensor out = torch::empty_like(x); if (has_initial_state.has_value()) {
auto has_initial_state_ = has_initial_state.value();
TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool);
TORCH_CHECK(has_initial_state_.is_cuda());
CHECK_SHAPE(has_initial_state_, batch_size);
}
ConvParamsBase params;
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
bias_.has_value() ? bias_.value().data_ptr() : nullptr,
silu_activation);
if (seq_idx_.has_value()) { if (query_start_loc.has_value()) {
params.seq_idx_ptr = seq_idx_.value().data_ptr(); auto query_start_loc_ = query_start_loc.value();
} else { TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int);
params.seq_idx_ptr = nullptr; TORCH_CHECK(query_start_loc_.is_cuda());
} }
if (initial_states_.has_value()) {
TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout"); if (cache_indices.has_value()) {
auto initial_states = initial_states_.value(); auto cache_indices_ = cache_indices.value();
TORCH_CHECK(initial_states.scalar_type() == input_type); TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int);
TORCH_CHECK(initial_states.is_cuda()); TORCH_CHECK(cache_indices_.is_cuda());
CHECK_SHAPE(initial_states, batch_size, dim, width - 1); CHECK_SHAPE(cache_indices_, batch_size);
TORCH_CHECK(initial_states.stride(1) == 1);
params.initial_states_ptr = initial_states.data_ptr();
params.initial_states_batch_stride = initial_states.stride(0);
params.initial_states_c_stride = initial_states.stride(1);
params.initial_states_l_stride = initial_states.stride(2);
} else {
params.initial_states_ptr = nullptr;
} }
if (final_states_out_.has_value()) { at::Tensor out = x;
TORCH_CHECK(is_channel_last, "final_states is only supported for channel last layout");
auto final_states = final_states_out_.value(); ConvParamsBase params;
TORCH_CHECK(final_states.scalar_type() == input_type); set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
TORCH_CHECK(final_states.is_cuda()); bias_,
CHECK_SHAPE(final_states, batch_size, dim, width - 1); silu_activation,
TORCH_CHECK(final_states.stride(1) == 1); pad_slot_id,
params.final_states_ptr = final_states.data_ptr(); query_start_loc,
params.final_states_batch_stride = final_states.stride(0); cache_indices,
params.final_states_c_stride = final_states.stride(1); has_initial_state
params.final_states_l_stride = final_states.stride(2); );
if (conv_states.has_value()) {
auto conv_states_ = conv_states.value();
TORCH_CHECK(conv_states_.scalar_type() == input_type);
TORCH_CHECK(conv_states_.is_cuda());
params.conv_states_ptr = conv_states_.data_ptr();
params.conv_states_batch_stride = conv_states_.stride(0);
params.conv_states_c_stride = conv_states_.stride(1);
params.conv_states_l_stride = conv_states_.stride(2);
} else { } else {
params.final_states_ptr = nullptr; params.conv_states_ptr = nullptr;
} }
// Otherwise the kernel will be launched from cuda:0 device // Otherwise the kernel will be launched from cuda:0 device
...@@ -183,23 +186,21 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, ...@@ -183,23 +186,21 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
at::cuda::CUDAGuard device_guard{(char)x.get_device()}; at::cuda::CUDAGuard device_guard{(char)x.get_device()};
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
if (!is_channel_last) { causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
} else {
causal_conv1d_channellast_fwd_cuda<input_t, weight_t>(params, stream);
}
}); });
return out;
} }
at::Tensor void causal_conv1d_update(const at::Tensor &x,
causal_conv1d_update(const at::Tensor &x,
const at::Tensor &conv_state, const at::Tensor &conv_state,
const at::Tensor &weight, const at::Tensor &weight,
const c10::optional<at::Tensor> &bias_, const c10::optional<at::Tensor> &bias_,
bool silu_activation, bool silu_activation,
const c10::optional<at::Tensor> &conv_state_indices_) { const c10::optional<at::Tensor> &cache_seqlens_,
const c10::optional<at::Tensor> &conv_state_indices_,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t pad_slot_id) {
auto input_type = x.scalar_type(); auto input_type = x.scalar_type();
auto weight_type = weight.scalar_type(); auto weight_type = weight.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
...@@ -214,9 +215,12 @@ causal_conv1d_update(const at::Tensor &x, ...@@ -214,9 +215,12 @@ causal_conv1d_update(const at::Tensor &x,
const auto sizes = x.sizes(); const auto sizes = x.sizes();
const int batch_size = sizes[0]; const int batch_size = sizes[0];
const int dim = sizes[1]; const int dim = sizes[1];
const int seqlen = sizes[2];
const int width = weight.size(-1); const int width = weight.size(-1);
const int conv_state_len = conv_state.size(2);
TORCH_CHECK(conv_state_len >= width - 1);
CHECK_SHAPE(x, batch_size, dim); CHECK_SHAPE(x, batch_size, dim, seqlen);
CHECK_SHAPE(weight, dim, width); CHECK_SHAPE(weight, dim, width);
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
...@@ -229,18 +233,31 @@ causal_conv1d_update(const at::Tensor &x, ...@@ -229,18 +233,31 @@ causal_conv1d_update(const at::Tensor &x,
CHECK_SHAPE(bias, dim); CHECK_SHAPE(bias, dim);
} }
at::Tensor out = torch::empty_like(x); at::Tensor out = x;
ConvParamsBase params; ConvParamsBase params;
set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out, set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
bias_.has_value() ? bias_.value().data_ptr() : nullptr, bias_,
silu_activation); silu_activation,
pad_slot_id);
params.conv_state_ptr = conv_state.data_ptr(); params.conv_state_ptr = conv_state.data_ptr();
params.conv_state_len = conv_state_len;
// All stride are in elements, not bytes. // All stride are in elements, not bytes.
params.conv_state_batch_stride = conv_state.stride(0); params.conv_state_batch_stride = conv_state.stride(0);
params.conv_state_c_stride = conv_state.stride(1); params.conv_state_c_stride = conv_state.stride(1);
params.conv_state_l_stride = conv_state.stride(2); params.conv_state_l_stride = conv_state.stride(2);
if (cache_seqlens_.has_value()) {
auto cache_seqlens = cache_seqlens_.value();
TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32);
TORCH_CHECK(cache_seqlens.is_cuda());
TORCH_CHECK(cache_seqlens.stride(-1) == 1);
CHECK_SHAPE(cache_seqlens, batch_size);
params.cache_seqlens = cache_seqlens.data_ptr<int32_t>();
} else {
params.cache_seqlens = nullptr;
}
if (conv_state_indices_.has_value()) { if (conv_state_indices_.has_value()) {
auto conv_state_indices = conv_state_indices_.value(); auto conv_state_indices = conv_state_indices_.value();
TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32) TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32)
...@@ -249,11 +266,11 @@ causal_conv1d_update(const at::Tensor &x, ...@@ -249,11 +266,11 @@ causal_conv1d_update(const at::Tensor &x,
CHECK_SHAPE(conv_state_indices, batch_size); CHECK_SHAPE(conv_state_indices, batch_size);
int conv_state_entries = conv_state.size(0); int conv_state_entries = conv_state.size(0);
CHECK_SHAPE(conv_state, conv_state_entries, dim, width); CHECK_SHAPE(conv_state, conv_state_entries, dim, conv_state_len);
params.conv_state_indices_ptr = conv_state_indices.data_ptr<int32_t>(); params.conv_state_indices_ptr = conv_state_indices.data_ptr<int32_t>();
} else { } else {
CHECK_SHAPE(conv_state, batch_size, dim, width); CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len);
params.conv_state_indices_ptr = nullptr; params.conv_state_indices_ptr = nullptr;
} }
...@@ -264,7 +281,6 @@ causal_conv1d_update(const at::Tensor &x, ...@@ -264,7 +281,6 @@ causal_conv1d_update(const at::Tensor &x,
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
causal_conv1d_update_cuda<input_t, weight_t>(params, stream); causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
}); });
return out;
} }
template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_> template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
...@@ -296,7 +312,7 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { ...@@ -296,7 +312,7 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
constexpr int kWidth = Ktraits::kWidth; constexpr int kWidth = Ktraits::kWidth;
constexpr int kNThreads = Ktraits::kNThreads; constexpr int kNThreads = Ktraits::kNThreads;
constexpr int kNElts = Ktraits::kNElts; constexpr int kNElts = Ktraits::kNElts;
static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
using input_t = typename Ktraits::input_t; using input_t = typename Ktraits::input_t;
using vec_t = typename Ktraits::vec_t; using vec_t = typename Ktraits::vec_t;
using weight_t = typename Ktraits::weight_t; using weight_t = typename Ktraits::weight_t;
...@@ -309,20 +325,42 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { ...@@ -309,20 +325,42 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_); auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize); vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
const bool kVarlen = params.query_start_loc_ptr != nullptr;
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
const int batch_id = blockIdx.x; const int batch_id = blockIdx.x;
const int channel_id = blockIdx.y; const int channel_id = blockIdx.y;
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride const int *query_start_loc = kVarlen ? reinterpret_cast<int *>(params.query_start_loc_ptr) : nullptr;
const int sequence_start_index = kVarlen ? query_start_loc[batch_id] : batch_id;
const int seqlen = kVarlen ? query_start_loc[batch_id + 1] - sequence_start_index : params.seqlen;
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + sequence_start_index * params.x_batch_stride
+ channel_id * params.x_c_stride; + channel_id * params.x_c_stride;
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride; weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
+ channel_id * params.out_c_stride; + channel_id * params.out_c_stride;
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]); float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
bool has_initial_state = params.has_initial_state_ptr == nullptr ? false
: reinterpret_cast<bool *>(params.has_initial_state_ptr)[batch_id];
int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
: reinterpret_cast<int *>(params.cache_indices_ptr);
int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
// cache_index == params.pad_slot_id is defined as padding, so we exit early
if (cache_index == params.pad_slot_id){
return;
}
input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr
: reinterpret_cast<input_t *>(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride;
// Thread 0 will load the last elements of the previous chunk, so we initialize those to 0. // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
if (tidx == 0) { if (tidx == 0) {
input_t zeros[kNElts] = {0}; input_t initial_state[kNElts] = {0};
smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(zeros)[0]; if (has_initial_state) {
#pragma unroll
for (int w = 0; w < kWidth - 1; ++w){ initial_state[kNElts - 1 - (kWidth - 2) + w ] = conv_states[w]; }
}
smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(initial_state)[0];
} }
float weight_vals[kWidth]; float weight_vals[kWidth];
...@@ -330,14 +368,14 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { ...@@ -330,14 +368,14 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
constexpr int kChunkSize = kNThreads * kNElts; constexpr int kChunkSize = kNThreads * kNElts;
const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize; const int n_chunks = (seqlen + kChunkSize - 1) / kChunkSize;
for (int chunk = 0; chunk < n_chunks; ++chunk) { for (int chunk = 0; chunk < n_chunks; ++chunk) {
input_t x_vals_load[2 * kNElts] = {0}; input_t x_vals_load[2 * kNElts] = {0};
if constexpr(kIsVecLoad) { if constexpr(kIsVecLoad) {
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts); typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (seqlen - chunk * kChunkSize) / kNElts);
} else { } else {
__syncthreads(); __syncthreads();
typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize); typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), seqlen - chunk * kChunkSize);
} }
x += kChunkSize; x += kChunkSize;
__syncthreads(); __syncthreads();
...@@ -375,11 +413,78 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { ...@@ -375,11 +413,78 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
#pragma unroll #pragma unroll
for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; } for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
if constexpr(kIsVecLoad) { if constexpr(kIsVecLoad) {
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts); typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (seqlen - chunk * kChunkSize) / kNElts);
} else { } else {
typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize); typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, seqlen - chunk * kChunkSize);
} }
out += kChunkSize; out += kChunkSize;
int final_state_position = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize);
// in case the final state is separated between the last "smem_exchange" and
// and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
// (which occurs when `final_state_position` is a non-positivie index)
// we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
if (conv_states != nullptr && final_state_position < 0 && seqlen > kWidth){
input_t vals_load[kNElts] = {0};
if ((chunk == n_chunks - 2) && (tidx == kNThreads - 1)){
// chunk = n_chunks - 2, a segment of the final state sits in the last index
reinterpret_cast<vec_t *>(vals_load)[0] = smem_exchange[kNThreads - 1];
#pragma unroll
for (int w = 0; w < -final_state_position; ++w){
conv_states[w] = vals_load[kNElts + final_state_position + w];
}
}
if ((chunk == n_chunks - 1) && tidx == 0){
// chunk = n_chunks - 1, the second segment of the final state first positions
reinterpret_cast<vec_t *>(vals_load)[0] = smem_exchange[0];
for (int w = -final_state_position; w < kWidth - 1; ++w){
conv_states[w] = vals_load[w + final_state_position];
}
return;
}
}
}
// Final state is stored in the smem_exchange last token slot,
// in case seqlen < kWidth, we would need to take the final state from the
// initial state which is stored in conv_states
// in case seqlen > kWidth, we would need to load the last kWidth - 1 data
// and load it into conv_state accordingly
int last_thread = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize) / kNElts;
if (conv_states != nullptr && tidx == last_thread) {
input_t x_vals_load[kNElts * 2] = {0};
// in case we are on the first kWidth tokens
if (last_thread == 0 && seqlen < kWidth){
// Need to take the initial state
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[0];
const int offset = seqlen - (kWidth - 1);
#pragma unroll
for (int w = 0; w < kWidth - 1; ++w){
// pad the existing state
if ((w - seqlen) >= 0 && has_initial_state) { conv_states[w - seqlen] = conv_states[w]; }
else if ((w - seqlen) >= 0 && !has_initial_state) { conv_states[w - seqlen] = input_t(0.0f); }
}
#pragma unroll
for (int w = 0; w < kWidth - 1; ++w){
if (offset + w >= 0)
conv_states[w] = x_vals_load[offset + w ];
}
}
else {
// in case the final state is in between the threads data
const int offset = ((seqlen - (kWidth - 1)) % (kNElts));
if ((offset + kWidth - 2) >= kNElts && (last_thread + 1 < kNThreads)){
// In case last_thread == kNThreads - 1, accessing last_thread + 1 will result in a
// illegal access error on H100.
// Therefore, we access last_thread + 1, only if the final state data sits there
reinterpret_cast<vec_t *>(x_vals_load)[1] = smem_exchange[last_thread + 1];
}
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[last_thread];
#pragma unroll
for (int w = 0; w < kWidth - 1; ++w){
conv_states[w] = x_vals_load[offset + w ];
}
}
} }
} }
...@@ -387,7 +492,8 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { ...@@ -387,7 +492,8 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
template<int kNThreads, int kWidth, typename input_t, typename weight_t> template<int kNThreads, int kWidth, typename input_t, typename weight_t>
void causal_conv1d_fwd_launch(ConvParamsBase &params, cudaStream_t stream) { void causal_conv1d_fwd_launch(ConvParamsBase &params, cudaStream_t stream) {
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] { const bool kVarlen = params.query_start_loc_ptr != nullptr;
BOOL_SWITCH(params.seqlen % kNElts == 0 && !kVarlen, kIsVecLoad, [&] {
using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>; using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
constexpr int kSmemSize = Ktraits::kSmemSize; constexpr int kSmemSize = Ktraits::kSmemSize;
dim3 grid(params.batch, params.dim); dim3 grid(params.batch, params.dim);
...@@ -422,220 +528,11 @@ void causal_conv1d_fwd_cuda(ConvParamsBase &params, cudaStream_t stream) { ...@@ -422,220 +528,11 @@ void causal_conv1d_fwd_cuda(ConvParamsBase &params, cudaStream_t stream) {
} }
} }
template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
struct Causal_conv1d_channellast_fwd_kernel_traits {
// The cache line is 128 bytes, and we try to read 16 bytes per thread.
// So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
// That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
// threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
using input_t = input_t_;
using weight_t = weight_t_;
static constexpr int kNThreads = kNThreads_;
static_assert(kNThreads % 32 == 0);
static constexpr int kNWarps = kNThreads / 32;
static constexpr int kWidth = kWidth_;
static constexpr int kChunkSizeL = kChunkSizeL_;
static constexpr int kNBytes = sizeof(input_t);
static_assert(kNBytes == 2 || kNBytes == 4);
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
static constexpr int kNEltsPerRow = 128 / kNBytes;
static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
static constexpr bool kIsVecLoad = kIsVecLoad_;
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
// using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
// using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
// static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
// sizeof(typename BlockStoreT::TempStorage)});
// static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
};
template<typename Ktraits, bool kHasSeqIdx>
__global__ __launch_bounds__(Ktraits::kNThreads)
void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) {
constexpr int kWidth = Ktraits::kWidth;
constexpr int kNThreads = Ktraits::kNThreads;
constexpr int kNElts = Ktraits::kNElts;
constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
using input_t = typename Ktraits::input_t;
using vec_t = typename Ktraits::vec_t;
using weight_t = typename Ktraits::weight_t;
// Shared memory.
__shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts];
const int batch_id = blockIdx.x;
const int chunk_l_id = blockIdx.y;
const int chunk_c_id = blockIdx.z;
const int tid = threadIdx.x;
const int l_idx = tid / kNThreadsPerC;
const int c_idx = tid % kNThreadsPerC;
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
+ (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
+ chunk_c_id * kChunkSizeC * params.weight_c_stride;
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
+ (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast<int *>(params.seq_idx_ptr)
+ batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
: reinterpret_cast<input_t *>(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
// The last L-chunk will also have enough info to write to final states, since it also contain a few x values
// from the previous L-chunk.
input_t *final_states = params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 ? nullptr
: reinterpret_cast<input_t *>(params.final_states_ptr) + batch_id * params.final_states_batch_stride + l_idx * params.final_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
#pragma unroll
for (int l = 0; l < Ktraits::kNLoads; ++l) {
input_t x_vals_load[kNElts] = {0};
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
}
reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
}
// Load the elements from the previous chunk that are needed for convolution.
if (l_idx < kWidth - 1) {
input_t x_vals_load[kNElts] = {0};
if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
} else if (initial_states != nullptr
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(initial_states);
}
reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
}
__syncthreads();
if (final_states != nullptr
&& l_idx < kWidth - 1
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
// x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth - 1)
// So last few elements (index params.seqlen - kWidth + 1 + l_idx) are stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id * kChunkSizeL - kWidth + 1)][c_idx]
*reinterpret_cast<vec_t *>(final_states) = reinterpret_cast<vec_t *>(x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx];
}
constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
// kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
static_assert((kLPerThread & (kLPerThread - 1)) == 0);
static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
static_assert(kNThreadsPerRow <= 32);
const int row_idx = tid / kNThreadsPerRow;
const int col_idx = tid % kNThreadsPerRow;
float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
float weight_vals[kWidth] = {0};
if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
#pragma unroll
for (int w = 0; w < kWidth; ++w) {
weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
}
}
float x_vals[kWidth - 1 + kLPerThread];
#pragma unroll
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
}
int seq_idx_thread[kWidth - 1 + kLPerThread];
if constexpr (kHasSeqIdx) {
#pragma unroll
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
seq_idx_thread[i] = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= 0 ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1;
}
}
float out_vals[kLPerThread];
#pragma unroll
for (int i = 0; i < kLPerThread; ++i) {
out_vals[i] = bias_val;
const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
#pragma unroll
for (int w = 0; w < kWidth; ++w) {
if constexpr (!kHasSeqIdx) {
out_vals[i] += weight_vals[w] * x_vals[i + w];
} else {
out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f;
}
}
if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); }
}
__syncthreads();
#pragma unroll
for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; }
__syncthreads();
#pragma unroll
for (int l = 0; l < Ktraits::kNLoads; ++l) {
input_t out_vals_store[kNElts];
reinterpret_cast<vec_t *>(out_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l * kLPerLoad + l_idx])[c_idx];
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
*reinterpret_cast<vec_t *>(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast<vec_t *>(out_vals_store)[0];
}
}
}
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
void causal_conv1d_channellast_fwd_launch(ConvParamsBase &params, cudaStream_t stream) {
BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] {
using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits<kNThreads, kWidth, 64, true, input_t, weight_t>;
// constexpr int kSmemSize = Ktraits::kSmemSize;
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
dim3 grid(params.batch, n_chunks_L, n_chunks_C);
dim3 block(Ktraits::kNThreads);
auto kernel = &causal_conv1d_channellast_fwd_kernel<Ktraits, kHasSeqIdx>;
// if (kSmemSize >= 48 * 1024) {
// C10_CUDA_CHECK(cudaFuncSetAttribute(
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
// }
// kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
template<typename input_t, typename weight_t>
void causal_conv1d_channellast_fwd_cuda(ConvParamsBase &params, cudaStream_t stream) {
if (params.width == 2) {
causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream);
} else if (params.width == 3) {
causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream);
} else if (params.width == 4) {
causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream);
}
}
template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream); template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream); template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream); template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
template void causal_conv1d_channellast_fwd_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
template void causal_conv1d_channellast_fwd_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
///////
...@@ -649,7 +546,7 @@ struct Causal_conv1d_update_kernel_traits { ...@@ -649,7 +546,7 @@ struct Causal_conv1d_update_kernel_traits {
static_assert(kNBytes == 2 || kNBytes == 4); static_assert(kNBytes == 2 || kNBytes == 4);
}; };
template<typename Ktraits> template<typename Ktraits, bool kIsCircularBuffer>
__global__ __launch_bounds__(Ktraits::kNThreads) __global__ __launch_bounds__(Ktraits::kNThreads)
void causal_conv1d_update_kernel(ConvParamsBase params) { void causal_conv1d_update_kernel(ConvParamsBase params) {
constexpr int kWidth = Ktraits::kWidth; constexpr int kWidth = Ktraits::kWidth;
...@@ -660,6 +557,8 @@ void causal_conv1d_update_kernel(ConvParamsBase params) { ...@@ -660,6 +557,8 @@ void causal_conv1d_update_kernel(ConvParamsBase params) {
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
const int batch_id = blockIdx.x; const int batch_id = blockIdx.x;
const int channel_id = blockIdx.y * kNThreads + tidx; const int channel_id = blockIdx.y * kNThreads + tidx;
if (channel_id >= params.dim) return;
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
+ channel_id * params.x_c_stride; + channel_id * params.x_c_stride;
...@@ -668,6 +567,10 @@ void causal_conv1d_update_kernel(ConvParamsBase params) { ...@@ -668,6 +567,10 @@ void causal_conv1d_update_kernel(ConvParamsBase params) {
const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
? batch_id ? batch_id
: params.conv_state_indices_ptr[batch_id]; : params.conv_state_indices_ptr[batch_id];
// conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early
if (conv_state_batch_coord == params.pad_slot_id){
return;
}
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr) input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
+ conv_state_batch_coord * params.conv_state_batch_stride + conv_state_batch_coord * params.conv_state_batch_stride
+ channel_id * params.conv_state_c_stride; + channel_id * params.conv_state_c_stride;
...@@ -675,35 +578,70 @@ void causal_conv1d_update_kernel(ConvParamsBase params) { ...@@ -675,35 +578,70 @@ void causal_conv1d_update_kernel(ConvParamsBase params) {
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride; weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
+ channel_id * params.out_c_stride; + channel_id * params.out_c_stride;
float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]); float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
int state_len = params.conv_state_len;
int advance_len = params.seqlen;
int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0;
int update_idx = cache_seqlen - (kWidth - 1);
update_idx = update_idx < 0 ? update_idx + state_len : update_idx;
float weight_vals[kWidth] = {0}; float weight_vals[kWidth] = {0};
if (channel_id < params.dim) { #pragma unroll
#pragma unroll for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
}
float x_vals[kWidth] = {0}; float x_vals[kWidth] = {0};
if (channel_id < params.dim) { if constexpr (!kIsCircularBuffer) {
#pragma unroll 2
for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) {
conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride];
}
#pragma unroll
for (int i = 0; i < kWidth - 1; ++i) {
input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride];
if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) {
conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val;
}
x_vals[i] = float(state_val);
}
} else {
#pragma unroll #pragma unroll
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); } for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) {
x_vals[kWidth - 1] = float(x[0]); input_t state_val = conv_state[update_idx * params.conv_state_l_stride];
x_vals[i] = float(state_val);
}
}
#pragma unroll 2
for (int i = 0; i < params.seqlen; ++i) {
input_t x_val = x[i * params.x_l_stride];
if constexpr (!kIsCircularBuffer) {
if (i < advance_len && state_len - advance_len + i >= 0) {
conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val;
}
} else {
conv_state[update_idx * params.conv_state_l_stride] = x_val;
++update_idx;
update_idx = update_idx >= state_len ? update_idx - state_len : update_idx;
}
x_vals[kWidth - 1] = float(x_val);
float out_val = bias_val;
#pragma unroll
for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; }
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
out[i * params.out_l_stride] = input_t(out_val);
// Shift the input buffer by 1
#pragma unroll #pragma unroll
for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); } for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; }
} }
float out_val = bias_val;
#pragma unroll
for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; }
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
if (channel_id < params.dim) { out[0] = input_t(out_val); }
} }
template<int kNThreads, int kWidth, typename input_t, typename weight_t> template<int kNThreads, int kWidth, typename input_t, typename weight_t>
void causal_conv1d_update_launch(ConvParamsBase &params, cudaStream_t stream) { void causal_conv1d_update_launch(ConvParamsBase &params, cudaStream_t stream) {
using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>; using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
auto kernel = &causal_conv1d_update_kernel<Ktraits>; auto kernel = params.cache_seqlens == nullptr
? &causal_conv1d_update_kernel<Ktraits, false>
: &causal_conv1d_update_kernel<Ktraits, true>;
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params); kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
} }
......
...@@ -13,6 +13,7 @@ struct ConvParamsBase { ...@@ -13,6 +13,7 @@ struct ConvParamsBase {
using index_t = uint32_t; using index_t = uint32_t;
int batch, dim, seqlen, width; int batch, dim, seqlen, width;
int64_t pad_slot_id;
bool silu_activation; bool silu_activation;
index_t x_batch_stride; index_t x_batch_stride;
...@@ -24,6 +25,7 @@ struct ConvParamsBase { ...@@ -24,6 +25,7 @@ struct ConvParamsBase {
index_t out_c_stride; index_t out_c_stride;
index_t out_l_stride; index_t out_l_stride;
int conv_state_len;
index_t conv_state_batch_stride; index_t conv_state_batch_stride;
index_t conv_state_c_stride; index_t conv_state_c_stride;
index_t conv_state_l_stride; index_t conv_state_l_stride;
...@@ -35,6 +37,10 @@ struct ConvParamsBase { ...@@ -35,6 +37,10 @@ struct ConvParamsBase {
void *__restrict__ out_ptr; void *__restrict__ out_ptr;
void *__restrict__ conv_state_ptr; void *__restrict__ conv_state_ptr;
void *__restrict__ query_start_loc_ptr;
void *__restrict__ has_initial_state_ptr;
void *__restrict__ cache_indices_ptr;
int32_t *__restrict__ cache_seqlens;
// For the continuous batching case. Makes it so that the mamba state for // For the continuous batching case. Makes it so that the mamba state for
// the current batch doesn't need to be a contiguous tensor. // the current batch doesn't need to be a contiguous tensor.
...@@ -52,6 +58,11 @@ struct ConvParamsBase { ...@@ -52,6 +58,11 @@ struct ConvParamsBase {
index_t final_states_batch_stride; index_t final_states_batch_stride;
index_t final_states_l_stride; index_t final_states_l_stride;
index_t final_states_c_stride; index_t final_states_c_stride;
void * conv_states_ptr;
index_t conv_states_batch_stride;
index_t conv_states_l_stride;
index_t conv_states_c_stride;
}; };
......
...@@ -21,6 +21,7 @@ struct SSMParamsBase { ...@@ -21,6 +21,7 @@ struct SSMParamsBase {
int dim_ngroups_ratio; int dim_ngroups_ratio;
bool is_variable_B; bool is_variable_B;
bool is_variable_C; bool is_variable_C;
int64_t pad_slot_id;
bool delta_softplus; bool delta_softplus;
...@@ -54,10 +55,14 @@ struct SSMParamsBase { ...@@ -54,10 +55,14 @@ struct SSMParamsBase {
void *__restrict__ delta_ptr; void *__restrict__ delta_ptr;
void *__restrict__ delta_bias_ptr; void *__restrict__ delta_bias_ptr;
void *__restrict__ out_ptr; void *__restrict__ out_ptr;
void *__restrict__ x_ptr; void *__restrict__ ssm_states_ptr;
void *__restrict__ z_ptr; void *__restrict__ z_ptr;
void *__restrict__ out_z_ptr; void *__restrict__ out_z_ptr;
void *__restrict__ index_ptr;
void *__restrict__ query_start_loc_ptr;
void *__restrict__ cache_indices_ptr;
void *__restrict__ has_initial_state_ptr;
}; };
...@@ -201,7 +206,7 @@ inline __device__ void load_input(typename Ktraits::input_t *u, ...@@ -201,7 +206,7 @@ inline __device__ void load_input(typename Ktraits::input_t *u,
typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
typename Ktraits::BlockLoadT::TempStorage &smem_load, typename Ktraits::BlockLoadT::TempStorage &smem_load,
int seqlen) { int seqlen) {
if constexpr (Ktraits::kIsEvenLen) { if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) {
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load); auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);
using vec_t = typename Ktraits::vec_t; using vec_t = typename Ktraits::vec_t;
typename Ktraits::BlockLoadVecT(smem_load_vec).Load( typename Ktraits::BlockLoadVecT(smem_load_vec).Load(
...@@ -217,21 +222,6 @@ inline __device__ void load_input(typename Ktraits::input_t *u, ...@@ -217,21 +222,6 @@ inline __device__ void load_input(typename Ktraits::input_t *u,
} }
} }
template<typename Ktraits>
inline __device__ void load_index(int *u,
int (&u_vals)[Ktraits::kNItems],
typename Ktraits::BlockLoadIndexT::TempStorage &smem_load_index,
int seqlen) {
if constexpr (Ktraits::kIsEvenLen) {
auto& smem_load_index_vec = reinterpret_cast<typename Ktraits::BlockLoadIndexVecT::TempStorage&>(smem_load_index);
Ktraits::BlockLoadIndexVecT(smem_load_index_vec).Load(
reinterpret_cast<uint4*>(u),
reinterpret_cast<uint4(&)[Ktraits::kNLoadsIndex]>(u_vals)
);
} else {
Ktraits::BlockLoadIndexT(smem_load_index).Load(u, u_vals, seqlen, 0);
}
}
template<typename Ktraits> template<typename Ktraits>
inline __device__ void load_weight(typename Ktraits::input_t *Bvar, inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
...@@ -240,7 +230,7 @@ inline __device__ void load_weight(typename Ktraits::input_t *Bvar, ...@@ -240,7 +230,7 @@ inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
int seqlen) { int seqlen) {
constexpr int kNItems = Ktraits::kNItems; constexpr int kNItems = Ktraits::kNItems;
typename Ktraits::input_t B_vals_load[kNItems]; typename Ktraits::input_t B_vals_load[kNItems];
if constexpr (Ktraits::kIsEvenLen) { if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) {
auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight); auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
using vec_t = typename Ktraits::vec_t; using vec_t = typename Ktraits::vec_t;
typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
...@@ -263,7 +253,7 @@ inline __device__ void store_output(typename Ktraits::input_t *out, ...@@ -263,7 +253,7 @@ inline __device__ void store_output(typename Ktraits::input_t *out,
typename Ktraits::input_t write_vals[Ktraits::kNItems]; typename Ktraits::input_t write_vals[Ktraits::kNItems];
#pragma unroll #pragma unroll
for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
if constexpr (Ktraits::kIsEvenLen) { if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) {
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store); auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);
using vec_t = typename Ktraits::vec_t; using vec_t = typename Ktraits::vec_t;
typename Ktraits::BlockStoreVecT(smem_store_vec).Store( typename Ktraits::BlockStoreVecT(smem_store_vec).Store(
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_, template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,
bool kIsVariableB_, bool kIsVariableC_, bool kIsVariableB_, bool kIsVariableC_,
bool kHasZ_, bool kUseIndex_, typename input_t_, typename weight_t_> bool kHasZ_, bool kVarlen_, typename input_t_, typename weight_t_>
struct Selective_Scan_fwd_kernel_traits { struct Selective_Scan_fwd_kernel_traits {
static_assert(kNItems_ % 4 == 0); static_assert(kNItems_ % 4 == 0);
using input_t = input_t_; using input_t = input_t_;
...@@ -38,22 +38,19 @@ struct Selective_Scan_fwd_kernel_traits { ...@@ -38,22 +38,19 @@ struct Selective_Scan_fwd_kernel_traits {
static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems); static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems);
static_assert(kNItems % kNElts == 0); static_assert(kNItems % kNElts == 0);
static constexpr int kNLoads = kNItems / kNElts; static constexpr int kNLoads = kNItems / kNElts;
static constexpr bool kIsEvenLen = kIsEvenLen_; static constexpr bool kIsEvenLen = kVarlen_ ? false : kIsEvenLen_;
static constexpr bool kIsVariableB = kIsVariableB_; static constexpr bool kIsVariableB = kIsVariableB_;
static constexpr bool kIsVariableC = kIsVariableC_; static constexpr bool kIsVariableC = kIsVariableC_;
static constexpr bool kHasZ = kHasZ_; static constexpr bool kHasZ = kHasZ_;
static constexpr bool kUseIndex = kUseIndex_; static constexpr bool kVarlen = kVarlen_;
static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; static constexpr bool kDirectIO = kVarlen_ ? false : kIsEvenLen && kNLoads == 1;
static constexpr int kNLoadsIndex = kNItems / 4; static constexpr int kNLoadsIndex = kNItems / 4;
using vec_t = typename BytesToType<kNBytes * kNElts>::Type; using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
using scan_t = float2; using scan_t = float2;
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>; using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>; !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
using BlockLoadIndexT = cub::BlockLoad<int, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
using BlockLoadIndexVecT = cub::BlockLoad<uint4, kNThreads, kNLoadsIndex,
!(kIsEvenLen && kNLoadsIndex == 1) ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems , cub::BLOCK_LOAD_WARP_TRANSPOSE>; using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems , cub::BLOCK_LOAD_WARP_TRANSPOSE>;
using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads , using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads ,
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>; !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
...@@ -65,8 +62,6 @@ struct Selective_Scan_fwd_kernel_traits { ...@@ -65,8 +62,6 @@ struct Selective_Scan_fwd_kernel_traits {
using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>; using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage), static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage),
sizeof(typename BlockLoadVecT::TempStorage), sizeof(typename BlockLoadVecT::TempStorage),
sizeof(typename BlockLoadIndexT::TempStorage),
sizeof(typename BlockLoadIndexVecT::TempStorage),
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
sizeof(typename BlockStoreT::TempStorage), sizeof(typename BlockStoreT::TempStorage),
...@@ -80,7 +75,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { ...@@ -80,7 +75,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
constexpr bool kIsVariableB = Ktraits::kIsVariableB; constexpr bool kIsVariableB = Ktraits::kIsVariableB;
constexpr bool kIsVariableC = Ktraits::kIsVariableC; constexpr bool kIsVariableC = Ktraits::kIsVariableC;
constexpr bool kHasZ = Ktraits::kHasZ; constexpr bool kHasZ = Ktraits::kHasZ;
constexpr bool kUseIndex = Ktraits::kUseIndex; constexpr bool kVarlen = Ktraits::kVarlen;
constexpr int kNThreads = Ktraits::kNThreads; constexpr int kNThreads = Ktraits::kNThreads;
constexpr int kNItems = Ktraits::kNItems; constexpr int kNItems = Ktraits::kNItems;
constexpr int kNRows = Ktraits::kNRows; constexpr int kNRows = Ktraits::kNRows;
...@@ -97,7 +92,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { ...@@ -97,7 +92,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan); // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_); auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_); auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
auto& smem_load_index = reinterpret_cast<typename Ktraits::BlockLoadIndexT::TempStorage&>(smem_);
auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_); auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize); auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
...@@ -108,17 +102,33 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { ...@@ -108,17 +102,33 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
const int batch_id = blockIdx.x; const int batch_id = blockIdx.x;
const int dim_id = blockIdx.y; const int dim_id = blockIdx.y;
const int group_id = dim_id / (params.dim_ngroups_ratio); const int group_id = dim_id / (params.dim_ngroups_ratio);
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride int seqlen = params.seqlen;
int sequence_start_index = batch_id;
if constexpr (kVarlen){
int *query_start_loc = reinterpret_cast<int *>(params.query_start_loc_ptr);
sequence_start_index = query_start_loc[batch_id];
seqlen = query_start_loc[batch_id + 1] - sequence_start_index;
}
const bool has_initial_state = params.has_initial_state_ptr == nullptr ? false
: reinterpret_cast<bool *>(params.has_initial_state_ptr)[batch_id];
const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
: reinterpret_cast<int *>(params.cache_indices_ptr);
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
// cache_index == params.pad_slot_id is defined as padding, so we exit early
if (cache_index == params.pad_slot_id){
return;
}
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + sequence_start_index * params.u_batch_stride
+ dim_id * kNRows * params.u_d_stride; + dim_id * kNRows * params.u_d_stride;
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + sequence_start_index * params.delta_batch_stride
+ dim_id * kNRows * params.delta_d_stride; + dim_id * kNRows * params.delta_d_stride;
weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * kNRows * params.A_d_stride; weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * kNRows * params.A_d_stride;
weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * kNRows * params.B_d_stride; weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * kNRows * params.B_d_stride;
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride;
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride; weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride;
scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate; input_t *ssm_states = reinterpret_cast<input_t *>(params.ssm_states_ptr) + (cache_index * params.dim + dim_id * kNRows) * params.dstate;
int *index = !kUseIndex ? nullptr :reinterpret_cast<int *>(params.index_ptr) + batch_id * params.seqlen;
float D_val[kNRows] = {0}; float D_val[kNRows] = {0};
if (params.D_ptr != nullptr) { if (params.D_ptr != nullptr) {
...@@ -142,9 +152,9 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { ...@@ -142,9 +152,9 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
// } // }
constexpr int kChunkSize = kNThreads * kNItems; constexpr int kChunkSize = kNThreads * kNItems;
for (int chunk = 0; chunk < params.n_chunks; ++chunk) { const int n_chunks = (seqlen + 2048 - 1) / 2048;
for (int chunk = 0; chunk < n_chunks; ++chunk) {
input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
int index_vals_load[kNRows][kNItems];
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
...@@ -152,15 +162,9 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { ...@@ -152,15 +162,9 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if constexpr (!kDirectIO) { if constexpr (!kDirectIO) {
if (r > 0) { __syncthreads(); } if (r > 0) { __syncthreads(); }
} }
load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize); load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, seqlen - chunk * kChunkSize);
if constexpr (!kDirectIO) { __syncthreads(); } if constexpr (!kDirectIO) { __syncthreads(); }
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, seqlen - chunk * kChunkSize);
if constexpr (kUseIndex) {
load_index<Ktraits>(index + r * params.delta_d_stride, index_vals_load[r], smem_load_index, params.seqlen - chunk * kChunkSize);
}
}
if constexpr (kUseIndex) {
index += kChunkSize;
} }
u += kChunkSize; u += kChunkSize;
delta += kChunkSize; delta += kChunkSize;
...@@ -195,9 +199,9 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { ...@@ -195,9 +199,9 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
// If both B and C vary, this is unused. // If both B and C vary, this is unused.
weight_t BC_val[kNRows]; weight_t BC_val[kNRows];
weight_t B_vals[kNItems], C_vals[kNItems]; weight_t B_vals[kNItems], C_vals[kNItems];
if constexpr (kIsVariableB) { if constexpr (kIsVariableB) {
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals, load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
smem_load_weight, (params.seqlen - chunk * kChunkSize) * (1)); smem_load_weight, (seqlen - chunk * kChunkSize) * (1));
if constexpr (!kIsVariableC) { if constexpr (!kIsVariableC) {
#pragma unroll #pragma unroll
for (int r = 0; r < kNRows; ++r) { for (int r = 0; r < kNRows; ++r) {
...@@ -208,7 +212,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { ...@@ -208,7 +212,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if constexpr (kIsVariableC) { if constexpr (kIsVariableC) {
auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals, load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (1 )); smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1 ));
if constexpr (!kIsVariableB) { if constexpr (!kIsVariableB) {
#pragma unroll #pragma unroll
for (int r = 0; r < kNRows; ++r) { for (int r = 0; r < kNRows; ++r) {
...@@ -232,24 +236,16 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { ...@@ -232,24 +236,16 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
!kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
// Reset A bar for cumulative sequences (Real) if (seqlen % (kNItems * kNThreads) != 0) { // So that the last state is correct
if constexpr (kUseIndex) { if (threadIdx.x * kNItems + i >= seqlen - chunk * kChunkSize) {
if (index_vals_load[r][i] == 0) {
thread_data[i].x = 0.f;
}
}
if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
thread_data[i] = make_float2(1.f, 0.f); thread_data[i] = make_float2(1.f, 0.f);
} }
} }
} }
// Initialize running total // Initialize running total
scan_t running_prefix;
// If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx]): 0.0);
running_prefix = chunk == 0 ? x[(r * params.n_chunks) * params.dstate + state_idx] : ( threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f));
// running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix); SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
typename Ktraits::BlockScanT(smem_scan).InclusiveScan( typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
...@@ -258,7 +254,9 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { ...@@ -258,7 +254,9 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
// Unless there's only 1 warp, but then it's the same thread (0) reading and writing. // Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
smem_running_prefix[state_idx] = prefix_op.running_prefix; smem_running_prefix[state_idx] = prefix_op.running_prefix;
x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix; if (chunk == n_chunks - 1) {
ssm_states[state_idx] = input_t(prefix_op.running_prefix.y);
}
} }
#pragma unroll #pragma unroll
for (int i = 0; i < kNItems; ++i) { for (int i = 0; i < kNItems; ++i) {
...@@ -270,7 +268,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { ...@@ -270,7 +268,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
} }
} }
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
+ dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
...@@ -278,26 +276,26 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { ...@@ -278,26 +276,26 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if constexpr (!kDirectIO) { if constexpr (!kDirectIO) {
if (r > 0) { __syncthreads(); } if (r > 0) { __syncthreads(); }
} }
store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize);
} }
if constexpr (kHasZ) { if constexpr (kHasZ) {
input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + sequence_start_index * params.z_batch_stride
+ dim_id * kNRows * params.z_d_stride + chunk * kChunkSize; + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize;
input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + sequence_start_index * params.out_z_batch_stride
+ dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize; + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize;
#pragma unroll #pragma unroll
for (int r = 0; r < kNRows; ++r) { for (int r = 0; r < kNRows; ++r) {
input_t z_vals[kNItems]; input_t z_vals[kNItems];
__syncthreads(); __syncthreads();
load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize); load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, seqlen - chunk * kChunkSize);
#pragma unroll #pragma unroll
for (int i = 0; i < kNItems; ++i) { for (int i = 0; i < kNItems; ++i) {
float z_val = z_vals[i]; float z_val = z_vals[i];
out_vals[r][i] *= z_val / (1 + expf(-z_val)); out_vals[r][i] *= z_val / (1 + expf(-z_val));
} }
__syncthreads(); __syncthreads();
store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize);
} }
} }
...@@ -316,8 +314,8 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) { ...@@ -316,8 +314,8 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
constexpr bool kIsVariableC = true; constexpr bool kIsVariableC = true;
constexpr bool kHasZ = true; constexpr bool kHasZ = true;
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
BOOL_SWITCH(params.index_ptr != nullptr , kUseIndex, [&] { BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kUseIndex, input_t, weight_t>; using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>;
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
dim3 grid(params.batch, params.dim / kNRows); dim3 grid(params.batch, params.dim / kNRows);
auto kernel = &selective_scan_fwd_kernel<Ktraits>; auto kernel = &selective_scan_fwd_kernel<Ktraits>;
...@@ -393,7 +391,6 @@ void set_ssm_params_fwd(SSMParamsBase &params, ...@@ -393,7 +391,6 @@ void set_ssm_params_fwd(SSMParamsBase &params,
const size_t seqlen, const size_t seqlen,
const size_t dstate, const size_t dstate,
const size_t n_groups, const size_t n_groups,
const size_t n_chunks,
const bool is_variable_B, const bool is_variable_B,
const bool is_variable_C, const bool is_variable_C,
// device pointers // device pointers
...@@ -405,12 +402,16 @@ void set_ssm_params_fwd(SSMParamsBase &params, ...@@ -405,12 +402,16 @@ void set_ssm_params_fwd(SSMParamsBase &params,
const torch::Tensor out, const torch::Tensor out,
const torch::Tensor z, const torch::Tensor z,
const torch::Tensor out_z, const torch::Tensor out_z,
void* D_ptr, const c10::optional<at::Tensor>& D,
void* delta_bias_ptr, const c10::optional<at::Tensor>& delta_bias,
void* x_ptr, const torch::Tensor ssm_states,
bool has_z, bool has_z,
bool delta_softplus, bool delta_softplus,
void* index_ptr) { const c10::optional<at::Tensor>& query_start_loc,
const c10::optional<at::Tensor>& cache_indices,
const c10::optional<at::Tensor>& has_initial_state,
bool varlen,
int64_t pad_slot_id) {
// Reset the parameters // Reset the parameters
memset(&params, 0, sizeof(params)); memset(&params, 0, sizeof(params));
...@@ -420,8 +421,8 @@ void set_ssm_params_fwd(SSMParamsBase &params, ...@@ -420,8 +421,8 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.seqlen = seqlen; params.seqlen = seqlen;
params.dstate = dstate; params.dstate = dstate;
params.n_groups = n_groups; params.n_groups = n_groups;
params.n_chunks = n_chunks;
params.dim_ngroups_ratio = dim / n_groups; params.dim_ngroups_ratio = dim / n_groups;
params.pad_slot_id = pad_slot_id;
params.delta_softplus = delta_softplus; params.delta_softplus = delta_softplus;
...@@ -434,55 +435,86 @@ void set_ssm_params_fwd(SSMParamsBase &params, ...@@ -434,55 +435,86 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.A_ptr = A.data_ptr(); params.A_ptr = A.data_ptr();
params.B_ptr = B.data_ptr(); params.B_ptr = B.data_ptr();
params.C_ptr = C.data_ptr(); params.C_ptr = C.data_ptr();
params.D_ptr = D_ptr; params.D_ptr = D.has_value() ? D.value().data_ptr() : nullptr;
params.delta_bias_ptr = delta_bias_ptr; params.delta_bias_ptr = delta_bias.has_value() ? delta_bias.value().data_ptr() : nullptr;
params.out_ptr = out.data_ptr(); params.out_ptr = out.data_ptr();
params.x_ptr = x_ptr; params.ssm_states_ptr = ssm_states.data_ptr();
params.z_ptr = has_z ? z.data_ptr() : nullptr; params.z_ptr = has_z ? z.data_ptr() : nullptr;
params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr; params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;
params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr;
params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr;
params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr;
params.index_ptr = index_ptr;
// All stride are in elements, not bytes. // All stride are in elements, not bytes.
params.A_d_stride = A.stride(0); params.A_d_stride = A.stride(0);
params.A_dstate_stride = A.stride(1); params.A_dstate_stride = A.stride(1);
if (!is_variable_B) {
params.B_d_stride = B.stride(0); if (varlen){
} else { params.B_batch_stride = B.stride(2);
params.B_batch_stride = B.stride(0); params.B_group_stride = B.stride(0);
params.B_group_stride = B.stride(1); params.B_dstate_stride = B.stride(1);
} params.C_batch_stride = C.stride(2);
params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2); params.C_group_stride = C.stride(0);
if (!is_variable_C) { params.C_dstate_stride = C.stride(1);
params.C_d_stride = C.stride(0);
} else { params.u_batch_stride = u.stride(1);
params.C_batch_stride = C.stride(0); params.u_d_stride = u.stride(0);
params.C_group_stride = C.stride(1); params.delta_batch_stride = delta.stride(1);
params.delta_d_stride = delta.stride(0);
if (has_z) {
params.z_batch_stride = z.stride(1);
params.z_d_stride = z.stride(0);
params.out_z_batch_stride = out_z.stride(1);
params.out_z_d_stride = out_z.stride(0);
}
params.out_batch_stride = out.stride(1);
params.out_d_stride = out.stride(0);
} }
params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2); else{
params.u_batch_stride = u.stride(0); if (!is_variable_B) {
params.u_d_stride = u.stride(1); params.B_d_stride = B.stride(0);
params.delta_batch_stride = delta.stride(0); } else {
params.delta_d_stride = delta.stride(1); params.B_batch_stride = B.stride(0);
if (has_z) { params.B_group_stride = B.stride(1);
params.z_batch_stride = z.stride(0); }
params.z_d_stride = z.stride(1); params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2);
params.out_z_batch_stride = out_z.stride(0); if (!is_variable_C) {
params.out_z_d_stride = out_z.stride(1); params.C_d_stride = C.stride(0);
} else {
params.C_batch_stride = C.stride(0);
params.C_group_stride = C.stride(1);
}
params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2);
params.u_batch_stride = u.stride(0);
params.u_d_stride = u.stride(1);
params.delta_batch_stride = delta.stride(0);
params.delta_d_stride = delta.stride(1);
if (has_z) {
params.z_batch_stride = z.stride(0);
params.z_d_stride = z.stride(1);
params.out_z_batch_stride = out_z.stride(0);
params.out_z_d_stride = out_z.stride(1);
}
params.out_batch_stride = out.stride(0);
params.out_d_stride = out.stride(1);
} }
params.out_batch_stride = out.stride(0);
params.out_d_stride = out.stride(1);
} }
std::vector<torch::Tensor> void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C, const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C,
const c10::optional<torch::Tensor> &D_, const c10::optional<torch::Tensor> &D_,
const c10::optional<torch::Tensor> &z_, const c10::optional<torch::Tensor> &z_,
const c10::optional<torch::Tensor> &delta_bias_, const c10::optional<torch::Tensor> &delta_bias_,
bool delta_softplus, bool delta_softplus,
const c10::optional<torch::Tensor> &index_, const c10::optional<torch::Tensor> &query_start_loc,
const c10::optional<torch::Tensor> &x) { const c10::optional<torch::Tensor> &cache_indices,
const c10::optional<torch::Tensor> &has_initial_state,
const torch::Tensor &ssm_states,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t pad_slot_id) {
auto input_type = u.scalar_type(); auto input_type = u.scalar_type();
auto weight_type = A.scalar_type(); auto weight_type = A.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
...@@ -505,23 +537,37 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, ...@@ -505,23 +537,37 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
const auto sizes = u.sizes(); const auto sizes = u.sizes();
const int batch_size = sizes[0]; const bool varlen = query_start_loc.has_value();
const int dim = sizes[1]; const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0];
const int seqlen = sizes[2]; const int dim = varlen ? sizes[0] : sizes[1];
const int seqlen = varlen ? sizes[1] : sizes[2];
const int dstate = A.size(1); const int dstate = A.size(1);
const int n_groups = is_variable_B ? B.size(1) : 1; const int n_groups = varlen ? B.size(0) : B.size(1);
TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256"); TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
CHECK_SHAPE(u, batch_size, dim, seqlen); if (varlen) {
CHECK_SHAPE(delta, batch_size, dim, seqlen); CHECK_SHAPE(u, dim, seqlen);
CHECK_SHAPE(delta, dim, seqlen);
} else {
CHECK_SHAPE(u, batch_size, dim, seqlen);
CHECK_SHAPE(delta, batch_size, dim, seqlen);
}
CHECK_SHAPE(A, dim, dstate); CHECK_SHAPE(A, dim, dstate);
TORCH_CHECK(is_variable_B, "is_variable_B = False is disabled in favor of reduced binary size") TORCH_CHECK(is_variable_B, "is_variable_B = False is disabled in favor of reduced binary size")
CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen ); if (varlen) {
CHECK_SHAPE(B, n_groups, dstate, seqlen);
} else {
CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen);
}
TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
TORCH_CHECK(is_variable_C, "is_variable_C = False is disabled in favor of reduced binary size") TORCH_CHECK(is_variable_C, "is_variable_C = False is disabled in favor of reduced binary size")
CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen); if (varlen) {
CHECK_SHAPE(C, n_groups, dstate, seqlen);
} else {
CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen);
}
TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
if (D_.has_value()) { if (D_.has_value()) {
...@@ -539,13 +585,31 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, ...@@ -539,13 +585,31 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
CHECK_SHAPE(delta_bias, dim); CHECK_SHAPE(delta_bias, dim);
} }
if (index_.has_value()) {
auto index = index_.value();
TORCH_CHECK(index.scalar_type() == at::ScalarType::Int); if (has_initial_state.has_value()) {
TORCH_CHECK(index.is_cuda()); auto has_initial_state_ = has_initial_state.value();
CHECK_SHAPE(index, batch_size, seqlen); TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool);
TORCH_CHECK(has_initial_state_.is_cuda());
CHECK_SHAPE(has_initial_state_, batch_size);
} }
if (query_start_loc.has_value()) {
auto query_start_loc_ = query_start_loc.value();
TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int);
TORCH_CHECK(query_start_loc_.is_cuda());
}
if (cache_indices.has_value()) {
auto cache_indices_ = cache_indices.value();
TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int);
TORCH_CHECK(cache_indices_.is_cuda());
CHECK_SHAPE(cache_indices_, batch_size);
}
at::Tensor z, out_z; at::Tensor z, out_z;
const bool has_z = z_.has_value(); const bool has_z = z_.has_value();
TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size") TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size")
...@@ -553,32 +617,36 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, ...@@ -553,32 +617,36 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
TORCH_CHECK(z.scalar_type() == input_type); TORCH_CHECK(z.scalar_type() == input_type);
TORCH_CHECK(z.is_cuda()); TORCH_CHECK(z.is_cuda());
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
CHECK_SHAPE(z, batch_size, dim, seqlen); if (varlen){
out_z = torch::empty_like(z); CHECK_SHAPE(z, dim, seqlen);
} else {
CHECK_SHAPE(z, batch_size, dim, seqlen);
}
out_z = z;
const int n_chunks = (seqlen + 2048 - 1) / 2048;
// const int n_chunks = (seqlen + 1024 - 1) / 1024;
// at::Tensor out = torch::empty_like(u);
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
at::Tensor out = torch::empty_like(delta); at::Tensor out = delta;
if (x.has_value()){ TORCH_CHECK(ssm_states.scalar_type() == input_type);
auto _x = x.value(); TORCH_CHECK(ssm_states.is_cuda());
TORCH_CHECK(_x.scalar_type() == weight_type); TORCH_CHECK(ssm_states.stride(-1) == 1);
TORCH_CHECK(_x.is_cuda());
TORCH_CHECK(_x.stride(-1) == 1);
CHECK_SHAPE(_x, batch_size, dim, n_chunks, dstate * 2);
}
SSMParamsBase params; SSMParamsBase params;
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, is_variable_B, is_variable_C,
u, delta, A, B, C, out, z, out_z, u, delta, A, B, C, out, z, out_z,
D_.has_value() ? D_.value().data_ptr() : nullptr, D_,
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, delta_bias_,
x.value().data_ptr(), ssm_states,
has_z, has_z,
delta_softplus, delta_softplus,
index_.has_value() ? index_.value().data_ptr() : nullptr); query_start_loc,
cache_indices,
has_initial_state,
varlen,
pad_slot_id
);
// Otherwise the kernel will be launched from cuda:0 device // Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing // Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)u.get_device()}; at::cuda::CUDAGuard device_guard{(char)u.get_device()};
...@@ -586,8 +654,5 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, ...@@ -586,8 +654,5 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
selective_scan_fwd_cuda<input_t, weight_t>(params, stream); selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
}); });
std::vector<at::Tensor> result = {out};
if (has_z) { result.push_back(out_z); }
return result;
} }
...@@ -38,6 +38,7 @@ using FragA = Vec<half2, 4>; ...@@ -38,6 +38,7 @@ using FragA = Vec<half2, 4>;
using FragB = Vec<half2, 2>; using FragB = Vec<half2, 2>;
using FragC = Vec<float, 4>; using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>; // quantization scales using FragS = Vec<half2, 1>; // quantization scales
using FragZP = Vec<half2, 4>;
// Predicated asynchronous global->shared copy; used for inputs A where we apply // Predicated asynchronous global->shared copy; used for inputs A where we apply
// predication to handle batchsizes that are not multiples of 16. // predication to handle batchsizes that are not multiples of 16.
...@@ -175,6 +176,46 @@ __device__ inline FragB dequant<vllm::kU8B128.id()>(int q) { ...@@ -175,6 +176,46 @@ __device__ inline FragB dequant<vllm::kU8B128.id()>(int q) {
return frag_b; return frag_b;
} }
template <>
__device__ inline FragB dequant<vllm::kU4.id()>(int q) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
const int SUB = 0x64006400;
const int MUL = 0x2c002c00;
const int ADD = 0xd400d400;
FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
return frag_b;
}
template <>
__device__ inline FragB dequant<vllm::kU8.id()>(int q) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
return frag_b;
}
// Multiply dequantized values by the corresponding quantization scale; used // Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization. // only for grouped quantization.
__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
...@@ -183,11 +224,10 @@ __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { ...@@ -183,11 +224,10 @@ __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
frag_b[1] = __hmul2(frag_b[1], s); frag_b[1] = __hmul2(frag_b[1], s);
} }
// Given 2 floats multiply by 2 scales (halves) __device__ inline void sub_zp(FragB& frag_b, half2& frag_zp, int i) {
__device__ inline void scale_float(float* c, FragS& s) { half2 zp = __half2half2(reinterpret_cast<__half*>(&frag_zp)[i]);
__half* s_ptr = reinterpret_cast<__half*>(&s); frag_b[0] = __hsub2(frag_b[0], zp);
c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); frag_b[1] = __hsub2(frag_b[1], zp);
c[1] = __fmul_rn(c[1], __half2float(s_ptr[1]));
} }
// Same as above, but for act_order (each K is multiplied individually) // Same as above, but for act_order (each K is multiplied individually)
...@@ -205,6 +245,13 @@ __device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, ...@@ -205,6 +245,13 @@ __device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2,
frag_b[1] = __hmul2(frag_b[1], s_val_3_4); frag_b[1] = __hmul2(frag_b[1], s_val_3_4);
} }
// Given 2 floats multiply by 2 scales (halves)
__device__ inline void scale_float(float* c, FragS& s) {
__half* s_ptr = reinterpret_cast<__half*>(&s);
c[0] = __fmul_rn(c[0], __half2float(s_ptr[0]));
c[1] = __fmul_rn(c[1], __half2float(s_ptr[1]));
}
// Wait until barrier reaches `count`, then lock for current threadblock. // Wait until barrier reaches `count`, then lock for current threadblock.
__device__ inline void barrier_acquire(int* lock, int count) { __device__ inline void barrier_acquire(int* lock, int count) {
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
...@@ -248,10 +295,11 @@ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id ...@@ -248,10 +295,11 @@ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int stages, // number of stages for the async global->shared const int stages, // number of stages for the async global->shared
// fetch pipeline // fetch pipeline
const bool has_act_order, // whether act_order is enabled const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks = -1 // number of consecutive 16x16 blocks const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale // with a separate quantization scale
> >
__device__ inline void MarlinMoESingle( __device__ void MarlinMoESingle(
const int4* __restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn int4* __restrict__ C, // fp16 output buffer of shape mxn
...@@ -259,6 +307,8 @@ __device__ inline void MarlinMoESingle( ...@@ -259,6 +307,8 @@ __device__ inline void MarlinMoESingle(
const float* __restrict__ topk_weights, // float topk weights const float* __restrict__ topk_weights, // float topk weights
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn // (k/groupsize)xn
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k const int* __restrict__ g_idx, // int32 group indices of shape k
const int* __restrict__ expert_offsets, const int* __restrict__ expert_offsets,
int num_groups, // number of scale groups per output channel int num_groups, // number of scale groups per output channel
...@@ -400,8 +450,12 @@ __device__ inline void MarlinMoESingle( ...@@ -400,8 +450,12 @@ __device__ inline void MarlinMoESingle(
int tb_n_warps = thread_n_blocks / 4; int tb_n_warps = thread_n_blocks / 4;
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
constexpr int sorted_sh_stride = threads; // Zero-points sizes/strides
constexpr int sorted_gl_stride = threads; int zp_gl_stride = (prob_n / pack_factor) / 4;
constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4;
constexpr int zp_tb_groups = s_tb_groups;
constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0;
int zp_gl_rd_delta = zp_gl_stride;
// Global A read index of current thread. // Global A read index of current thread.
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
...@@ -442,6 +496,19 @@ __device__ inline void MarlinMoESingle( ...@@ -442,6 +496,19 @@ __device__ inline void MarlinMoESingle(
int s_sh_wr = threadIdx.x; int s_sh_wr = threadIdx.x;
bool s_sh_wr_pred = threadIdx.x < s_sh_stride; bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
// Zero-points
int zp_gl_rd;
if constexpr (has_zp) {
if constexpr (group_blocks == -1) {
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
} else {
zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
zp_sh_stride * slice_col + threadIdx.x;
}
}
int zp_sh_wr = threadIdx.x;
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
// We use a different scale layout for grouped and column-wise quantization as // We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in // we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case. // row-major in the latter case.
...@@ -453,23 +520,29 @@ __device__ inline void MarlinMoESingle( ...@@ -453,23 +520,29 @@ __device__ inline void MarlinMoESingle(
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) % 4; (threadIdx.x % 32) % 4;
// Zero-points have the same read layout as the scales
// (without column-wise case)
constexpr int num_col_threads = 8;
constexpr int num_row_threads = 4;
constexpr int num_ints_per_thread = 8 / pack_factor;
int zp_sh_rd;
if constexpr (has_zp) {
zp_sh_rd = num_ints_per_thread * num_col_threads *
((threadIdx.x / 32) % (thread_n_blocks / 4)) +
num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
}
int sh_first_group_id = -1; int sh_first_group_id = -1;
int sh_num_groups = -1; int sh_num_groups = -1;
constexpr int sh_max_num_groups = 32; constexpr int sh_max_num_groups = 32;
int shs_size;
if constexpr (has_act_order)
shs_size = sh_max_num_groups * s_sh_stride + threads;
else
shs_size = group_blocks > 0 ? stages * s_sh_stage : threads;
extern __shared__ int4 sh[]; extern __shared__ int4 sh[];
// Shared memory storage for global fetch pipelines. // Shared memory storage for global fetch pipelines.
int4* sh_a = sh; int4* sh_a = sh;
int4* sh_b = sh_a + (stages * a_sh_stage); int4* sh_b = sh_a + (stages * a_sh_stage);
int4* sh_g_idx = sh_b + (stages * b_sh_stage); int4* sh_g_idx = sh_b + (stages * b_sh_stage);
int4* sh_s = sh_g_idx + (stages * g_idx_stage); int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
int* sh_sorted = (int*)(sh_s + shs_size); int4* sh_s = sh_zp + (stages * zp_sh_stage);
// Precompute which thread should not read memory in which iterations; this is // Precompute which thread should not read memory in which iterations; this is
// needed if there are more threads than required for a certain tilesize or // needed if there are more threads than required for a certain tilesize or
...@@ -525,8 +598,10 @@ __device__ inline void MarlinMoESingle( ...@@ -525,8 +598,10 @@ __device__ inline void MarlinMoESingle(
FragA frag_a[2][thread_m_blocks]; FragA frag_a[2][thread_m_blocks];
I4 frag_b_quant[2][b_thread_vecs]; I4 frag_b_quant[2][b_thread_vecs];
FragC frag_c[thread_m_blocks][4][2]; FragC frag_c[thread_m_blocks][4][2];
FragS frag_s[2][4]; // No act-order FragS frag_s[2][4]; // No act-order
FragS act_frag_s[2][4][4]; // For act-order FragS act_frag_s[2][4][4]; // For act-order
int frag_qzp[2][num_ints_per_thread]; // Zero-points
FragZP frag_zp; // Zero-points in fp16
// Zero accumulators. // Zero accumulators.
auto zero_accums = [&]() { auto zero_accums = [&]() {
...@@ -633,6 +708,28 @@ __device__ inline void MarlinMoESingle( ...@@ -633,6 +708,28 @@ __device__ inline void MarlinMoESingle(
} }
} }
} }
if constexpr (has_zp && group_blocks != -1) {
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
if constexpr (group_blocks >= thread_k_blocks) {
// Only fetch zero-points if this tile starts a new group
if (pipe % (group_blocks / thread_k_blocks) == 0) {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta;
}
} else {
for (int i = 0; i < zp_tb_groups; i++) {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr],
&zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta;
}
}
}
} }
} }
// Insert a fence even when we are winding down the pipeline to ensure that // Insert a fence even when we are winding down the pipeline to ensure that
...@@ -640,15 +737,9 @@ __device__ inline void MarlinMoESingle( ...@@ -640,15 +737,9 @@ __device__ inline void MarlinMoESingle(
cp_async_fence(); cp_async_fence();
}; };
// TODO we are currently hitting illegal memory accesses when fetching auto fetch_zp_to_shared = [&]() {
// sorted_ids to shared data: fix this if (zp_sh_wr_pred) {
auto fetch_sorted_ids_to_shared = [&]() { cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]);
const int mpt = ceildiv(prob_m, threads);
for (int i = 0; i < mpt; i++) {
if ((i * sorted_gl_stride) + threadIdx.x < prob_m) {
sh_sorted[(i * sorted_sh_stride) + threadIdx.x] =
sorted_ids[(i * sorted_gl_stride) + threadIdx.x];
}
} }
}; };
...@@ -799,8 +890,83 @@ __device__ inline void MarlinMoESingle( ...@@ -799,8 +890,83 @@ __device__ inline void MarlinMoESingle(
} }
}; };
auto fetch_zp_to_registers = [&](int k, int full_pipe) {
// This code does not handle group_blocks == 0,
// which signifies act_order.
// has_zp implies AWQ, which doesn't have act_order,
static_assert(!has_zp || group_blocks != 0);
if constexpr (has_zp) {
int pipe = full_pipe % stages;
if constexpr (group_blocks == -1) {
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
}
} else if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_zp_stage =
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
} else {
int warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
int cur_k = warp_row * 16;
cur_k += k_iter_size * (k % b_sh_wr_iters);
int k_blocks = cur_k / 16;
int cur_group_id = 0;
// Suppress bogus and persistent divide-by-zero warning
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
cur_group_id = k_blocks / group_blocks;
#pragma nv_diagnostic pop
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
sh_zp_stage += cur_group_id * zp_sh_stride;
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
}
}
};
// Execute the actual tensor core matmul of a sub-tile. // Execute the actual tensor core matmul of a sub-tile.
auto matmul = [&](int k) { auto matmul = [&](int k) {
if constexpr (has_zp) {
FragB frag_zp_0;
FragB frag_zp_1;
int zp_quant_0, zp_quant_1;
if constexpr (w_type.size_bits() == 4) {
zp_quant_0 = frag_qzp[k % 2][0];
zp_quant_1 = zp_quant_0 >> 8;
} else {
static_assert(w_type.size_bits() == 8);
zp_quant_0 = frag_qzp[k % 2][0];
zp_quant_1 = frag_qzp[k % 2][1];
}
frag_zp_0 = dequant<w_type_id>(zp_quant_0);
frag_zp_1 = dequant<w_type_id>(zp_quant_1);
frag_zp[0] = frag_zp_0[0];
frag_zp[1] = frag_zp_0[1];
frag_zp[2] = frag_zp_1[0];
frag_zp[3] = frag_zp_1[1];
}
// We have the m dimension as the inner loop in order to encourage overlapping // We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations. // dequantization and matmul operations.
#pragma unroll #pragma unroll
...@@ -818,6 +984,10 @@ __device__ inline void MarlinMoESingle( ...@@ -818,6 +984,10 @@ __device__ inline void MarlinMoESingle(
FragB frag_b0 = dequant<w_type_id>(b_quant_0); FragB frag_b0 = dequant<w_type_id>(b_quant_0);
FragB frag_b1 = dequant<w_type_id>(b_quant_1); FragB frag_b1 = dequant<w_type_id>(b_quant_1);
// Apply zero-point to frag_b0
if constexpr (has_zp) {
sub_zp(frag_b0, frag_zp[j], 0);
}
// Apply scale to frag_b0 // Apply scale to frag_b0
if constexpr (has_act_order) { if constexpr (has_act_order) {
...@@ -829,6 +999,11 @@ __device__ inline void MarlinMoESingle( ...@@ -829,6 +999,11 @@ __device__ inline void MarlinMoESingle(
} }
} }
// Apply zero-point to frag_b1
if constexpr (has_zp) {
sub_zp(frag_b1, frag_zp[j], 1);
}
// Apply scale to frag_b1 // Apply scale to frag_b1
if constexpr (has_act_order) { if constexpr (has_act_order) {
scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
...@@ -1062,9 +1237,6 @@ __device__ inline void MarlinMoESingle( ...@@ -1062,9 +1237,6 @@ __device__ inline void MarlinMoESingle(
// Start global fetch and register load pipelines. // Start global fetch and register load pipelines.
auto start_pipes = [&]() { auto start_pipes = [&]() {
// TODO re-enable after fixing this function
// fetch_sorted_ids_to_shared();
// __syncthreads();
#pragma unroll #pragma unroll
for (int i = 0; i < stages - 1; i++) { for (int i = 0; i < stages - 1; i++) {
...@@ -1075,6 +1247,12 @@ __device__ inline void MarlinMoESingle( ...@@ -1075,6 +1247,12 @@ __device__ inline void MarlinMoESingle(
} }
fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
} }
if constexpr (has_zp && group_blocks == -1) {
if (i == 0) {
fetch_zp_to_shared();
}
}
fetch_to_shared(i, i, i < slice_iters); fetch_to_shared(i, i, i < slice_iters);
} }
...@@ -1083,6 +1261,7 @@ __device__ inline void MarlinMoESingle( ...@@ -1083,6 +1261,7 @@ __device__ inline void MarlinMoESingle(
init_same_group(0); init_same_group(0);
fetch_to_registers(0, 0); fetch_to_registers(0, 0);
fetch_scales_to_registers(0, 0); fetch_scales_to_registers(0, 0);
fetch_zp_to_registers(0, 0);
a_gl_rd += a_gl_rd_delta_o * (stages - 1); a_gl_rd += a_gl_rd_delta_o * (stages - 1);
slice_k_start_shared_fetch += tb_k * (stages - 1); slice_k_start_shared_fetch += tb_k * (stages - 1);
}; };
...@@ -1102,6 +1281,7 @@ __device__ inline void MarlinMoESingle( ...@@ -1102,6 +1281,7 @@ __device__ inline void MarlinMoESingle(
for (int k = 0; k < b_sh_wr_iters; k++) { for (int k = 0; k < b_sh_wr_iters; k++) {
fetch_to_registers(k + 1, pipe % stages); fetch_to_registers(k + 1, pipe % stages);
fetch_scales_to_registers(k + 1, pipe); fetch_scales_to_registers(k + 1, pipe);
fetch_zp_to_registers(k + 1, pipe);
if (k == b_sh_wr_iters - 2) { if (k == b_sh_wr_iters - 2) {
fetch_to_shared((pipe + stages - 1) % stages, pipe, fetch_to_shared((pipe + stages - 1) % stages, pipe,
slice_iters >= stages); slice_iters >= stages);
...@@ -1236,7 +1416,9 @@ __device__ inline void MarlinMoESingle( ...@@ -1236,7 +1416,9 @@ __device__ inline void MarlinMoESingle(
} else { } else {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x; s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
} }
start_pipes(); start_pipes();
} }
} }
...@@ -1250,6 +1432,7 @@ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id ...@@ -1250,6 +1432,7 @@ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int stages, // number of stages for the async global->shared const int stages, // number of stages for the async global->shared
// fetch pipeline // fetch pipeline
const bool has_act_order, // whether act_order is enabled const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks = -1 // number of consecutive 16x16 blocks const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale // with a separate quantization scale
> >
...@@ -1261,6 +1444,8 @@ __global__ void MarlinMoE( ...@@ -1261,6 +1444,8 @@ __global__ void MarlinMoE(
const float* __restrict__ topk_weights, // float topk weights const float* __restrict__ topk_weights, // float topk weights
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn // (k/groupsize)xn
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k const int* __restrict__ g_idx, // int32 group indices of shape k
const int* __restrict__ expert_offsets, const int* __restrict__ expert_offsets,
int num_groups, // number of scale groups per output channel int num_groups, // number of scale groups per output channel
...@@ -1309,29 +1494,29 @@ __global__ void MarlinMoE( ...@@ -1309,29 +1494,29 @@ __global__ void MarlinMoE(
if (max_block == 1) { if (max_block == 1) {
MarlinMoESingle<w_type_id, threads, 1, thread_n_blocks, thread_k_blocks, MarlinMoESingle<w_type_id, threads, 1, thread_n_blocks, thread_k_blocks,
stages, has_act_order, group_blocks>( stages, has_act_order, has_zp, group_blocks>(
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx,
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
current_m_block); current_m_block);
} else if (max_block == 2) { } else if (max_block == 2) {
MarlinMoESingle<w_type_id, threads, 2, thread_n_blocks, thread_k_blocks, MarlinMoESingle<w_type_id, threads, 2, thread_n_blocks, thread_k_blocks,
stages, has_act_order, group_blocks>( stages, has_act_order, has_zp, group_blocks>(
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx,
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
current_m_block); current_m_block);
} else if (max_block == 3) { } else if (max_block == 3) {
MarlinMoESingle<w_type_id, threads, 3, thread_n_blocks, thread_k_blocks, MarlinMoESingle<w_type_id, threads, 3, thread_n_blocks, thread_k_blocks,
stages, has_act_order, group_blocks>( stages, has_act_order, has_zp, group_blocks>(
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx,
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
current_m_block); current_m_block);
} else { } else {
MarlinMoESingle<w_type_id, threads, 4, thread_n_blocks, thread_k_blocks, MarlinMoESingle<w_type_id, threads, 4, thread_n_blocks, thread_k_blocks,
stages, has_act_order, group_blocks>( stages, has_act_order, has_zp, group_blocks>(
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx,
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
current_m_block); current_m_block);
...@@ -1347,6 +1532,7 @@ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id ...@@ -1347,6 +1532,7 @@ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int stages, // number of stages for the async global->shared const int stages, // number of stages for the async global->shared
// fetch pipeline // fetch pipeline
const bool has_act_order, // whether act_order is enabled const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks = -1 // number of consecutive 16x16 blocks const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale // with a separate quantization scale
> >
...@@ -1358,6 +1544,8 @@ __global__ void MarlinMoE( ...@@ -1358,6 +1544,8 @@ __global__ void MarlinMoE(
const float* __restrict__ topk_weights, // float topk weights const float* __restrict__ topk_weights, // float topk weights
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn // (k/groupsize)xn
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k const int* __restrict__ g_idx, // int32 group indices of shape k
const int* __restrict__ expert_offsets, const int* __restrict__ expert_offsets,
int num_groups, // number of scale groups per output channel int num_groups, // number of scale groups per output channel
...@@ -1374,7 +1562,6 @@ __global__ void MarlinMoE( ...@@ -1374,7 +1562,6 @@ __global__ void MarlinMoE(
int current_m_block, // current m block to start kernel computation from int current_m_block, // current m block to start kernel computation from
int max_par, // maximum parallelism int max_par, // maximum parallelism
int cfg_max_m_blocks // upper bound on m blocks int cfg_max_m_blocks // upper bound on m blocks
) { ) {
// Marlin is not implemented yet for SM < 8.0 // Marlin is not implemented yet for SM < 8.0
assert(false); assert(false);
...@@ -1389,37 +1576,41 @@ __global__ void MarlinMoE( ...@@ -1389,37 +1576,41 @@ __global__ void MarlinMoE(
const int USER_THREADS = const int USER_THREADS =
256; // Note: This is only used with user-provided thread_k/n 256; // Note: This is only used with user-provided thread_k/n
const int STAGES = 4; // 4 pipeline stages fit into shared memory const int STAGES = 4; // 4 pipeline stages fit into shared memory
// const int SHARED_MEM =
// 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)
static constexpr int min_thread_n = 64; static constexpr int min_thread_n = 64;
static constexpr int min_thread_k = 64; static constexpr int min_thread_k = 64;
#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \ #define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \
GROUP_BLOCKS, NUM_THREADS) \ HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \
else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
num_threads == NUM_THREADS) { \ group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \ cudaFuncSetAttribute( \
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
STAGES, HAS_ACT_ORDER, GROUP_BLOCKS>, \ STAGES, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
STAGES, HAS_ACT_ORDER, GROUP_BLOCKS> \ STAGES, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \ <<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
replicate_input, apply_weights, m_block, max_par, \ replicate_input, apply_weights, m_block, max_par, \
cfg_max_m_blocks); \ cfg_max_m_blocks); \
} }
#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ #define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
#define AWQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
} // namespace marlin_moe } // namespace marlin_moe
#include "marlin_moe_kernel_ku4.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks) {
bool has_zp = true;
if (false) {
}
AWQ_CALL_IF_MOE(vllm::kU4, 16, 4, 256)
AWQ_CALL_IF_MOE(vllm::kU4, 8, 8, 256)
AWQ_CALL_IF_MOE(vllm::kU4, 8, 4, 128)
AWQ_CALL_IF_MOE(vllm::kU4, 4, 8, 128)
else {
return false;
}
return true;
}
} // namespace marlin_moe
#pragma once
#include "marlin_moe_kernel.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks);
} // namespace marlin_moe
...@@ -9,11 +9,13 @@ bool call_marlin_moe_kernel_ku4b8( ...@@ -9,11 +9,13 @@ bool call_marlin_moe_kernel_ku4b8(
bool has_act_order, int group_blocks, int num_threads, int blocks, bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr, int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr, const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts, const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
bool replicate_input, bool apply_weights, int m_block, int max_par, int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int cfg_max_m_blocks) { int m_block, int max_par, int cfg_max_m_blocks) {
bool has_zp = false;
if (false) { if (false) {
} }
GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256)
......
...@@ -11,10 +11,10 @@ bool call_marlin_moe_kernel_ku4b8( ...@@ -11,10 +11,10 @@ bool call_marlin_moe_kernel_ku4b8(
bool has_act_order, int group_blocks, int num_threads, int blocks, bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr, int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr, const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts, const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
bool replicate_input, bool apply_weights, int m_block, int max_par, int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int cfg_max_m_blocks); int m_block, int max_par, int cfg_max_m_blocks);
} // namespace marlin_moe } // namespace marlin_moe
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment