"README_ORIGIN.md" did not exist on "c2e8357f6c7dc01a4b21a6db456ef303e07e26cb"
Unverified Commit 57e50f8d authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

style: upgrade the linter (#339)

* style: reformated codes

* style: reformated codes
parent b737368d
...@@ -219,28 +219,30 @@ Tensor dwconv_f16(Tensor input, Tensor weight, Tensor out, Tensor bias) { ...@@ -219,28 +219,30 @@ Tensor dwconv_f16(Tensor input, Tensor weight, Tensor out, Tensor bias) {
// weight = weight.copy(weight.device()); // weight = weight.copy(weight.device());
dispatchF16(weight.dtype(), [&]<typename half_t>() { dispatchF16(weight.dtype(), [&]<typename half_t>() {
using ElementOutput = half_t;
using ElementOutput = half_t; using ElementAccumulator = half_t;
using ElementAccumulator = half_t;
using ElementComputeEpilogue = half_t; using ElementComputeEpilogue = half_t;
using ElementInputA = half_t; using ElementInputA = half_t;
using ElementInputB = half_t; using ElementInputB = half_t;
using LayoutInputA = cutlass::layout::TensorNHWC; using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC; using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutOutput = cutlass::layout::TensorNHWC; using LayoutOutput = cutlass::layout::TensorNHWC;
using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, 64>; using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, 64>;
using FilterShape = cutlass::MatrixShape<3, 3>; using FilterShape = cutlass::MatrixShape<3, 3>;
using ThreadblockShape = cutlass::gemm::GemmShape<ThreadBlockOutputShape::kNHW, 64, FilterShape::kCount>; using ThreadblockShape = cutlass::gemm::GemmShape<ThreadBlockOutputShape::kNHW, 64, FilterShape::kCount>;
using WarpShape = cutlass::gemm::GemmShape<16, 64, FilterShape::kCount>; using WarpShape = cutlass::gemm::GemmShape<16, 64, FilterShape::kCount>;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop< using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop<
ElementInputA, LayoutInputA, ElementInputA,
ElementInputB, LayoutInputB, LayoutInputA,
ElementOutput, LayoutOutput, ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator, ElementAccumulator,
cutlass::arch::OpClassSimt, cutlass::arch::OpClassSimt,
cutlass::arch::Sm80, cutlass::arch::Sm80,
...@@ -249,15 +251,14 @@ Tensor dwconv_f16(Tensor input, Tensor weight, Tensor out, Tensor bias) { ...@@ -249,15 +251,14 @@ Tensor dwconv_f16(Tensor input, Tensor weight, Tensor out, Tensor bias) {
FilterShape, FilterShape,
WarpShape, WarpShape,
InstructionShape, InstructionShape,
cutlass::epilogue::thread::LinearCombination< cutlass::epilogue::thread::LinearCombination<ElementOutput,
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
128 / cutlass::sizeof_bits<ElementOutput>::value, ElementOutput,
ElementOutput, ElementComputeEpilogue>, ElementComputeEpilogue>,
cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle<1,
1, ThreadBlockOutputShape::kN,
ThreadBlockOutputShape::kN, ThreadBlockOutputShape::kH,
ThreadBlockOutputShape::kH, ThreadBlockOutputShape::kW>,
ThreadBlockOutputShape::kW>,
4, 4,
cutlass::arch::OpMultiplyAdd, cutlass::arch::OpMultiplyAdd,
cutlass::conv::IteratorAlgorithm::kFixedStrideDilation, cutlass::conv::IteratorAlgorithm::kFixedStrideDilation,
...@@ -267,15 +268,14 @@ Tensor dwconv_f16(Tensor input, Tensor weight, Tensor out, Tensor bias) { ...@@ -267,15 +268,14 @@ Tensor dwconv_f16(Tensor input, Tensor weight, Tensor out, Tensor bias) {
using DeviceKernel = typename cutlass::conv::device::DirectConvolution<DepthwiseDirect2dConv>; using DeviceKernel = typename cutlass::conv::device::DirectConvolution<DepthwiseDirect2dConv>;
cutlass::conv::Conv2dProblemSize problem_size( cutlass::conv::Conv2dProblemSize problem_size(cutlass::Tensor4DCoord(N, H, W, C_),
cutlass::Tensor4DCoord(N, H, W, C_), cutlass::Tensor4DCoord(K, R, S, C__),
cutlass::Tensor4DCoord(K, R, S, C__), cutlass::Tensor4DCoord(1, 1, 1, 1),
cutlass::Tensor4DCoord(1, 1, 1, 1), cutlass::MatrixCoord(1, 1),
cutlass::MatrixCoord(1, 1), cutlass::MatrixCoord(1, 1),
cutlass::MatrixCoord(1, 1), cutlass::conv::Mode::kCrossCorrelation,
cutlass::conv::Mode::kCrossCorrelation, 1,
1, C_ // groups
C_ // groups
); );
const int P = problem_size.P; const int P = problem_size.P;
...@@ -292,11 +292,17 @@ Tensor dwconv_f16(Tensor input, Tensor weight, Tensor out, Tensor bias) { ...@@ -292,11 +292,17 @@ Tensor dwconv_f16(Tensor input, Tensor weight, Tensor out, Tensor bias) {
Tensor tmp_weight = Tensor::empty_like(weight); Tensor tmp_weight = Tensor::empty_like(weight);
cutlass::TensorRef<ElementInputA, LayoutInputA> a_ref(input.data_ptr<ElementInputA>(), LayoutInputA(input.stride(2), input.stride(1), input.stride(0))); cutlass::TensorRef<ElementInputA, LayoutInputA> a_ref(
cutlass::TensorRef<ElementInputB, LayoutInputB> b_ref(weight.data_ptr<ElementInputB>(), LayoutInputB(weight.stride(2), weight.stride(1), weight.stride(0))); input.data_ptr<ElementInputA>(), LayoutInputA(input.stride(2), input.stride(1), input.stride(0)));
cutlass::TensorRef<ElementOutput, LayoutOutput> c_ref(bias.valid() ? bias.data_ptr<ElementOutput>() : out.data_ptr<ElementOutput>(), LayoutOutput(0, 0, 0)); cutlass::TensorRef<ElementInputB, LayoutInputB> b_ref(
cutlass::TensorRef<ElementOutput, LayoutOutput> d_ref(out.data_ptr<ElementOutput>(), LayoutOutput(out.stride(2), out.stride(1), out.stride(0))); weight.data_ptr<ElementInputB>(), LayoutInputB(weight.stride(2), weight.stride(1), weight.stride(0)));
cutlass::TensorRef<ElementOutput, LayoutOutput> tmpw_ref(tmp_weight.data_ptr<ElementOutput>(), LayoutOutput(tmp_weight.stride(2), tmp_weight.stride(1), tmp_weight.stride(0))); cutlass::TensorRef<ElementOutput, LayoutOutput> c_ref(
bias.valid() ? bias.data_ptr<ElementOutput>() : out.data_ptr<ElementOutput>(), LayoutOutput(0, 0, 0));
cutlass::TensorRef<ElementOutput, LayoutOutput> d_ref(
out.data_ptr<ElementOutput>(), LayoutOutput(out.stride(2), out.stride(1), out.stride(0)));
cutlass::TensorRef<ElementOutput, LayoutOutput> tmpw_ref(
tmp_weight.data_ptr<ElementOutput>(),
LayoutOutput(tmp_weight.stride(2), tmp_weight.stride(1), tmp_weight.stride(0)));
typename DeviceKernel::Arguments arguments{ typename DeviceKernel::Arguments arguments{
problem_size, problem_size,
...@@ -315,7 +321,6 @@ Tensor dwconv_f16(Tensor input, Tensor weight, Tensor out, Tensor bias) { ...@@ -315,7 +321,6 @@ Tensor dwconv_f16(Tensor input, Tensor weight, Tensor out, Tensor bias) {
BufferCUDA workspace(workspace_size); BufferCUDA workspace(workspace_size);
auto stream = getCurrentCUDAStream(); auto stream = getCurrentCUDAStream();
cutlass::Status status = implicit_gemm_op.can_implement(arguments); cutlass::Status status = implicit_gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) { if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot implement"); throw std::runtime_error("cutlass cannot implement");
...@@ -333,4 +338,4 @@ Tensor dwconv_f16(Tensor input, Tensor weight, Tensor out, Tensor bias) { ...@@ -333,4 +338,4 @@ Tensor dwconv_f16(Tensor input, Tensor weight, Tensor out, Tensor bias) {
}); });
return out; return out;
} }
\ No newline at end of file
...@@ -5,4 +5,4 @@ ...@@ -5,4 +5,4 @@
// Tensor depthwise_conv2d_kernel(Tensor A, Tensor B); // Tensor depthwise_conv2d_kernel(Tensor A, Tensor B);
Tensor dwconv_f16(Tensor input, Tensor weight, Tensor out, Tensor bias); Tensor dwconv_f16(Tensor input, Tensor weight, Tensor out, Tensor bias);
\ No newline at end of file
...@@ -9,18 +9,16 @@ ...@@ -9,18 +9,16 @@
using spdlog::fmt_lib::format; using spdlog::fmt_lib::format;
Tensor gemm_batched_fp16( Tensor gemm_batched_fp16(Tensor a, // FP16 row-major [(... batch ...), M, K]
Tensor a, // FP16 row-major [(... batch ...), M, K] Tensor b, // FP16 col-major [(... batch ...), N, K]
Tensor b, // FP16 col-major [(... batch ...), N, K] Tensor out // FP32 row-major [(... batch ...), M, N]
Tensor out // FP32 row-major [(... batch ...), M, N] ) {
) const int M = a.shape[-2];
{ const int K = a.shape[-1];
const int M = a.shape[-2]; const int N = a.shape[-2];
const int K = a.shape[-1];
const int N = a.shape[-2];
const int batch = a.numel() / (M * K); const int batch = a.numel() / (M * K);
using ElementInput = cutlass::half_t; using ElementInput = cutlass::half_t;
using ElementOutput = float; using ElementOutput = float;
using LayoutA = cutlass::layout::RowMajor; using LayoutA = cutlass::layout::RowMajor;
...@@ -28,18 +26,23 @@ Tensor gemm_batched_fp16( ...@@ -28,18 +26,23 @@ Tensor gemm_batched_fp16(
using LayoutO = cutlass::layout::RowMajor; using LayoutO = cutlass::layout::RowMajor;
using Gemm = cutlass::gemm::device::GemmBatched< using Gemm = cutlass::gemm::device::GemmBatched<
ElementInput, LayoutA, ElementInput,
ElementInput, LayoutB, LayoutA,
ElementOutput, LayoutO, ElementInput,
ElementOutput, LayoutB,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, ElementOutput,
LayoutO,
ElementOutput,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<32, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<32, 32, 64>,
cutlass::gemm::GemmShape<16, 8, 16>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination< cutlass::epilogue::thread::LinearCombination<ElementOutput,
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementOutput, ElementOutput>, ElementOutput,
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, ElementOutput>,
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
2>; 2>;
auto sizeA = cutlass::MatrixCoord(M, K); auto sizeA = cutlass::MatrixCoord(M, K);
...@@ -48,8 +51,8 @@ Tensor gemm_batched_fp16( ...@@ -48,8 +51,8 @@ Tensor gemm_batched_fp16(
if (!out.valid()) { if (!out.valid()) {
auto outShape = TensorShape(a.shape.dataExtent); auto outShape = TensorShape(a.shape.dataExtent);
outShape[-1] = N; outShape[-1] = N;
out = Tensor::empty(outShape, Tensor::FP32, a.device()); out = Tensor::empty(outShape, Tensor::FP32, a.device());
} }
assert(K == b.shape[-1]); assert(K == b.shape[-1]);
...@@ -62,28 +65,23 @@ Tensor gemm_batched_fp16( ...@@ -62,28 +65,23 @@ Tensor gemm_batched_fp16(
cutlass::gemm::GemmCoord problemSize(M, N, K); cutlass::gemm::GemmCoord problemSize(M, N, K);
cutlass::TensorRef<ElementInput, LayoutA> refA( cutlass::TensorRef<ElementInput, LayoutA> refA(a.data_ptr<ElementInput>(), LayoutA(a.stride(-2)));
a.data_ptr<ElementInput>(), LayoutA(a.stride(-2))); cutlass::TensorRef<ElementInput, LayoutB> refB(b.data_ptr<ElementInput>(), LayoutB(b.stride(-2)));
cutlass::TensorRef<ElementInput, LayoutB> refB( cutlass::TensorRef<ElementOutput, LayoutO> refO(out.data_ptr<ElementOutput>(), LayoutO(out.stride(-2)));
b.data_ptr<ElementInput>(), LayoutB(b.stride(-2)));
cutlass::TensorRef<ElementOutput, LayoutO> refO( typename Gemm::Arguments arguments{problemSize,
out.data_ptr<ElementOutput>(), LayoutO(out.stride(-2))); refA,
(int)a.stride(-3),
typename Gemm::Arguments arguments{ refB,
problemSize, (int)b.stride(-3),
refA, refO,
(int)a.stride(-3), (int)out.stride(-3),
refB, refO,
(int)b.stride(-3), (int)out.stride(-3),
refO, {ElementOutput(1), ElementOutput(0)},
(int)out.stride(-3), batch};
refO,
(int)out.stride(-3), Gemm op;
{ ElementOutput(1), ElementOutput(0) },
batch
};
Gemm op;
BufferCUDA workspace(Gemm::get_workspace_size(arguments)); BufferCUDA workspace(Gemm::get_workspace_size(arguments));
cutlass::Status status = op.can_implement(arguments); cutlass::Status status = op.can_implement(arguments);
...@@ -102,4 +100,4 @@ Tensor gemm_batched_fp16( ...@@ -102,4 +100,4 @@ Tensor gemm_batched_fp16(
} }
return out; return out;
} }
\ No newline at end of file
...@@ -3,8 +3,7 @@ ...@@ -3,8 +3,7 @@
#include "common.h" #include "common.h"
#include "Tensor.h" #include "Tensor.h"
Tensor gemm_batched_fp16( Tensor gemm_batched_fp16(Tensor a, // FP16 row-major [(... batch ...), M, K]
Tensor a, // FP16 row-major [(... batch ...), M, K] Tensor b, // FP16 col-major [(... batch ...), N, K]
Tensor b, // FP16 col-major [(... batch ...), N, K] Tensor out // FP32 row-major [(... batch ...), M, N]
Tensor out // FP32 row-major [(... batch ...), M, N] );
);
\ No newline at end of file
...@@ -14,10 +14,9 @@ using spdlog::fmt_lib::format; ...@@ -14,10 +14,9 @@ using spdlog::fmt_lib::format;
Tensor gemm_f16(Tensor input, // FP16 Tensor gemm_f16(Tensor input, // FP16
Tensor weight, // FP16 Tensor weight, // FP16
Tensor out, // FP16 Tensor out, // FP16
Tensor bias, Tensor bias,
float alpha float alpha) {
) {
auto N = weight.size(0); auto N = weight.size(0);
auto K = input.size(-1); auto K = input.size(-1);
auto M = input.numel() / K; auto M = input.numel() / K;
...@@ -26,30 +25,34 @@ Tensor gemm_f16(Tensor input, // FP16 ...@@ -26,30 +25,34 @@ Tensor gemm_f16(Tensor input, // FP16
spdlog::debug("gemm_f16: M={} K={} N={}", M, K, N); spdlog::debug("gemm_f16: M={} K={} N={}", M, K, N);
dispatchF16(weight.dtype(), [&]<typename half_t>() { dispatchF16(weight.dtype(), [&]<typename half_t>() {
using ElementOutput = half_t;
using ElementOutput = half_t; using ElementAccumulator = float;
using ElementAccumulator = float;
using ElementComputeEpilogue = half_t; using ElementComputeEpilogue = half_t;
using ElementInputA = half_t; // <- data type of elements in input matrix A using ElementInputA = half_t; // <- data type of elements in input matrix A
using ElementInputB = half_t; // <- data type of elements in input matrix B using ElementInputB = half_t; // <- data type of elements in input matrix B
using LayoutInputA = cutlass::layout::RowMajor; using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor; using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor; using LayoutOutput = cutlass::layout::RowMajor;
// #if CUDA_ARCH >= 800 // #if CUDA_ARCH >= 800
using Gemm = cutlass::gemm::device::Gemm< using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
ElementInputA, cutlass::layout::RowMajor, ElementInputB, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor,
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, ElementInputB,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75>; cutlass::layout::ColumnMajor,
// cutlass::gemm::GemmShape<128, 128, 64>, ElementOutput,
// cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, cutlass::layout::RowMajor,
// cutlass::epilogue::thread::LinearCombination< ElementAccumulator,
// ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value, cutlass::arch::OpClassTensorOp,
// ElementAccumulator, ElementComputeEpilogue>, cutlass::arch::Sm75>;
// cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; // cutlass::gemm::GemmShape<128, 128, 64>,
// cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
auto input_size = cutlass::MatrixCoord(M, K); // cutlass::epilogue::thread::LinearCombination<
// ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
// ElementAccumulator, ElementComputeEpilogue>,
// cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
auto input_size = cutlass::MatrixCoord(M, K);
auto weight_size = cutlass::MatrixCoord(K, N); auto weight_size = cutlass::MatrixCoord(K, N);
auto output_size = cutlass::MatrixCoord(M, N); auto output_size = cutlass::MatrixCoord(M, N);
...@@ -59,8 +62,8 @@ Tensor gemm_f16(Tensor input, // FP16 ...@@ -59,8 +62,8 @@ Tensor gemm_f16(Tensor input, // FP16
if (!out.valid()) { if (!out.valid()) {
auto out_shape = TensorShape(input.shape.dataExtent); auto out_shape = TensorShape(input.shape.dataExtent);
out_shape[-1] = N; out_shape[-1] = N;
out = Tensor::empty(out_shape, input.scalar_type(), input.device()); out = Tensor::empty(out_shape, input.scalar_type(), input.device());
} }
// FIXME: check contiguous of input if dims >= 3 // FIXME: check contiguous of input if dims >= 3
...@@ -83,23 +86,22 @@ Tensor gemm_f16(Tensor input, // FP16 ...@@ -83,23 +86,22 @@ Tensor gemm_f16(Tensor input, // FP16
// constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
cutlass::gemm::GemmCoord problem_size(M, N, K); cutlass::gemm::GemmCoord problem_size(M, N, K);
cutlass::TensorRef<ElementInputA, LayoutInputA> input_ref( cutlass::TensorRef<ElementInputA, LayoutInputA> input_ref(input.data_ptr<ElementInputA>(),
input.data_ptr<ElementInputA>(), LayoutInputA(input.stride(-2))); LayoutInputA(input.stride(-2)));
cutlass::TensorRef<ElementInputB, LayoutInputB> weight_ref( cutlass::TensorRef<ElementInputB, LayoutInputB> weight_ref(weight.data_ptr<ElementInputB>(),
weight.data_ptr<ElementInputB>(), LayoutInputB::packed(weight_size)); LayoutInputB::packed(weight_size));
cutlass::TensorRef<ElementOutput, LayoutOutput> bias_ref( cutlass::TensorRef<ElementOutput, LayoutOutput> bias_ref(
bias.valid() ? bias.data_ptr<ElementOutput>() : out.data_ptr<ElementOutput>(), LayoutOutput(0)); bias.valid() ? bias.data_ptr<ElementOutput>() : out.data_ptr<ElementOutput>(), LayoutOutput(0));
cutlass::TensorRef<ElementOutput, LayoutOutput> out_ref( cutlass::TensorRef<ElementOutput, LayoutOutput> out_ref(out.data_ptr<ElementOutput>(),
out.data_ptr<ElementOutput>(), LayoutOutput(out.stride(-2))); LayoutOutput(out.stride(-2)));
typename Gemm::Arguments arguments{ typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication
problem_size, // <- problem size of matrix multiplication input_ref, // <- reference to matrix A on device
input_ref, // <- reference to matrix A on device weight_ref, // <- reference to matrix B on device
weight_ref, // <- reference to matrix B on device bias_ref, // <- reference to matrix C on device
bias_ref, // <- reference to matrix C on device out_ref, // <- reference to matrix D on device
out_ref, // <- reference to matrix D on device {ElementOutput(alpha), ElementOutput(bias.valid() ? 1.0f : 0.0f)},
{ElementOutput(alpha), ElementOutput(bias.valid() ? 1.0f : 0.0f)}, 1};
1};
Gemm gemm_op; Gemm gemm_op;
// Using the arguments, query for extra workspace required for matrix // Using the arguments, query for extra workspace required for matrix
...@@ -127,9 +129,7 @@ Tensor gemm_f16(Tensor input, // FP16 ...@@ -127,9 +129,7 @@ Tensor gemm_f16(Tensor input, // FP16
if (status != cutlass::Status::kSuccess) { if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot run"); throw std::runtime_error("cutlass cannot run");
} }
}); });
return out; return out;
} }
...@@ -3,10 +3,8 @@ ...@@ -3,10 +3,8 @@
#include "common.h" #include "common.h"
#include "Tensor.h" #include "Tensor.h"
Tensor gemm_f16( Tensor gemm_f16(Tensor input, // FP16
Tensor input, // FP16 Tensor weight, // FP16
Tensor weight, // FP16 Tensor out, // FP16
Tensor out, // FP16 Tensor bias,
Tensor bias, float alpha);
float alpha
);
\ No newline at end of file
...@@ -11,10 +11,10 @@ ...@@ -11,10 +11,10 @@
using spdlog::fmt_lib::format; using spdlog::fmt_lib::format;
Tensor gemm_w8a8_fp16(Tensor input, // INT8 Tensor gemm_w8a8_fp16(Tensor input, // INT8
Tensor weight, // INT8 Tensor weight, // INT8
Tensor out, // FP16 Tensor out, // FP16
half alpha, half alpha,
half beta // FP16 half beta // FP16
) { ) {
auto N = weight.size(0); auto N = weight.size(0);
auto K = input.size(-1); auto K = input.size(-1);
...@@ -23,57 +23,66 @@ Tensor gemm_w8a8_fp16(Tensor input, // INT8 ...@@ -23,57 +23,66 @@ Tensor gemm_w8a8_fp16(Tensor input, // INT8
spdlog::debug("gemm_w8a8: M={} K={} N={}", M, K, N); spdlog::debug("gemm_w8a8: M={} K={} N={}", M, K, N);
using ElementOutput = cutlass::half_t; using ElementOutput = cutlass::half_t;
using ElementAccumulator = int32_t; using ElementAccumulator = int32_t;
using ElementComputeEpilogue = cutlass::half_t; using ElementComputeEpilogue = cutlass::half_t;
using ElementInputA = int8_t; // <- data type of elements in input matrix A using ElementInputA = int8_t; // <- data type of elements in input matrix A
using ElementInputB = int8_t; // <- data type of elements in input matrix B using ElementInputB = int8_t; // <- data type of elements in input matrix B
using LayoutInputA = cutlass::layout::RowMajor; using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor; using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor; using LayoutOutput = cutlass::layout::RowMajor;
// #if CUDA_ARCH >= 800 // #if CUDA_ARCH >= 800
using Gemm = cutlass::gemm::device::Gemm< using Gemm = cutlass::gemm::device::Gemm<
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, int8_t,
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::layout::RowMajor,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, int8_t,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, cutlass::gemm::GemmShape<32, 64, 64>,
cutlass::epilogue::thread::LinearCombination< cutlass::gemm::GemmShape<16, 8, 32>,
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value, cutlass::epilogue::thread::LinearCombination<ElementOutput,
ElementAccumulator, ElementComputeEpilogue>, 128 / cutlass::sizeof_bits<ElementOutput>::value,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; ElementAccumulator,
// #elif CUDA_ARCH >= 750 ElementComputeEpilogue>,
// using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
// cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, 3>;
// ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>; // #elif CUDA_ARCH >= 750
// using Gemm = cutlass::gemm::device::Gemm< // using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration<
// int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, // cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
// ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, // ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>;
// cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, // using Gemm = cutlass::gemm::device::Gemm<
// DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, // int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
// DefaultGemmCfg::InstructionShape, // ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
// cutlass::epilogue::thread::LinearCombination< // cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
// ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value, // DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape,
// ElementAccumulator, ElementComputeEpilogue>>; // DefaultGemmCfg::InstructionShape,
// #elif CUDA_ARCH >= 700 // cutlass::epilogue::thread::LinearCombination<
// using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< // ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
// cutlass::arch::OpClassSimt, cutlass::arch::Sm70, // ElementAccumulator, ElementComputeEpilogue>>;
// ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>; // #elif CUDA_ARCH >= 700
// using Gemm = cutlass::gemm::device::Gemm< // using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration<
// int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, // cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
// ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, // ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>;
// cutlass::arch::OpClassSimt, cutlass::arch::Sm70, // using Gemm = cutlass::gemm::device::Gemm<
// DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, // int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
// DefaultGemmCfg::InstructionShape, // ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
// cutlass::epilogue::thread::LinearCombination< // cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
// ElementOutput, 1, ElementAccumulator, ElementComputeEpilogue>>; // DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape,
// #else // DefaultGemmCfg::InstructionShape,
// #error "Unsupported cuda arch" // cutlass::epilogue::thread::LinearCombination<
// #endif // ElementOutput, 1, ElementAccumulator, ElementComputeEpilogue>>;
// #else
auto input_size = cutlass::MatrixCoord(M, K); // #error "Unsupported cuda arch"
// #endif
auto input_size = cutlass::MatrixCoord(M, K);
auto weight_size = cutlass::MatrixCoord(K, N); auto weight_size = cutlass::MatrixCoord(K, N);
auto output_size = cutlass::MatrixCoord(M, N); auto output_size = cutlass::MatrixCoord(M, N);
...@@ -83,8 +92,8 @@ Tensor gemm_w8a8_fp16(Tensor input, // INT8 ...@@ -83,8 +92,8 @@ Tensor gemm_w8a8_fp16(Tensor input, // INT8
if (!out.valid()) { if (!out.valid()) {
auto out_shape = TensorShape(input.shape.dataExtent); auto out_shape = TensorShape(input.shape.dataExtent);
out_shape[-1] = N; out_shape[-1] = N;
out = Tensor::empty(out_shape, Tensor::FP16, input.device()); out = Tensor::empty(out_shape, Tensor::FP16, input.device());
} }
// FIXME: check contiguous of input if dims >= 3 // FIXME: check contiguous of input if dims >= 3
...@@ -105,21 +114,20 @@ Tensor gemm_w8a8_fp16(Tensor input, // INT8 ...@@ -105,21 +114,20 @@ Tensor gemm_w8a8_fp16(Tensor input, // INT8
// constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
cutlass::gemm::GemmCoord problem_size(M, N, K); cutlass::gemm::GemmCoord problem_size(M, N, K);
cutlass::TensorRef<ElementInputA, LayoutInputA> input_ref( cutlass::TensorRef<ElementInputA, LayoutInputA> input_ref(input.data_ptr<ElementInputA>(),
input.data_ptr<ElementInputA>(), LayoutInputA(input.stride(-2))); LayoutInputA(input.stride(-2)));
cutlass::TensorRef<ElementInputB, LayoutInputB> weight_ref( cutlass::TensorRef<ElementInputB, LayoutInputB> weight_ref(weight.data_ptr<ElementInputB>(),
weight.data_ptr<ElementInputB>(), LayoutInputB::packed(weight_size)); LayoutInputB::packed(weight_size));
cutlass::TensorRef<ElementOutput, LayoutOutput> out_ref( cutlass::TensorRef<ElementOutput, LayoutOutput> out_ref(out.data_ptr<ElementOutput>(),
out.data_ptr<ElementOutput>(), LayoutOutput(out.stride(-2))); LayoutOutput(out.stride(-2)));
typename Gemm::Arguments arguments{ typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication
problem_size, // <- problem size of matrix multiplication input_ref, // <- reference to matrix A on device
input_ref, // <- reference to matrix A on device weight_ref, // <- reference to matrix B on device
weight_ref, // <- reference to matrix B on device out_ref, // <- reference to matrix C on device
out_ref, // <- reference to matrix C on device out_ref, // <- reference to matrix D on device
out_ref, // <- reference to matrix D on device {ElementOutput(alpha), ElementOutput(beta)},
{ElementOutput(alpha), ElementOutput(beta)}, 1};
1};
Gemm gemm_op; Gemm gemm_op;
// Using the arguments, query for extra workspace required for matrix // Using the arguments, query for extra workspace required for matrix
......
...@@ -7,5 +7,4 @@ Tensor gemm_w8a8_fp16(Tensor input, // INT8 ...@@ -7,5 +7,4 @@ Tensor gemm_w8a8_fp16(Tensor input, // INT8
Tensor weight, // INT8 Tensor weight, // INT8
Tensor out, Tensor out,
half scale, half scale,
half bias half bias);
);
\ No newline at end of file
...@@ -6,175 +6,212 @@ void rms_norm(Tensor &out, // [..., hidden_size] ...@@ -6,175 +6,212 @@ void rms_norm(Tensor &out, // [..., hidden_size]
Tensor &weight, // [hidden_size] Tensor &weight, // [hidden_size]
float epsilon, float epsilon,
bool use_quant) { bool use_quant) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024)); dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = getCurrentCUDAStream(); const cudaStream_t stream = getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
if (use_quant) { if (use_quant) {
vllm::rms_norm_kernel<scalar_t, int8_t, true><<<grid, block, 0, stream>>>( vllm::rms_norm_kernel<scalar_t, int8_t, true><<<grid, block, 0, stream>>>(out.data_ptr<int8_t>(),
out.data_ptr<int8_t>(), input.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size); weight.data_ptr<scalar_t>(),
} else { epsilon,
vllm::rms_norm_kernel<scalar_t, scalar_t, false><<<grid, block, 0, stream>>>( num_tokens,
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), hidden_size);
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size); } else {
} vllm::rms_norm_kernel<scalar_t, scalar_t, false><<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(),
}); input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);
}
});
} }
void layernorm_general(Tensor out, Tensor input, Tensor weight, Tensor bias, float epsilon) { void layernorm_general(Tensor out, Tensor input, Tensor weight, Tensor bias, float epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 256)); dim3 block(std::min(hidden_size, 256));
block.x = 32 * ((block.x + 31) / 32); block.x = 32 * ((block.x + 31) / 32);
size_t size_shmem = input.scalar_size() * hidden_size; size_t size_shmem = input.scalar_size() * hidden_size;
const cudaStream_t stream = getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "generalLayerNorm", [&] {
using T = typename packed_as<scalar_t, 2>::type;
vllm::generalLayerNorm<T, half, true><<<grid, block, size_shmem, stream>>>(
reinterpret_cast<T*>(input.data_ptr<scalar_t>()),
weight.valid() ? reinterpret_cast<T*>(weight.data_ptr<scalar_t>()) : nullptr,
bias.valid() ? reinterpret_cast<T*>(bias.data_ptr<scalar_t>()) : nullptr,
reinterpret_cast<T*>(out.data_ptr<scalar_t>()),
epsilon, num_tokens, hidden_size, nullptr, nullptr, nullptr, true
);
});
}
void rms_norm_general(Tensor &out, // [..., hidden_size] const cudaStream_t stream = getCurrentCUDAStream();
Tensor &input, // [..., hidden_size] VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "generalLayerNorm", [&] {
Tensor &weight, // [hidden_size] using T = typename packed_as<scalar_t, 2>::type;
Tensor &scaling, // [tokens] or [1] vllm::generalLayerNorm<T, half, true><<<grid, block, size_shmem, stream>>>(
float epsilon, reinterpret_cast<T *>(input.data_ptr<scalar_t>()),
bool use_per_token_quant) { weight.valid() ? reinterpret_cast<T *>(weight.data_ptr<scalar_t>()) : nullptr,
int hidden_size = input.size(-1); bias.valid() ? reinterpret_cast<T *>(bias.data_ptr<scalar_t>()) : nullptr,
int num_tokens = input.numel() / hidden_size; reinterpret_cast<T *>(out.data_ptr<scalar_t>()),
dim3 grid(num_tokens); epsilon,
dim3 block(std::min(hidden_size, 1024)); num_tokens,
block.x = 32 * ((block.x + 31) / 32); hidden_size,
nullptr,
const cudaStream_t stream = getCurrentCUDAStream(); nullptr,
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "generalLayerNorm", [&] { nullptr,
using T = scalar_t; true);
if (use_per_token_quant) { });
// per-token
vllm::generalLayerNorm<T, half><<<grid, block, 0, stream>>>(
reinterpret_cast<T*>(input.data_ptr<scalar_t>()),
reinterpret_cast<T*>(weight.data_ptr<scalar_t>()), nullptr,
nullptr, epsilon, num_tokens, hidden_size, nullptr, scaling.data_ptr<half>(),
out.data_ptr<int8_t>(), false
);
// input, gamma, beta, normed_output, eps, tokens, hidden_dim, per_tensor_scale, per_token_scale
// normed_output_quant, use_shmem
// out.data_ptr<int8_t>(), input.data_ptr<scalar_t>(),
// weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
} else {
// per-tensor
vllm::generalLayerNorm<T, half><<<grid, block, 0, stream>>>(
reinterpret_cast<T*>(input.data_ptr<scalar_t>()),
reinterpret_cast<T*>(weight.data_ptr<scalar_t>()), nullptr,
nullptr, epsilon, num_tokens, hidden_size, scaling.data_ptr<half>(), nullptr,
out.data_ptr<int8_t>(), false
);
}
});
} }
void rms_norm_general_fuse_sum(Tensor &out, // [..., hidden_size] void rms_norm_general(Tensor &out, // [..., hidden_size]
Tensor &input, // [..., hidden_size] Tensor &input, // [..., hidden_size]
Tensor &weight, // [hidden_size] Tensor &weight, // [hidden_size]
Tensor &input_sum, // [tokens] or [1] Tensor &scaling, // [tokens] or [1]
Tensor &scaling, // [tokens] or [1] float epsilon,
float epsilon, bool use_per_token_quant) {
bool use_per_token_quant) { int hidden_size = input.size(-1);
int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size;
int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens);
dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024));
dim3 block(std::min(hidden_size, 1024)); block.x = 32 * ((block.x + 31) / 32);
block.x = 32 * ((block.x + 31) / 32);
const cudaStream_t stream = getCurrentCUDAStream();
const cudaStream_t stream = getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "generalLayerNorm", [&] {
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "generalLayerNorm_fuse_sum", [&] { using T = scalar_t;
using T = scalar_t; if (use_per_token_quant) {
if (use_per_token_quant) { // per-token
// per-token vllm::generalLayerNorm<T, half>
vllm::generalLayerNorm_fuse_sum<T, half><<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(reinterpret_cast<T *>(input.data_ptr<scalar_t>()),
reinterpret_cast<T*>(input.data_ptr<scalar_t>()), reinterpret_cast<T *>(weight.data_ptr<scalar_t>()),
reinterpret_cast<T*>(weight.data_ptr<scalar_t>()), nullptr, nullptr,
nullptr, epsilon, num_tokens, hidden_size, input_sum.data_ptr<half>(), nullptr, scaling.data_ptr<half>(), nullptr,
out.data_ptr<int8_t>(), false epsilon,
); num_tokens,
// input, gamma, beta, normed_output, eps, tokens, hidden_dim, per_tensor_scale, per_token_scale hidden_size,
// normed_output_quant, use_shmem nullptr,
// out.data_ptr<int8_t>(), input.data_ptr<scalar_t>(), scaling.data_ptr<half>(),
// weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size); out.data_ptr<int8_t>(),
} else { false);
// per-tensor // input, gamma, beta, normed_output, eps, tokens, hidden_dim, per_tensor_scale, per_token_scale
// Rasing error here // normed_output_quant, use_shmem
// Not implemented per-tensor input_sum // out.data_ptr<int8_t>(), input.data_ptr<scalar_t>(),
assert(false); // weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
} else {
vllm::generalLayerNorm_fuse_sum<T, half><<<grid, block, 0, stream>>>( // per-tensor
reinterpret_cast<T*>(input.data_ptr<scalar_t>()), vllm::generalLayerNorm<T, half>
reinterpret_cast<T*>(weight.data_ptr<scalar_t>()), nullptr, <<<grid, block, 0, stream>>>(reinterpret_cast<T *>(input.data_ptr<scalar_t>()),
nullptr, epsilon, num_tokens, hidden_size, nullptr, scaling.data_ptr<half>(), nullptr, reinterpret_cast<T *>(weight.data_ptr<scalar_t>()),
out.data_ptr<int8_t>(), false nullptr,
); nullptr,
} epsilon,
}); num_tokens,
hidden_size,
scaling.data_ptr<half>(),
nullptr,
out.data_ptr<int8_t>(),
false);
}
});
} }
void rms_norm_general_fuse_sum(Tensor &out, // [..., hidden_size]
Tensor &input, // [..., hidden_size]
Tensor &weight, // [hidden_size]
Tensor &input_sum, // [tokens] or [1]
Tensor &scaling, // [tokens] or [1]
float epsilon,
bool use_per_token_quant) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
block.x = 32 * ((block.x + 31) / 32);
const cudaStream_t stream = getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "generalLayerNorm_fuse_sum", [&] {
using T = scalar_t;
if (use_per_token_quant) {
// per-token
vllm::generalLayerNorm_fuse_sum<T, half>
<<<grid, block, 0, stream>>>(reinterpret_cast<T *>(input.data_ptr<scalar_t>()),
reinterpret_cast<T *>(weight.data_ptr<scalar_t>()),
nullptr,
nullptr,
epsilon,
num_tokens,
hidden_size,
input_sum.data_ptr<half>(),
nullptr,
scaling.data_ptr<half>(),
out.data_ptr<int8_t>(),
false);
// input, gamma, beta, normed_output, eps, tokens, hidden_dim, per_tensor_scale, per_token_scale
// normed_output_quant, use_shmem
// out.data_ptr<int8_t>(), input.data_ptr<scalar_t>(),
// weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
} else {
// per-tensor
// Rasing error here
// Not implemented per-tensor input_sum
assert(false);
vllm::generalLayerNorm_fuse_sum<T, half>
<<<grid, block, 0, stream>>>(reinterpret_cast<T *>(input.data_ptr<scalar_t>()),
reinterpret_cast<T *>(weight.data_ptr<scalar_t>()),
nullptr,
nullptr,
epsilon,
num_tokens,
hidden_size,
nullptr,
scaling.data_ptr<half>(),
nullptr,
out.data_ptr<int8_t>(),
false);
}
});
}
void invoke_dequant_add_residual_rms_norm_quant( void invoke_dequant_add_residual_rms_norm_quant(Tensor &out, // [..., hidden_size]
Tensor &out, // [..., hidden_size] Tensor &input, // [..., hidden_size]
Tensor &input, // [..., hidden_size] Tensor &residual, // [..., hidden_size]
Tensor &residual, // [..., hidden_size] Tensor &gamma, // [hidden_size]
Tensor &gamma, // [hidden_size] half scale,
half scale, float epsilon) {
float epsilon) { int hidden_size = input.size(-1);
int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size;
int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens);
dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024));
dim3 block(std::min(hidden_size, 1024)); const cudaStream_t stream = getCurrentCUDAStream();
const cudaStream_t stream = getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES(residual.scalar_type(), "dequant_add_residual_rms_norm_quant_kernel", [&] {
VLLM_DISPATCH_FLOATING_TYPES( vllm::dequant_add_residual_rms_norm_quant_kernel<scalar_t, half, false>
residual.scalar_type(), "dequant_add_residual_rms_norm_quant_kernel", <<<grid, block, 0, stream>>>(input.data_ptr<int32_t>(),
[&] { residual.data_ptr<scalar_t>(),
vllm::dequant_add_residual_rms_norm_quant_kernel<scalar_t, half, false> out.data_ptr<int8_t>(),
<<<grid, block, 0, stream>>>( gamma.data_ptr<scalar_t>(),
input.data_ptr<int32_t>(), residual.data_ptr<scalar_t>(), epsilon,
out.data_ptr<int8_t>(), gamma.data_ptr<scalar_t>(), epsilon, scale,
scale, num_tokens, hidden_size); num_tokens,
}); hidden_size);
});
} }
void invoke_dequant_add_residual_rms_norm_quant( void invoke_dequant_add_residual_rms_norm_quant(Tensor &out, // [..., hidden_size]
Tensor &out, // [..., hidden_size] Tensor &input, // [..., hidden_size]
Tensor &input, // [..., hidden_size] Tensor &residual, // [..., hidden_size]
Tensor &residual, // [..., hidden_size] Tensor &gamma, // [hidden_size]
Tensor &gamma, // [hidden_size] Tensor &scale, // [num_tokens]
Tensor &scale, // [num_tokens] float epsilon) {
float epsilon) { int hidden_size = input.size(-1);
int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size;
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024)); dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = getCurrentCUDAStream(); const cudaStream_t stream = getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(residual.scalar_type(), "dequant_add_residual_rms_norm_quant_kernel", [&] {
residual.scalar_type(), "dequant_add_residual_rms_norm_quant_kernel", vllm::dequant_add_residual_rms_norm_quant_kernel<scalar_t, half *, true>
[&] { <<<grid, block, 0, stream>>>(input.data_ptr<int32_t>(),
vllm::dequant_add_residual_rms_norm_quant_kernel<scalar_t, half*, true> residual.data_ptr<scalar_t>(),
<<<grid, block, 0, stream>>>( out.data_ptr<int8_t>(),
input.data_ptr<int32_t>(), residual.data_ptr<scalar_t>(), gamma.data_ptr<scalar_t>(),
out.data_ptr<int8_t>(), gamma.data_ptr<scalar_t>(), epsilon, epsilon,
scale.data_ptr<half>(), num_tokens, hidden_size); scale.data_ptr<half>(),
}); num_tokens,
hidden_size);
});
} }
...@@ -7,36 +7,36 @@ ...@@ -7,36 +7,36 @@
void rms_norm(Tensor &out, // [num_tokens, hidden_size] void rms_norm(Tensor &out, // [num_tokens, hidden_size]
Tensor &input, // [num_tokens, hidden_size] Tensor &input, // [num_tokens, hidden_size]
Tensor &weight, // [hidden_size] Tensor &weight, // [hidden_size]
float epsilon, bool use_quant); float epsilon,
bool use_quant);
void layernorm_general(Tensor out, Tensor input, Tensor weight, Tensor bias, float epsilon); void layernorm_general(Tensor out, Tensor input, Tensor weight, Tensor bias, float epsilon);
void rms_norm_general(Tensor &out, // [..., hidden_size] void rms_norm_general(Tensor &out, // [..., hidden_size]
Tensor &input, // [..., hidden_size] Tensor &input, // [..., hidden_size]
Tensor &weight, // [hidden_size] Tensor &weight, // [hidden_size]
Tensor &scaling, // [tokens] or [1] Tensor &scaling, // [tokens] or [1]
float epsilon, float epsilon,
bool use_per_token_quant); bool use_per_token_quant);
void rms_norm_general_fuse_sum(Tensor &out, // [..., hidden_size] void rms_norm_general_fuse_sum(Tensor &out, // [..., hidden_size]
Tensor &input, // [..., hidden_size] Tensor &input, // [..., hidden_size]
Tensor &weight, // [hidden_size] Tensor &weight, // [hidden_size]
Tensor &input_sum, // [tokens] or [1] Tensor &input_sum, // [tokens] or [1]
Tensor &scaling, // [tokens] or [1] Tensor &scaling, // [tokens] or [1]
float epsilon, float epsilon,
bool use_per_token_quant); bool use_per_token_quant);
void invoke_dequant_add_residual_rms_norm_quant( void invoke_dequant_add_residual_rms_norm_quant(Tensor &out, // [..., hidden_size]
Tensor &out, // [..., hidden_size] Tensor &input, // [..., hidden_size]
Tensor &input, // [..., hidden_size] Tensor &residual, // [..., hidden_size]
Tensor &residual, // [..., hidden_size] Tensor &gamma, // [hidden_size]
Tensor &gamma, // [hidden_size] half scale,
half scale, float epsilon); float epsilon);
void invoke_dequant_add_residual_rms_norm_quant( void invoke_dequant_add_residual_rms_norm_quant(Tensor &out, // [..., hidden_size]
Tensor &out, // [..., hidden_size] Tensor &input, // [..., hidden_size]
Tensor &input, // [..., hidden_size] Tensor &residual, // [..., hidden_size]
Tensor &residual, // [..., hidden_size] Tensor &gamma, // [hidden_size]
Tensor &gamma, // [hidden_size] Tensor &scale, // [num_tokens]
Tensor &scale, // [num_tokens] float epsilon);
float epsilon);
\ No newline at end of file
...@@ -5,13 +5,12 @@ ...@@ -5,13 +5,12 @@
#include "utils.cuh" #include "utils.cuh"
#include "reduction_utils.cuh" #include "reduction_utils.cuh"
namespace vllm { namespace vllm {
// from TRTLLM // from TRTLLM
template <typename Tf, typename T> template<typename Tf, typename T>
__inline__ __device__ Tf compute_layernorm(Tf val, float s_mean, float s_variance, const T* gamma, const T* beta, int i) __inline__ __device__ Tf
{ compute_layernorm(Tf val, float s_mean, float s_variance, const T *gamma, const T *beta, int i) {
Tf ret = (val - s_mean) * s_variance; Tf ret = (val - s_mean) * s_variance;
if (gamma != nullptr) { if (gamma != nullptr) {
ret = ret * cuda_cast<Tf>(gamma[i]); ret = ret * cuda_cast<Tf>(gamma[i]);
...@@ -44,353 +43,320 @@ __inline__ __device__ Tf compute_layernorm(Tf val, float s_mean, float s_varianc ...@@ -44,353 +43,320 @@ __inline__ __device__ Tf compute_layernorm(Tf val, float s_mean, float s_varianc
* amax per row. A final pass scales to int8 accordingly, and writes output to * amax per row. A final pass scales to int8 accordingly, and writes output to
* normed_output_quant. * normed_output_quant.
*/ */
template <typename T, typename scale_type, bool USE_DIFF_OF_SQUARES = false> template<typename T, typename scale_type, bool USE_DIFF_OF_SQUARES = false>
__global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta, T* normed_output, const float eps, __global__ void generalLayerNorm(const T *input,
int tokens, int hidden_dim, const scale_type* scale_orig_quant_per_tensor, scale_type* scale_orig_quant_per_token, const T *gamma,
int8_t* normed_output_quant, bool use_shmem) const T *beta,
{ T *normed_output,
const float eps,
int tokens,
int hidden_dim,
const scale_type *scale_orig_quant_per_tensor,
scale_type *scale_orig_quant_per_token,
int8_t *normed_output_quant,
bool use_shmem) {
constexpr auto num_elems_T = num_elems<T>::value; constexpr auto num_elems_T = num_elems<T>::value;
using int8_packed_t = typename packed_as<int8_t, num_elems_T>::type; using int8_packed_t = typename packed_as<int8_t, num_elems_T>::type;
using float_packed_t = typename packed_as<float, num_elems_T>::type; using float_packed_t = typename packed_as<float, num_elems_T>::type;
using T_scalar = typename packed_as<T, 1>::type; using T_scalar = typename packed_as<T, 1>::type;
extern __shared__ __align__(sizeof(float)) char _shmem[]; extern __shared__ __align__(sizeof(float)) char _shmem[];
T* shmem = reinterpret_cast<T*>(_shmem); T *shmem = reinterpret_cast<T *>(_shmem);
__shared__ float s_mean; __shared__ float s_mean;
__shared__ float s_variance; __shared__ float s_variance;
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
const int bidx = blockIdx.x; const int bidx = blockIdx.x;
float mean = 0.0f; float mean = 0.0f;
float variance = 0.0f; float variance = 0.0f;
float local_sum = 0.0f; float local_sum = 0.0f;
float local_var_sum = 0.0f; float local_var_sum = 0.0f;
const int n_elems = hidden_dim / num_elems_T; const int n_elems = hidden_dim / num_elems_T;
for (int i = tidx; i < n_elems; i += blockDim.x) for (int i = tidx; i < n_elems; i += blockDim.x) {
{
const T val = input[bidx * n_elems + i]; const T val = input[bidx * n_elems + i];
if (use_shmem) if (use_shmem) {
{
shmem[i] = val; shmem[i] = val;
} }
const float_packed_t val_f = cuda_cast<float_packed_t>(val); const float_packed_t val_f = cuda_cast<float_packed_t>(val);
local_sum += cuda_sum<float>(val_f); local_sum += cuda_sum<float>(val_f);
if (USE_DIFF_OF_SQUARES) if (USE_DIFF_OF_SQUARES) {
{
local_var_sum += cuda_sum<float>(val_f * val_f); local_var_sum += cuda_sum<float>(val_f * val_f);
} }
} }
if (USE_DIFF_OF_SQUARES) if (USE_DIFF_OF_SQUARES) {
{
float packed[2] = {local_sum, local_var_sum}; float packed[2] = {local_sum, local_var_sum};
blockReduceSumV2<float, 2>(packed); blockReduceSumV2<float, 2>(packed);
mean = packed[0]; mean = packed[0];
variance = packed[1]; variance = packed[1];
} } else {
else
{
mean = blockReduceSum(local_sum); mean = blockReduceSum(local_sum);
} }
if (threadIdx.x == 0) if (threadIdx.x == 0) {
{ mean = mean / hidden_dim;
mean = mean / hidden_dim;
s_mean = mean; s_mean = mean;
if (USE_DIFF_OF_SQUARES) if (USE_DIFF_OF_SQUARES) {
{ variance = (variance / hidden_dim) - (mean * mean); // Var[x] = E[x²] - E[x]²
variance = (variance / hidden_dim) - (mean * mean); // Var[x] = E[x²] - E[x]²
s_variance = rsqrtf(variance + eps); s_variance = rsqrtf(variance + eps);
} }
} }
__syncthreads(); __syncthreads();
if (!USE_DIFF_OF_SQUARES) if (!USE_DIFF_OF_SQUARES) {
{ for (int i = tidx; i < n_elems; i += blockDim.x) {
for (int i = tidx; i < n_elems; i += blockDim.x) const T val = use_shmem ? shmem[i] : input[bidx * n_elems + i];
{
const T val = use_shmem ? shmem[i] : input[bidx * n_elems + i];
float_packed_t diff = cuda_cast<float_packed_t>(val) - s_mean; float_packed_t diff = cuda_cast<float_packed_t>(val) - s_mean;
local_var_sum += cuda_sum<float>(diff * diff); local_var_sum += cuda_sum<float>(diff * diff);
} }
variance = blockReduceSum(local_var_sum); variance = blockReduceSum(local_var_sum);
if (threadIdx.x == 0) if (threadIdx.x == 0) {
{
s_variance = rsqrtf(variance / hidden_dim + eps); s_variance = rsqrtf(variance / hidden_dim + eps);
} }
__syncthreads(); __syncthreads();
} }
const bool with_per_token_scaling = scale_orig_quant_per_token != nullptr; const bool with_per_token_scaling = scale_orig_quant_per_token != nullptr;
const bool with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr; const bool with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr;
const float_packed_t scale_orig_quant const float_packed_t scale_orig_quant =
= cuda_cast<float_packed_t>(with_per_tensor_scaling ? __half2float(*scale_orig_quant_per_tensor) : 0.0f); cuda_cast<float_packed_t>(with_per_tensor_scaling ? __half2float(*scale_orig_quant_per_tensor) : 0.0f);
T_scalar amax = 1e-6f; T_scalar amax = 1e-6f;
for (int i = tidx; i < n_elems; i += blockDim.x) for (int i = tidx; i < n_elems; i += blockDim.x) {
{ const int index = bidx * n_elems + i;
const int index = bidx * n_elems + i;
const float_packed_t val_f = cuda_cast<float_packed_t>(use_shmem ? shmem[i] : input[index]); const float_packed_t val_f = cuda_cast<float_packed_t>(use_shmem ? shmem[i] : input[index]);
const T val = cuda_cast<T>(compute_layernorm(val_f, s_mean, s_variance, gamma, beta, i)); const T val = cuda_cast<T>(compute_layernorm(val_f, s_mean, s_variance, gamma, beta, i));
if (with_per_token_scaling) if (with_per_token_scaling) {
{
amax = cuda_max(cuda_max<T_scalar, T>(cuda_abs(val)), amax); amax = cuda_max(cuda_max<T_scalar, T>(cuda_abs(val)), amax);
if (use_shmem) if (use_shmem) {
{
shmem[i] = val; shmem[i] = val;
} }
} } else if (with_per_tensor_scaling) {
else if (with_per_tensor_scaling) reinterpret_cast<int8_packed_t *>(normed_output_quant)[index] =
{ cuda_cast<int8_packed_t>(cuda_cast<float_packed_t>(val) * scale_orig_quant);
reinterpret_cast<int8_packed_t*>(normed_output_quant)[index] } else {
= cuda_cast<int8_packed_t>(cuda_cast<float_packed_t>(val) * scale_orig_quant);
}
else
{
normed_output[index] = val; normed_output[index] = val;
} }
} }
if (with_per_token_scaling) if (with_per_token_scaling) {
{ float abs_max_f = blockAllReduceMax(cuda_cast<float>(amax));
float abs_max_f = blockAllReduceMax(cuda_cast<float>(amax));
const float dynamic_per_token_scale = 127.f / abs_max_f; const float dynamic_per_token_scale = 127.f / abs_max_f;
for (int i = tidx; i < n_elems; i += blockDim.x) for (int i = tidx; i < n_elems; i += blockDim.x) {
{ const int index = bidx * n_elems + i;
const int index = bidx * n_elems + i;
float_packed_t val_f = cuda_cast<float_packed_t>(use_shmem ? shmem[i] : input[index]); float_packed_t val_f = cuda_cast<float_packed_t>(use_shmem ? shmem[i] : input[index]);
if (!use_shmem) if (!use_shmem) {
{
val_f = compute_layernorm(val_f, s_mean, s_variance, gamma, beta, i); val_f = compute_layernorm(val_f, s_mean, s_variance, gamma, beta, i);
} }
reinterpret_cast<int8_packed_t*>(normed_output_quant)[index] reinterpret_cast<int8_packed_t *>(normed_output_quant)[index] =
= cuda_cast<int8_packed_t>(val_f * cuda_cast<float_packed_t>(dynamic_per_token_scale)); cuda_cast<int8_packed_t>(val_f * cuda_cast<float_packed_t>(dynamic_per_token_scale));
} }
if (tidx == 0) if (tidx == 0) {
{
scale_orig_quant_per_token[bidx] = abs_max_f / 127.f; scale_orig_quant_per_token[bidx] = abs_max_f / 127.f;
} }
} }
} }
template<typename T, typename scale_type, bool USE_DIFF_OF_SQUARES = false>
template <typename T, typename scale_type, bool USE_DIFF_OF_SQUARES = false> __global__ void generalLayerNorm_fuse_sum(const T *input,
__global__ void generalLayerNorm_fuse_sum(const T* input, const T* gamma, const T* beta, T* normed_output, const float eps, const T *gamma,
int tokens, int hidden_dim, scale_type* input_sum, const scale_type* scale_orig_quant_per_tensor, scale_type* scale_orig_quant_per_token, const T *beta,
int8_t* normed_output_quant, bool use_shmem) T *normed_output,
{ const float eps,
int tokens,
int hidden_dim,
scale_type *input_sum,
const scale_type *scale_orig_quant_per_tensor,
scale_type *scale_orig_quant_per_token,
int8_t *normed_output_quant,
bool use_shmem) {
constexpr auto num_elems_T = num_elems<T>::value; constexpr auto num_elems_T = num_elems<T>::value;
using int8_packed_t = typename packed_as<int8_t, num_elems_T>::type; using int8_packed_t = typename packed_as<int8_t, num_elems_T>::type;
using float_packed_t = typename packed_as<float, num_elems_T>::type; using float_packed_t = typename packed_as<float, num_elems_T>::type;
using T_scalar = typename packed_as<T, 1>::type; using T_scalar = typename packed_as<T, 1>::type;
extern __shared__ __align__(sizeof(float)) char _shmem[]; extern __shared__ __align__(sizeof(float)) char _shmem[];
T* shmem = reinterpret_cast<T*>(_shmem); T *shmem = reinterpret_cast<T *>(_shmem);
__shared__ float s_mean; __shared__ float s_mean;
__shared__ float s_variance; __shared__ float s_variance;
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
const int bidx = blockIdx.x; const int bidx = blockIdx.x;
float mean = 0.0f; float mean = 0.0f;
float variance = 0.0f; float variance = 0.0f;
float local_sum = 0.0f; float local_sum = 0.0f;
float local_var_sum = 0.0f; float local_var_sum = 0.0f;
const int n_elems = hidden_dim / num_elems_T; const int n_elems = hidden_dim / num_elems_T;
for (int i = tidx; i < n_elems; i += blockDim.x) for (int i = tidx; i < n_elems; i += blockDim.x) {
{
const T val = input[bidx * n_elems + i]; const T val = input[bidx * n_elems + i];
if (use_shmem) if (use_shmem) {
{
shmem[i] = val; shmem[i] = val;
} }
const float_packed_t val_f = cuda_cast<float_packed_t>(val); const float_packed_t val_f = cuda_cast<float_packed_t>(val);
local_sum += cuda_sum<float>(val_f); local_sum += cuda_sum<float>(val_f);
if (USE_DIFF_OF_SQUARES) if (USE_DIFF_OF_SQUARES) {
{
local_var_sum += cuda_sum<float>(val_f * val_f); local_var_sum += cuda_sum<float>(val_f * val_f);
} }
} }
if (USE_DIFF_OF_SQUARES) if (USE_DIFF_OF_SQUARES) {
{
float packed[2] = {local_sum, local_var_sum}; float packed[2] = {local_sum, local_var_sum};
blockReduceSumV2<float, 2>(packed); blockReduceSumV2<float, 2>(packed);
mean = packed[0]; mean = packed[0];
variance = packed[1]; variance = packed[1];
} } else {
else
{
mean = blockReduceSum(local_sum); mean = blockReduceSum(local_sum);
} }
if (threadIdx.x == 0) if (threadIdx.x == 0) {
{ mean = mean / hidden_dim;
mean = mean / hidden_dim;
s_mean = mean; s_mean = mean;
if (USE_DIFF_OF_SQUARES) if (USE_DIFF_OF_SQUARES) {
{ variance = (variance / hidden_dim) - (mean * mean); // Var[x] = E[x²] - E[x]²
variance = (variance / hidden_dim) - (mean * mean); // Var[x] = E[x²] - E[x]²
s_variance = rsqrtf(variance + eps); s_variance = rsqrtf(variance + eps);
} }
} }
__syncthreads(); __syncthreads();
if (!USE_DIFF_OF_SQUARES) if (!USE_DIFF_OF_SQUARES) {
{ for (int i = tidx; i < n_elems; i += blockDim.x) {
for (int i = tidx; i < n_elems; i += blockDim.x) const T val = use_shmem ? shmem[i] : input[bidx * n_elems + i];
{
const T val = use_shmem ? shmem[i] : input[bidx * n_elems + i];
float_packed_t diff = cuda_cast<float_packed_t>(val) - s_mean; float_packed_t diff = cuda_cast<float_packed_t>(val) - s_mean;
local_var_sum += cuda_sum<float>(diff * diff); local_var_sum += cuda_sum<float>(diff * diff);
} }
variance = blockReduceSum(local_var_sum); variance = blockReduceSum(local_var_sum);
if (threadIdx.x == 0) if (threadIdx.x == 0) {
{
s_variance = rsqrtf(variance / hidden_dim + eps); s_variance = rsqrtf(variance / hidden_dim + eps);
} }
__syncthreads(); __syncthreads();
} }
const bool with_per_token_scaling = scale_orig_quant_per_token != nullptr; const bool with_per_token_scaling = scale_orig_quant_per_token != nullptr;
const bool with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr; const bool with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr;
const float_packed_t scale_orig_quant const float_packed_t scale_orig_quant =
= cuda_cast<float_packed_t>(with_per_tensor_scaling ? __half2float(*scale_orig_quant_per_tensor) : 0.0f); cuda_cast<float_packed_t>(with_per_tensor_scaling ? __half2float(*scale_orig_quant_per_tensor) : 0.0f);
T_scalar amax = 1e-6f; T_scalar amax = 1e-6f;
T_scalar sum = 0.0f; T_scalar sum = 0.0f;
for (int i = tidx; i < n_elems; i += blockDim.x) for (int i = tidx; i < n_elems; i += blockDim.x) {
{ const int index = bidx * n_elems + i;
const int index = bidx * n_elems + i;
const float_packed_t val_f = cuda_cast<float_packed_t>(use_shmem ? shmem[i] : input[index]); const float_packed_t val_f = cuda_cast<float_packed_t>(use_shmem ? shmem[i] : input[index]);
const T val = cuda_cast<T>(compute_layernorm(val_f, s_mean, s_variance, gamma, beta, i)); const T val = cuda_cast<T>(compute_layernorm(val_f, s_mean, s_variance, gamma, beta, i));
if (with_per_token_scaling) if (with_per_token_scaling) {
{
amax = cuda_max(cuda_max<T_scalar, T>(cuda_abs(val)), amax); amax = cuda_max(cuda_max<T_scalar, T>(cuda_abs(val)), amax);
sum += cuda_sum<float>(val); sum += cuda_sum<float>(val);
if (use_shmem) if (use_shmem) {
{
shmem[i] = val; shmem[i] = val;
} }
} } else if (with_per_tensor_scaling) {
else if (with_per_tensor_scaling) reinterpret_cast<int8_packed_t *>(normed_output_quant)[index] =
{ cuda_cast<int8_packed_t>(cuda_cast<float_packed_t>(val) * scale_orig_quant);
reinterpret_cast<int8_packed_t*>(normed_output_quant)[index] } else {
= cuda_cast<int8_packed_t>(cuda_cast<float_packed_t>(val) * scale_orig_quant);
}
else
{
normed_output[index] = val; normed_output[index] = val;
} }
} }
if (with_per_token_scaling) if (with_per_token_scaling) {
{ float abs_max_f = blockAllReduceMax(cuda_cast<float>(amax));
float abs_max_f = blockAllReduceMax(cuda_cast<float>(amax)); float sum_f = blockAllReduceSum(cuda_cast<float>(sum));
float sum_f = blockAllReduceSum(cuda_cast<float>(sum));
const float dynamic_per_token_scale = 127.f / abs_max_f; const float dynamic_per_token_scale = 127.f / abs_max_f;
for (int i = tidx; i < n_elems; i += blockDim.x) for (int i = tidx; i < n_elems; i += blockDim.x) {
{ const int index = bidx * n_elems + i;
const int index = bidx * n_elems + i;
float_packed_t val_f = cuda_cast<float_packed_t>(use_shmem ? shmem[i] : input[index]); float_packed_t val_f = cuda_cast<float_packed_t>(use_shmem ? shmem[i] : input[index]);
if (!use_shmem) if (!use_shmem) {
{
val_f = compute_layernorm(val_f, s_mean, s_variance, gamma, beta, i); val_f = compute_layernorm(val_f, s_mean, s_variance, gamma, beta, i);
} }
reinterpret_cast<int8_packed_t*>(normed_output_quant)[index] reinterpret_cast<int8_packed_t *>(normed_output_quant)[index] =
= cuda_cast<int8_packed_t>(val_f * cuda_cast<float_packed_t>(dynamic_per_token_scale)); cuda_cast<int8_packed_t>(val_f * cuda_cast<float_packed_t>(dynamic_per_token_scale));
} }
if (tidx == 0) if (tidx == 0) {
{
scale_orig_quant_per_token[bidx] = abs_max_f / 127.f; scale_orig_quant_per_token[bidx] = abs_max_f / 127.f;
input_sum[bidx] = sum_f; input_sum[bidx] = sum_f;
} }
} }
} }
// TODO(woosuk): Further optimize this kernel. // TODO(woosuk): Further optimize this kernel.
template <typename scalar_t, typename out_type, bool use_quant> template<typename scalar_t, typename out_type, bool use_quant>
__global__ void __global__ void rms_norm_kernel(out_type *__restrict__ out, // [..., hidden_size]
rms_norm_kernel(out_type *__restrict__ out, // [..., hidden_size] const scalar_t *__restrict__ input, // [..., hidden_size]
const scalar_t *__restrict__ input, // [..., hidden_size] const scalar_t *__restrict__ weight, // [hidden_size]
const scalar_t *__restrict__ weight, // [hidden_size] const float epsilon,
const float epsilon, const int num_tokens, const int num_tokens,
const int hidden_size) { const int hidden_size) {
__shared__ float s_variance; __shared__ float s_variance;
float variance = 0.0f; float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float)input[blockIdx.x * hidden_size + idx]; const float x = (float)input[blockIdx.x * hidden_size + idx];
variance += x * x; variance += x * x;
}
variance = blockReduceSum<float>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
if constexpr (use_quant) {
out[blockIdx.x * hidden_size + idx] = float_to_int8_rn(
((float)(x * s_variance)) * (float)(weight[idx]));
} else {
out[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
} }
} variance = blockReduceSum<float>(variance);
} if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
if constexpr (use_quant) {
out[blockIdx.x * hidden_size + idx] = float_to_int8_rn(((float)(x * s_variance)) * (float)(weight[idx]));
} else {
out[blockIdx.x * hidden_size + idx] = ((scalar_t)(x * s_variance)) * weight[idx];
}
}
}
template<typename T, typename scale_type, bool use_per_token_dequant>
__global__ void dequant_add_residual_rms_norm_quant_kernel(const int32_t *__restrict__ input,
T *__restrict__ residual,
int8_t *__restrict__ output,
const T *__restrict__ gamma,
const float layernorm_eps,
const scale_type scale,
int num_tokens,
int hidden_size) {
// layernorm module in the T5 style No bias and no subtraction of mean.
const int tid = threadIdx.x;
__shared__ float s_variance;
float variance = 0.0f;
template <typename T, typename scale_type, bool use_per_token_dequant> float local_var_sum = 0.0f;
__global__ void dequant_add_residual_rms_norm_quant_kernel( for (int i = tid; i < hidden_size; i += blockDim.x) {
const int32_t *__restrict__ input, T *__restrict__ residual, float diff = 0.0f;
int8_t *__restrict__ output, const T *__restrict__ gamma, if constexpr (use_per_token_dequant) {
const float layernorm_eps, const scale_type scale, int num_tokens, int hidden_size) { diff = ((((float)input[blockIdx.x * hidden_size + i]) * __half2float(scale[blockIdx.x])) +
// layernorm module in the T5 style No bias and no subtraction of mean. (float)residual[blockIdx.x * hidden_size + i]);
const int tid = threadIdx.x; } else {
diff = ((((float)input[blockIdx.x * hidden_size + i]) * __half2float(scale)) +
(float)residual[blockIdx.x * hidden_size + i]);
}
residual[blockIdx.x * hidden_size + i] = (T)diff;
local_var_sum += diff * diff;
}
variance = blockReduceSum(local_var_sum);
__shared__ float s_variance; if (threadIdx.x == 0) {
float variance = 0.0f; s_variance = rsqrtf(variance / (float)hidden_size + layernorm_eps);
}
__syncthreads();
float local_var_sum = 0.0f; for (int i = tid; i < hidden_size; i += blockDim.x) {
for (int i = tid; i < hidden_size; i += blockDim.x) { output[blockIdx.x * hidden_size + i] =
float diff = 0.0f; float_to_int8_rn((((float)(residual[blockIdx.x * hidden_size + i])) * s_variance) * (float)(gamma[i]));
if constexpr (use_per_token_dequant) {
diff = ((((float)input[blockIdx.x * hidden_size + i]) * __half2float(scale[blockIdx.x])) +
(float)residual[blockIdx.x * hidden_size + i]);
} else {
diff = ((((float)input[blockIdx.x * hidden_size + i]) * __half2float(scale)) +
(float)residual[blockIdx.x * hidden_size + i]);
} }
residual[blockIdx.x * hidden_size + i] = (T)diff;
local_var_sum += diff * diff;
}
variance = blockReduceSum(local_var_sum);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / (float)hidden_size + layernorm_eps);
}
__syncthreads();
for (int i = tid; i < hidden_size; i += blockDim.x) {
output[blockIdx.x * hidden_size + i] = float_to_int8_rn(
(((float)(residual[blockIdx.x * hidden_size + i])) * s_variance) *
(float)(gamma[i]));
}
} }
} // namespace vllm } // namespace vllm
...@@ -11,7 +11,7 @@ Tensor add(Tensor a, Tensor b) { ...@@ -11,7 +11,7 @@ Tensor add(Tensor a, Tensor b) {
assert(b.is_contiguous()); assert(b.is_contiguous());
int threadsPerBlock = 1024; int threadsPerBlock = 1024;
int blocksPerGrid = (a.numel() + threadsPerBlock - 1) / threadsPerBlock; int blocksPerGrid = (a.numel() + threadsPerBlock - 1) / threadsPerBlock;
auto stream = getCurrentCUDAStream(); auto stream = getCurrentCUDAStream();
...@@ -44,14 +44,23 @@ void mul_add(Tensor x, Tensor scale, Tensor bias) { ...@@ -44,14 +44,23 @@ void mul_add(Tensor x, Tensor scale, Tensor bias) {
assert(bias.numel() % unroll == 0); assert(bias.numel() % unroll == 0);
int threadsPerBlock = 1024; int threadsPerBlock = 1024;
int blocksPerGrid = (x.numel() + threadsPerBlock * unroll - 1) / (threadsPerBlock * unroll); int blocksPerGrid = (x.numel() + threadsPerBlock * unroll - 1) / (threadsPerBlock * unroll);
auto stream = getCurrentCUDAStream(); auto stream = getCurrentCUDAStream();
dispatch(x.scalar_type(), [&]<typename scalar_t>() { dispatch(x.scalar_type(), [&]<typename scalar_t>() {
if (scale.valid()) { if (scale.valid()) {
mul_add_kernel<scalar_t, unroll, false><<<blocksPerGrid, threadsPerBlock, 0, stream>>>( mul_add_kernel<scalar_t, unroll, false>
x.data_ptr<scalar_t>(), scale.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), 0, x.numel(), scale.numel(), bias.numel(), 0, 0, 0); <<<blocksPerGrid, threadsPerBlock, 0, stream>>>(x.data_ptr<scalar_t>(),
scale.data_ptr<scalar_t>(),
bias.data_ptr<scalar_t>(),
0,
x.numel(),
scale.numel(),
bias.numel(),
0,
0,
0);
} else { } else {
mul_add_kernel<scalar_t, unroll, true><<<blocksPerGrid, threadsPerBlock, 0, stream>>>( mul_add_kernel<scalar_t, unroll, true><<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
x.data_ptr<scalar_t>(), nullptr, bias.data_ptr<scalar_t>(), 0, x.numel(), 1, bias.numel(), 0, 0, 0); x.data_ptr<scalar_t>(), nullptr, bias.data_ptr<scalar_t>(), 0, x.numel(), 1, bias.numel(), 0, 0, 0);
...@@ -65,7 +74,7 @@ void mul_add_batch(Tensor x, Tensor scale, bool batch_scale, double scale_shift, ...@@ -65,7 +74,7 @@ void mul_add_batch(Tensor x, Tensor scale, bool batch_scale, double scale_shift,
assert(!batch_scale || scale.shape[0] == batch_size); assert(!batch_scale || scale.shape[0] == batch_size);
assert(!batch_bias || bias.shape[0] == batch_size); assert(!batch_bias || bias.shape[0] == batch_size);
const int numel = x.numel() / batch_size; const int numel = x.numel() / batch_size;
const int numel_scale = scale.valid() ? (scale.numel() / (batch_scale ? batch_size : 1)) : 1; const int numel_scale = scale.valid() ? (scale.numel() / (batch_scale ? batch_size : 1)) : 1;
const int numel_bias = bias.numel() / (batch_bias ? batch_size : 1); const int numel_bias = bias.numel() / (batch_bias ? batch_size : 1);
...@@ -91,17 +100,29 @@ void mul_add_batch(Tensor x, Tensor scale, bool batch_scale, double scale_shift, ...@@ -91,17 +100,29 @@ void mul_add_batch(Tensor x, Tensor scale, bool batch_scale, double scale_shift,
dispatch(x.scalar_type(), [&]<typename scalar_t>() { dispatch(x.scalar_type(), [&]<typename scalar_t>() {
if (scale.valid()) { if (scale.valid()) {
mul_add_kernel<scalar_t, unroll, false><<<grid, threadsPerBlock, 0, stream>>>( mul_add_kernel<scalar_t, unroll, false>
x.data_ptr<scalar_t>(), scale.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), <<<grid, threadsPerBlock, 0, stream>>>(x.data_ptr<scalar_t>(),
(scalar_t)scale_shift, scale.data_ptr<scalar_t>(),
numel, numel_scale, numel_bias, bias.data_ptr<scalar_t>(),
x.stride(0), batch_scale ? scale.stride(0) : 0, batch_bias ? bias.stride(0) : 0); (scalar_t)scale_shift,
numel,
numel_scale,
numel_bias,
x.stride(0),
batch_scale ? scale.stride(0) : 0,
batch_bias ? bias.stride(0) : 0);
} else { } else {
mul_add_kernel<scalar_t, unroll, true><<<grid, threadsPerBlock, 0, stream>>>( mul_add_kernel<scalar_t, unroll, true>
x.data_ptr<scalar_t>(), nullptr, bias.data_ptr<scalar_t>(), <<<grid, threadsPerBlock, 0, stream>>>(x.data_ptr<scalar_t>(),
(scalar_t)scale_shift, nullptr,
numel, 1, numel_bias, bias.data_ptr<scalar_t>(),
x.stride(0), 0, batch_bias ? bias.stride(0) : 0); (scalar_t)scale_shift,
numel,
1,
numel_bias,
x.stride(0),
0,
batch_bias ? bias.stride(0) : 0);
} }
}); });
} }
...@@ -134,8 +155,7 @@ Tensor argmax_sample(Tensor logits) { ...@@ -134,8 +155,7 @@ Tensor argmax_sample(Tensor logits) {
dispatch(logits.scalar_type(), [&]<typename scalar_t>() { dispatch(logits.scalar_type(), [&]<typename scalar_t>() {
argmax_sample_kernel<<<logits.shape[0], std::min(logits.shape[1], 1024), 0, stream>>>( argmax_sample_kernel<<<logits.shape[0], std::min(logits.shape[1], 1024), 0, stream>>>(
logits.data_ptr<scalar_t>(), out.data_ptr<int32_t>(), logits.shape[1] logits.data_ptr<scalar_t>(), out.data_ptr<int32_t>(), logits.shape[1]);
);
}); });
return out; return out;
...@@ -155,20 +175,17 @@ void splitqkv(Tensor qkv, Tensor q, Tensor k, Tensor v) { ...@@ -155,20 +175,17 @@ void splitqkv(Tensor qkv, Tensor q, Tensor k, Tensor v) {
assert(dim_k == dim_v); assert(dim_k == dim_v);
assert(dim_q + dim_k + dim_v == qkv.shape[-1]); assert(dim_q + dim_k + dim_v == qkv.shape[-1]);
int num_tokens = qkv.numel() / qkv.shape[-1]; int num_tokens = qkv.numel() / qkv.shape[-1];
dispatch(qkv.scalar_type(), [&]<typename scalar_t>() { dispatch(qkv.scalar_type(), [&]<typename scalar_t>() {
splitqkv_kernel<<<num_tokens, std::min(qkv.shape[-1], 1024), 0, stream>>>( splitqkv_kernel<<<num_tokens, std::min(qkv.shape[-1], 1024), 0, stream>>>(qkv.data_ptr<scalar_t>(),
qkv.data_ptr<scalar_t>(), q.data_ptr<scalar_t>(),
q.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), v.data_ptr<scalar_t>(),
v.data_ptr<scalar_t>(), dim_q,
dim_q, dim_k);
dim_k
);
}); });
} }
template<size_t N> template<size_t N>
...@@ -176,7 +193,7 @@ std::array<Tensor, N> split_mod(Tensor input) { ...@@ -176,7 +193,7 @@ std::array<Tensor, N> split_mod(Tensor input) {
assert(input.shape[-1] % N == 0); assert(input.shape[-1] % N == 0);
int threadsPerBlock = 1024; int threadsPerBlock = 1024;
int blocksPerGrid = (input.numel() + threadsPerBlock - 1) / threadsPerBlock; int blocksPerGrid = (input.numel() + threadsPerBlock - 1) / threadsPerBlock;
auto stream = getCurrentCUDAStream(); auto stream = getCurrentCUDAStream();
...@@ -194,8 +211,7 @@ std::array<Tensor, N> split_mod(Tensor input) { ...@@ -194,8 +211,7 @@ std::array<Tensor, N> split_mod(Tensor input) {
outPtr[k] = out[k].template data_ptr<scalar_t>(); outPtr[k] = out[k].template data_ptr<scalar_t>();
} }
split_mod_kernel<<<blocksPerGrid, threadsPerBlock, 0, stream>>>( split_mod_kernel<<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
input.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), outPtr, input.numel());
outPtr, input.numel());
}); });
return out; return out;
...@@ -209,7 +225,7 @@ Tensor quant_static(Tensor x, float scale) { ...@@ -209,7 +225,7 @@ Tensor quant_static(Tensor x, float scale) {
assert((uintptr_t)x.data_ptr() % (x.scalar_size() * unroll) == 0); assert((uintptr_t)x.data_ptr() % (x.scalar_size() * unroll) == 0);
int threadsPerBlock = 1024; int threadsPerBlock = 1024;
int blocksPerGrid = (x.numel() + threadsPerBlock * unroll - 1) / (threadsPerBlock * unroll); int blocksPerGrid = (x.numel() + threadsPerBlock * unroll - 1) / (threadsPerBlock * unroll);
auto stream = getCurrentCUDAStream(); auto stream = getCurrentCUDAStream();
...@@ -228,9 +244,8 @@ Tensor quant_static_fuse_gelu(Tensor x, float scale) { ...@@ -228,9 +244,8 @@ Tensor quant_static_fuse_gelu(Tensor x, float scale) {
assert((uintptr_t)x.data_ptr() % (x.scalar_size() * unroll) == 0); assert((uintptr_t)x.data_ptr() % (x.scalar_size() * unroll) == 0);
int threadsPerBlock = 1024; int threadsPerBlock = 1024;
int blocksPerGrid = (x.numel() + threadsPerBlock * unroll - 1) / (threadsPerBlock * unroll); int blocksPerGrid = (x.numel() + threadsPerBlock * unroll - 1) / (threadsPerBlock * unroll);
auto stream = getCurrentCUDAStream(); auto stream = getCurrentCUDAStream();
...@@ -258,7 +273,7 @@ void cast(Tensor input, Tensor output) { ...@@ -258,7 +273,7 @@ void cast(Tensor input, Tensor output) {
constexpr int unroll = 16 / std::max(sizeof(input_t), sizeof(output_t)); constexpr int unroll = 16 / std::max(sizeof(input_t), sizeof(output_t));
int threadsPerBlock = 1024; int threadsPerBlock = 1024;
int blocksPerGrid = (int)ceilDiv<int64_t>(input.numel(), threadsPerBlock * unroll); int blocksPerGrid = (int)ceilDiv<int64_t>(input.numel(), threadsPerBlock * unroll);
cast_kernel<input_t, output_t, unroll><<<blocksPerGrid, threadsPerBlock, 0, stream>>>( cast_kernel<input_t, output_t, unroll><<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
input.data_ptr<input_t>(), output.data_ptr<output_t>(), input.numel()); input.data_ptr<input_t>(), output.data_ptr<output_t>(), input.numel());
...@@ -271,17 +286,16 @@ void cast(Tensor input, Tensor output) { ...@@ -271,17 +286,16 @@ void cast(Tensor input, Tensor output) {
Tensor topk(Tensor x, int k) { Tensor topk(Tensor x, int k) {
constexpr int MAXK = 64 + 4; constexpr int MAXK = 64 + 4;
const int N = x.shape[-1]; const int N = x.shape[-1];
const int batch = x.numel() / N; const int batch = x.numel() / N;
assert(k <= N); assert(k <= N);
assert(k <= MAXK); assert(k <= MAXK);
auto outShape = TensorShape(x.shape.dataExtent); auto outShape = TensorShape(x.shape.dataExtent);
outShape[-1] = k; outShape[-1] = k;
outShape.dataStride.clear(); outShape.dataStride.clear();
Tensor out = Tensor::empty(outShape, Tensor::INT32, x.device()); Tensor out = Tensor::empty(outShape, Tensor::INT32, x.device());
auto stream = getCurrentCUDAStream(); auto stream = getCurrentCUDAStream();
...@@ -294,10 +308,7 @@ Tensor topk(Tensor x, int k) { ...@@ -294,10 +308,7 @@ Tensor topk(Tensor x, int k) {
if constexpr (K > 0) { if constexpr (K > 0) {
dispatch(x.scalar_type(), [&]<typename scalar_t>() { dispatch(x.scalar_type(), [&]<typename scalar_t>() {
topk_kernel<scalar_t, K><<<ceilDiv(batch, 32), 32, 0, stream>>>( topk_kernel<scalar_t, K><<<ceilDiv(batch, 32), 32, 0, stream>>>(
x.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), out.data_ptr<int>(), N, x.stride(-2), batch);
out.data_ptr<int>(),
N, x.stride(-2), batch
);
checkCUDA(cudaGetLastError()); checkCUDA(cudaGetLastError());
}); });
} }
...@@ -312,4 +323,4 @@ template std::array<Tensor, 4> split_mod<4>(Tensor input); ...@@ -312,4 +323,4 @@ template std::array<Tensor, 4> split_mod<4>(Tensor input);
template std::array<Tensor, 5> split_mod<5>(Tensor input); template std::array<Tensor, 5> split_mod<5>(Tensor input);
template std::array<Tensor, 6> split_mod<6>(Tensor input); template std::array<Tensor, 6> split_mod<6>(Tensor input);
}; // namespace nunchaku::kernels }; // namespace nunchaku::kernels
\ No newline at end of file
...@@ -22,4 +22,4 @@ Tensor topk(Tensor x, int k); ...@@ -22,4 +22,4 @@ Tensor topk(Tensor x, int k);
template<size_t N> template<size_t N>
std::array<Tensor, N> split_mod(Tensor input); std::array<Tensor, N> split_mod(Tensor input);
}; // namespace nunchaku::kernels }; // namespace nunchaku::kernels
\ No newline at end of file
...@@ -9,7 +9,6 @@ ...@@ -9,7 +9,6 @@
namespace nunchaku::kernels { namespace nunchaku::kernels {
template<typename T> template<typename T>
__global__ void add_kernel(T *a, T *b, T *c, size_t length) { __global__ void add_kernel(T *a, T *b, T *c, size_t length) {
int i = threadIdx.x + blockIdx.x * blockDim.x; int i = threadIdx.x + blockIdx.x * blockDim.x;
...@@ -24,12 +23,21 @@ struct alignas(sizeof(T) * unroll) Tvec { ...@@ -24,12 +23,21 @@ struct alignas(sizeof(T) * unroll) Tvec {
}; };
template<typename T, int unroll, bool no_scale> template<typename T, int unroll, bool no_scale>
__global__ void mul_add_kernel(T *x, T *scale, T *bias, T scale_shift, size_t length, int mod_scale, int mod_bias, int64_t batch_stride_x, int64_t batch_stride_scale, int64_t batch_stride_bias) { __global__ void mul_add_kernel(T *x,
T *scale,
T *bias,
T scale_shift,
size_t length,
int mod_scale,
int mod_bias,
int64_t batch_stride_x,
int64_t batch_stride_scale,
int64_t batch_stride_bias) {
const int batch_id = blockIdx.y; const int batch_id = blockIdx.y;
int thread = threadIdx.x + blockIdx.x * blockDim.x; int thread = threadIdx.x + blockIdx.x * blockDim.x;
int i = thread * unroll; int i = thread * unroll;
int i_scale = i % mod_scale; int i_scale = i % mod_scale;
int i_bias = i % mod_bias; int i_bias = i % mod_bias;
if (i >= length) { if (i >= length) {
return; return;
...@@ -37,9 +45,9 @@ __global__ void mul_add_kernel(T *x, T *scale, T *bias, T scale_shift, size_t le ...@@ -37,9 +45,9 @@ __global__ void mul_add_kernel(T *x, T *scale, T *bias, T scale_shift, size_t le
using Tvec = nunchaku::kernels::Tvec<T, unroll>; using Tvec = nunchaku::kernels::Tvec<T, unroll>;
Tvec rx = *reinterpret_cast<Tvec *>(&x[i + batch_stride_x * batch_id]); Tvec rx = *reinterpret_cast<Tvec *>(&x[i + batch_stride_x * batch_id]);
Tvec rscale = *reinterpret_cast<Tvec *>(&scale[i_scale + batch_stride_scale * batch_id]); Tvec rscale = *reinterpret_cast<Tvec *>(&scale[i_scale + batch_stride_scale * batch_id]);
Tvec rbias = *reinterpret_cast<Tvec *>(&bias[i_bias + batch_stride_bias * batch_id]); Tvec rbias = *reinterpret_cast<Tvec *>(&bias[i_bias + batch_stride_bias * batch_id]);
#pragma unroll #pragma unroll
for (int k = 0; k < unroll; k++) { for (int k = 0; k < unroll; k++) {
...@@ -58,16 +66,16 @@ __global__ void mul_add_kernel(T *x, T *scale, T *bias, T scale_shift, size_t le ...@@ -58,16 +66,16 @@ __global__ void mul_add_kernel(T *x, T *scale, T *bias, T scale_shift, size_t le
*reinterpret_cast<Tvec *>(&x[i + batch_stride_x * batch_id]) = rx; *reinterpret_cast<Tvec *>(&x[i + batch_stride_x * batch_id]) = rx;
// #pragma unroll // #pragma unroll
// for (int k = 0; k < unroll; k++) { // for (int k = 0; k < unroll; k++) {
// // assert(i < length); // // assert(i < length);
// x[i] = x[i] * scale[i_scale] + bias[i_bias]; // x[i] = x[i] * scale[i_scale] + bias[i_bias];
// i++; // i++;
// i_scale++; // i_scale++;
// i_bias++; // i_bias++;
// // assert(i_scale < mod_scale); // // assert(i_scale < mod_scale);
// // assert(i_bias < mod_bias); // // assert(i_bias < mod_bias);
// } // }
} }
template<typename T, size_t N> template<typename T, size_t N>
...@@ -82,12 +90,13 @@ __global__ void split_mod_kernel(T *input, std::array<T *, N> output, size_t len ...@@ -82,12 +90,13 @@ __global__ void split_mod_kernel(T *input, std::array<T *, N> output, size_t len
} }
template<typename T> template<typename T>
__global__ void EmbeddingKernel(int32_t *__restrict__ input_id, T *__restrict__ output, T *__restrict__ lookup, int embed_dim) { __global__ void
EmbeddingKernel(int32_t *__restrict__ input_id, T *__restrict__ output, T *__restrict__ lookup, int embed_dim) {
int i = blockIdx.x; int i = blockIdx.x;
int32_t token_id = input_id[i]; int32_t token_id = input_id[i];
T *output_sample_ptr = output + i * embed_dim; T *output_sample_ptr = output + i * embed_dim;
T *target_embed = lookup + token_id * embed_dim; T *target_embed = lookup + token_id * embed_dim;
for (int j = threadIdx.x; j < embed_dim; j += blockDim.x) { for (int j = threadIdx.x; j < embed_dim; j += blockDim.x) {
output_sample_ptr[j] = target_embed[j]; output_sample_ptr[j] = target_embed[j];
...@@ -97,7 +106,7 @@ __global__ void EmbeddingKernel(int32_t *__restrict__ input_id, T *__restrict__ ...@@ -97,7 +106,7 @@ __global__ void EmbeddingKernel(int32_t *__restrict__ input_id, T *__restrict__
template<typename T> template<typename T>
__global__ void argmax_sample_kernel(T *input, int32_t *output, int hidden_dim) { __global__ void argmax_sample_kernel(T *input, int32_t *output, int hidden_dim) {
float maxValue = -1e20; float maxValue = -1e20;
int argmax = 0; int argmax = 0;
for (int i = threadIdx.x; i < hidden_dim; i += blockDim.x) { for (int i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
float data = (float)input[blockIdx.x * hidden_dim + i]; float data = (float)input[blockIdx.x * hidden_dim + i];
if (data > maxValue) { if (data > maxValue) {
...@@ -105,7 +114,7 @@ __global__ void argmax_sample_kernel(T *input, int32_t *output, int hidden_dim) ...@@ -105,7 +114,7 @@ __global__ void argmax_sample_kernel(T *input, int32_t *output, int hidden_dim)
argmax = i; argmax = i;
} }
} }
// blockAllReduceMax seems to be broken when T=half // blockAllReduceMax seems to be broken when T=half
float maxValueBlock = vllm::blockAllReduceMax(maxValue); float maxValueBlock = vllm::blockAllReduceMax(maxValue);
if (maxValue == maxValueBlock) { if (maxValue == maxValueBlock) {
output[blockIdx.x] = argmax; output[blockIdx.x] = argmax;
...@@ -127,14 +136,14 @@ __global__ void splitqkv_kernel(T *qkv, T *q, T *k, T *v, int q_size, int kv_siz ...@@ -127,14 +136,14 @@ __global__ void splitqkv_kernel(T *qkv, T *q, T *k, T *v, int q_size, int kv_siz
} }
} }
template <typename T, int unroll> template<typename T, int unroll>
__global__ void quant_kernel_static(const T * input, int8_t * output, T scale, size_t length) { __global__ void quant_kernel_static(const T *input, int8_t *output, T scale, size_t length) {
int i = (blockIdx.x * blockDim.x + threadIdx.x) * unroll; int i = (blockIdx.x * blockDim.x + threadIdx.x) * unroll;
if (i >= length) { if (i >= length) {
return; return;
} }
using Tvec = nunchaku::kernels::Tvec<T, unroll>; using Tvec = nunchaku::kernels::Tvec<T, unroll>;
using I8vec = nunchaku::kernels::Tvec<int8_t, unroll>; using I8vec = nunchaku::kernels::Tvec<int8_t, unroll>;
Tvec rinput = *reinterpret_cast<const Tvec *>(&input[i]); Tvec rinput = *reinterpret_cast<const Tvec *>(&input[i]);
...@@ -149,14 +158,14 @@ __global__ void quant_kernel_static(const T * input, int8_t * output, T scale, s ...@@ -149,14 +158,14 @@ __global__ void quant_kernel_static(const T * input, int8_t * output, T scale, s
*reinterpret_cast<I8vec *>(&output[i]) = routput; *reinterpret_cast<I8vec *>(&output[i]) = routput;
} }
template <typename T, int unroll> template<typename T, int unroll>
__global__ void quant_kernel_static_fuse_gelu(const T * input, int8_t * output, T scale, size_t length) { __global__ void quant_kernel_static_fuse_gelu(const T *input, int8_t *output, T scale, size_t length) {
int i = (blockIdx.x * blockDim.x + threadIdx.x) * unroll; int i = (blockIdx.x * blockDim.x + threadIdx.x) * unroll;
if (i >= length) { if (i >= length) {
return; return;
} }
using Tvec = nunchaku::kernels::Tvec<T, unroll>; using Tvec = nunchaku::kernels::Tvec<T, unroll>;
using I8vec = nunchaku::kernels::Tvec<int8_t, unroll>; using I8vec = nunchaku::kernels::Tvec<int8_t, unroll>;
Tvec rinput = *reinterpret_cast<const Tvec *>(&input[i]); Tvec rinput = *reinterpret_cast<const Tvec *>(&input[i]);
...@@ -175,10 +184,10 @@ template<typename Tin, typename Tout, int unroll> ...@@ -175,10 +184,10 @@ template<typename Tin, typename Tout, int unroll>
__global__ void cast_kernel(const Tin *input, Tout *output, size_t length) { __global__ void cast_kernel(const Tin *input, Tout *output, size_t length) {
const int i = (blockIdx.x * blockDim.x + threadIdx.x) * unroll; const int i = (blockIdx.x * blockDim.x + threadIdx.x) * unroll;
using Tvec_in = nunchaku::kernels::Tvec<Tin, unroll>; using Tvec_in = nunchaku::kernels::Tvec<Tin, unroll>;
using Tvec_out = nunchaku::kernels::Tvec<Tout, unroll>; using Tvec_out = nunchaku::kernels::Tvec<Tout, unroll>;
Tvec_in rinput = *reinterpret_cast<const Tvec_in *>(&input[i]); Tvec_in rinput = *reinterpret_cast<const Tvec_in *>(&input[i]);
Tvec_out routput; Tvec_out routput;
#pragma unroll #pragma unroll
...@@ -196,16 +205,15 @@ __global__ void cast_kernel(const Tin *input, Tout *output, size_t length) { ...@@ -196,16 +205,15 @@ __global__ void cast_kernel(const Tin *input, Tout *output, size_t length) {
// input: [..., N] // input: [..., N]
// output: [..., K] of index in reverse order // output: [..., K] of index in reverse order
template<typename T, int K> template<typename T, int K>
__global__ __global__ void topk_kernel(const T *input, int *output, int N, int strideInput, int numRows) {
void topk_kernel(const T *input, int *output, int N, int strideInput, int numRows) { const int row = blockIdx.x * blockDim.x + threadIdx.x;
const int row = blockIdx.x * blockDim.x + threadIdx.x;
const int offset = row * strideInput; const int offset = row * strideInput;
if (row >= numRows) { if (row >= numRows) {
return; return;
} }
T val[K]; T val[K];
int16_t idx[K]; int16_t idx[K];
#pragma unroll #pragma unroll
...@@ -224,7 +232,7 @@ void topk_kernel(const T *input, int *output, int N, int strideInput, int numRow ...@@ -224,7 +232,7 @@ void topk_kernel(const T *input, int *output, int N, int strideInput, int numRow
for (int i = K; i < N; i++) { for (int i = K; i < N; i++) {
T newval = input[offset + i]; T newval = input[offset + i];
T minval = val[0]; T minval = val[0];
int minpos = 0; int minpos = 0;
#pragma unroll #pragma unroll
for (int j = 1; j < K; j++) { for (int j = 1; j < K; j++) {
...@@ -259,4 +267,4 @@ void topk_kernel(const T *input, int *output, int N, int strideInput, int numRow ...@@ -259,4 +267,4 @@ void topk_kernel(const T *input, int *output, int N, int strideInput, int numRow
} }
} }
}; // namespace nunchaku::kernels }; // namespace nunchaku::kernels
\ No newline at end of file
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh * Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
...@@ -18,86 +19,80 @@ ...@@ -18,86 +19,80 @@
#pragma once #pragma once
#define FINAL_MASK 0xffffffff #define FINAL_MASK 0xffffffff
namespace vllm { namespace vllm {
template<typename T> template<typename T>
__inline__ __device__ T warpReduceSum(T val) { __inline__ __device__ T warpReduceSum(T val) {
#pragma unroll #pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) for (int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(0xffffffff, val, mask, 32); val += __shfl_xor_sync(0xffffffff, val, mask, 32);
return val; return val;
} }
template <typename T, int NUM> template<typename T, int NUM>
__inline__ __device__ T warpReduceSumV2(T* val) __inline__ __device__ T warpReduceSumV2(T *val) {
{
#pragma unroll #pragma unroll
for (int i = 0; i < NUM; i++) for (int i = 0; i < NUM; i++) {
{
#pragma unroll #pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) for (int mask = 16; mask > 0; mask >>= 1)
val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32);
} }
return (T) (0.0f); return (T)(0.0f);
} }
/* Calculate the sum of all elements in a block */ /* Calculate the sum of all elements in a block */
template<typename T> template<typename T>
__inline__ __device__ T blockReduceSum(T val) { __inline__ __device__ T blockReduceSum(T val) {
static __shared__ T shared[32]; static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f; int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5; int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val); val = warpReduceSum<T>(val);
if (lane == 0) if (lane == 0)
shared[wid] = val; shared[wid] = val;
__syncthreads(); __syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32 // blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val); val = warpReduceSum<T>(val);
return val; return val;
} }
/* Calculate the sum of all elements in a block */ /* Calculate the sum of all elements in a block */
template<typename T> template<typename T>
__inline__ __device__ T blockAllReduceSum(T val) { __inline__ __device__ T blockAllReduceSum(T val) {
static __shared__ T shared[32]; static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f; int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5; int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val); val = warpReduceSum<T>(val);
if (lane == 0) if (lane == 0)
shared[wid] = val; shared[wid] = val;
__syncthreads(); __syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32 // blockDim.x is not divided by 32
val = (lane < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); val = (lane < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val); val = warpReduceSum<T>(val);
return val; return val;
} }
template <typename T, int NUM> template<typename T, int NUM>
__inline__ __device__ T blockReduceSumV2(T* val) __inline__ __device__ T blockReduceSumV2(T *val) {
{
static __shared__ T shared[NUM][33]; static __shared__ T shared[NUM][33];
int lane = threadIdx.x & 0x1f; int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5; int wid = threadIdx.x >> 5;
warpReduceSumV2<T, NUM>(val); warpReduceSumV2<T, NUM>(val);
if (lane == 0) if (lane == 0) {
{
#pragma unroll #pragma unroll
for (int i = 0; i < NUM; i++) for (int i = 0; i < NUM; i++) {
{
shared[i][wid] = val[i]; shared[i][wid] = val[i];
} }
} }
...@@ -106,17 +101,15 @@ __inline__ __device__ T blockReduceSumV2(T* val) ...@@ -106,17 +101,15 @@ __inline__ __device__ T blockReduceSumV2(T* val)
bool is_mask = threadIdx.x < (blockDim.x / 32.f); bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll #pragma unroll
for (int i = 0; i < NUM; i++) for (int i = 0; i < NUM; i++) {
{ val[i] = is_mask ? shared[i][lane] : (T)(0.0f);
val[i] = is_mask ? shared[i][lane] : (T) (0.0f);
} }
warpReduceSumV2<T, NUM>(val); warpReduceSumV2<T, NUM>(val);
return (T) 0.0f; return (T)0.0f;
} }
template<typename T> template<typename T>
__inline__ __device__ T warpReduceMax(T val) __inline__ __device__ T warpReduceMax(T val) {
{
#pragma unroll #pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) for (int mask = 16; mask > 0; mask >>= 1)
val = max(val, __shfl_xor_sync(0xffffffff, val, mask, 32)); val = max(val, __shfl_xor_sync(0xffffffff, val, mask, 32));
...@@ -124,13 +117,12 @@ __inline__ __device__ T warpReduceMax(T val) ...@@ -124,13 +117,12 @@ __inline__ __device__ T warpReduceMax(T val)
} }
/* Calculate the maximum of all elements in a block */ /* Calculate the maximum of all elements in a block */
template<typename T> template<typename T>
__inline__ __device__ T blockReduceMax(T val) __inline__ __device__ T blockReduceMax(T val) {
{
static __shared__ T shared[32]; static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f; // in-warp idx int lane = threadIdx.x & 0x1f; // in-warp idx
int wid = threadIdx.x >> 5; // warp idx int wid = threadIdx.x >> 5; // warp idx
val = warpReduceMax(val); // get maxx in each warp val = warpReduceMax(val); // get maxx in each warp
if (lane == 0) // record in-warp maxx by warp Idx if (lane == 0) // record in-warp maxx by warp Idx
shared[wid] = val; shared[wid] = val;
__syncthreads(); __syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
...@@ -141,16 +133,15 @@ __inline__ __device__ T blockReduceMax(T val) ...@@ -141,16 +133,15 @@ __inline__ __device__ T blockReduceMax(T val)
} }
/* Calculate the maximum of all elements in a block */ /* Calculate the maximum of all elements in a block */
template <typename T> template<typename T>
__inline__ __device__ T blockAllReduceMax(T val) __inline__ __device__ T blockAllReduceMax(T val) {
{
static __shared__ T shared[32]; static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f; // in-warp idx int lane = threadIdx.x & 0x1f; // in-warp idx
int wid = threadIdx.x >> 5; // warp idx int wid = threadIdx.x >> 5; // warp idx
val = warpReduceMax(val); // get maxx in each warp val = warpReduceMax(val); // get maxx in each warp
if (lane == 0) // record in-warp maxx by warp Idx if (lane == 0) // record in-warp maxx by warp Idx
shared[wid] = val; shared[wid] = val;
__syncthreads(); __syncthreads();
...@@ -163,8 +154,4 @@ __inline__ __device__ T blockAllReduceMax(T val) ...@@ -163,8 +154,4 @@ __inline__ __device__ T blockAllReduceMax(T val)
return val; return val;
} }
} // namespace vllm } // namespace vllm
// Adated from FasterTransformer, https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp // Adated from FasterTransformer,
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
#pragma once #pragma once
#include <cassert> #include <cassert>
...@@ -14,8 +15,7 @@ ...@@ -14,8 +15,7 @@
#include <cuda_bf16.h> #include <cuda_bf16.h>
#endif #endif
__device__ __forceinline__ __device__ __forceinline__ static void trap_unsupported_arch() {
static void trap_unsupported_arch() {
if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) { if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) {
printf("This kernel is not supported on your GPU\n"); printf("This kernel is not supported on your GPU\n");
} }
...@@ -25,64 +25,143 @@ static void trap_unsupported_arch() { ...@@ -25,64 +25,143 @@ static void trap_unsupported_arch() {
} }
#if defined(ENABLE_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(ENABLE_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
__device__ __forceinline__ __device__ __forceinline__ static __nv_bfloat162
static __nv_bfloat162 __hfma2(const __nv_bfloat162 a, const __nv_bfloat162 b, const __nv_bfloat162 c) { __hfma2(const __nv_bfloat162 a, const __nv_bfloat162 b, const __nv_bfloat162 c) {
trap_unsupported_arch(); trap_unsupported_arch();
return __nv_bfloat162(0.0f, 0.0f); return __nv_bfloat162(0.0f, 0.0f);
} }
#endif #endif
template<typename T> struct num_elems; template<typename T>
template <> struct num_elems<float> { static constexpr int value = 1; }; struct num_elems;
template <> struct num_elems<float2> { static constexpr int value = 2; }; template<>
template <> struct num_elems<float4> { static constexpr int value = 4; }; struct num_elems<float> {
template <> struct num_elems<half> { static constexpr int value = 1; }; static constexpr int value = 1;
template <> struct num_elems<half2> { static constexpr int value = 2; }; };
template<>
struct num_elems<float2> {
static constexpr int value = 2;
};
template<>
struct num_elems<float4> {
static constexpr int value = 4;
};
template<>
struct num_elems<half> {
static constexpr int value = 1;
};
template<>
struct num_elems<half2> {
static constexpr int value = 2;
};
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
template <> struct num_elems<__nv_bfloat16> { static constexpr int value = 1; }; template<>
template <> struct num_elems<__nv_bfloat162> { static constexpr int value = 2; }; struct num_elems<__nv_bfloat16> {
static constexpr int value = 1;
};
template<>
struct num_elems<__nv_bfloat162> {
static constexpr int value = 2;
};
#endif #endif
#ifdef ENABLE_FP8 #ifdef ENABLE_FP8
template <> struct num_elems<__nv_fp8_e4m3> { static constexpr int value = 1; }; template<>
template <> struct num_elems<__nv_fp8x2_e4m3> { static constexpr int value = 2; }; struct num_elems<__nv_fp8_e4m3> {
static constexpr int value = 1;
};
template<>
struct num_elems<__nv_fp8x2_e4m3> {
static constexpr int value = 2;
};
#endif #endif
template<typename T, int num> struct packed_as; template<typename T, int num>
template<typename T> struct packed_as<T, 1> { using type = T; }; struct packed_as;
template<> struct packed_as<half, 2> { using type = half2; }; template<typename T>
template<> struct packed_as<float, 2> { using type = float2; }; struct packed_as<T, 1> {
template<> struct packed_as<int8_t, 2> { using type = int16_t; }; using type = T;
template<> struct packed_as<int32_t, 2> { using type = int2; }; };
template<> struct packed_as<half2, 1> { using type = half; }; template<>
template<> struct packed_as<float2, 1> { using type = float; }; struct packed_as<half, 2> {
using type = half2;
};
template<>
struct packed_as<float, 2> {
using type = float2;
};
template<>
struct packed_as<int8_t, 2> {
using type = int16_t;
};
template<>
struct packed_as<int32_t, 2> {
using type = int2;
};
template<>
struct packed_as<half2, 1> {
using type = half;
};
template<>
struct packed_as<float2, 1> {
using type = float;
};
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
template<> struct packed_as<__nv_bfloat16, 2> { using type = __nv_bfloat162; }; template<>
template<> struct packed_as<__nv_bfloat162, 1> { using type = __nv_bfloat16; }; struct packed_as<__nv_bfloat16, 2> {
using type = __nv_bfloat162;
};
template<>
struct packed_as<__nv_bfloat162, 1> {
using type = __nv_bfloat16;
};
#endif #endif
#ifdef ENABLE_FP8 #ifdef ENABLE_FP8
template<> struct packed_as<__nv_fp8_e4m3, 2> { using type = __nv_fp8x2_e4m3; }; template<>
template<> struct packed_as<__nv_fp8x2_e4m3, 1> { using type = __nv_fp8_e4m3; }; struct packed_as<__nv_fp8_e4m3, 2> {
template<> struct packed_as<__nv_fp8_e5m2, 2> { using type = __nv_fp8x2_e5m2; }; using type = __nv_fp8x2_e4m3;
template<> struct packed_as<__nv_fp8x2_e5m2, 1> { using type = __nv_fp8_e5m2; }; };
template<>
struct packed_as<__nv_fp8x2_e4m3, 1> {
using type = __nv_fp8_e4m3;
};
template<>
struct packed_as<__nv_fp8_e5m2, 2> {
using type = __nv_fp8x2_e5m2;
};
template<>
struct packed_as<__nv_fp8x2_e5m2, 1> {
using type = __nv_fp8_e5m2;
};
#endif #endif
inline __device__ float2 operator*(float2 a, float2 b) { return make_float2(a.x * b.x, a.y * b.y); } inline __device__ float2 operator*(float2 a, float2 b) {
inline __device__ float2 operator+(float2 a, float2 b) { return make_float2(a.x + b.x, a.y + b.y); } return make_float2(a.x * b.x, a.y * b.y);
inline __device__ float2 operator-(float2 a, float2 b) { return make_float2(a.x - b.x, a.y - b.y); } }
inline __device__ float2 operator+(float2 a, float2 b) {
return make_float2(a.x + b.x, a.y + b.y);
}
inline __device__ float2 operator-(float2 a, float2 b) {
return make_float2(a.x - b.x, a.y - b.y);
}
inline __device__ float2 operator*(float2 a, float b) { return make_float2(a.x * b, a.y * b); } inline __device__ float2 operator*(float2 a, float b) {
inline __device__ float2 operator+(float2 a, float b) { return make_float2(a.x + b, a.y + b); } return make_float2(a.x * b, a.y * b);
inline __device__ float2 operator-(float2 a, float b) { return make_float2(a.x - b, a.y - b); } }
inline __device__ float2 operator+(float2 a, float b) {
return make_float2(a.x + b, a.y + b);
}
inline __device__ float2 operator-(float2 a, float b) {
return make_float2(a.x - b, a.y - b);
}
static inline __device__ int8_t float_to_int8_rn(float x) static inline __device__ int8_t float_to_int8_rn(float x) {
{
uint32_t dst; uint32_t dst;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
return reinterpret_cast<const int8_t&>(dst); return reinterpret_cast<const int8_t &>(dst);
} }
template<typename T> template<typename T>
inline __device__ T ldg(const T* val) { inline __device__ T ldg(const T *val) {
return __ldg(val); return __ldg(val);
} }
...@@ -90,15 +169,13 @@ inline __device__ T ldg(const T* val) { ...@@ -90,15 +169,13 @@ inline __device__ T ldg(const T* val) {
#define bf1622float2 __bfloat1622float2 #define bf1622float2 __bfloat1622float2
#define float22bf162 __float22bfloat162_rn #define float22bf162 __float22bfloat162_rn
#define bf162bf162 __bfloat162bfloat162 #define bf162bf162 __bfloat162bfloat162
inline __device__ int16_t bf1622int16(__nv_bfloat162 val) inline __device__ int16_t bf1622int16(__nv_bfloat162 val) {
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2 f_val; float2 f_val;
f_val.x = max(min(__low2float(val), 127.f), -128.f); f_val.x = max(min(__low2float(val), 127.f), -128.f);
f_val.y = max(min(__high2float(val), 127.f), -128.f); f_val.y = max(min(__high2float(val), 127.f), -128.f);
union union {
{
int8_t int8[2]; int8_t int8[2];
int16_t int16; int16_t int16;
}; };
...@@ -110,8 +187,7 @@ inline __device__ int16_t bf1622int16(__nv_bfloat162 val) ...@@ -110,8 +187,7 @@ inline __device__ int16_t bf1622int16(__nv_bfloat162 val)
val = __hmin2(val, make_bfloat162(127., 127.)); val = __hmin2(val, make_bfloat162(127., 127.));
val = __hmax2(val, make_bfloat162(-128., -128.)); val = __hmax2(val, make_bfloat162(-128., -128.));
union union {
{
int8_t int8[2]; int8_t int8[2];
int16_t int16; int16_t int16;
}; };
...@@ -125,7 +201,7 @@ inline __device__ int16_t bf1622int16(__nv_bfloat162 val) ...@@ -125,7 +201,7 @@ inline __device__ int16_t bf1622int16(__nv_bfloat162 val)
#if ENABLE_BF16 #if ENABLE_BF16
template<> template<>
inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162* val) { inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162 *val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return val[0]; return val[0];
#else #else
...@@ -134,7 +210,7 @@ inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162* val) { ...@@ -134,7 +210,7 @@ inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162* val) {
} }
template<> template<>
inline __device__ __nv_bfloat16 ldg(const __nv_bfloat16* val) { inline __device__ __nv_bfloat16 ldg(const __nv_bfloat16 *val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return val[0]; return val[0];
#else #else
...@@ -143,59 +219,49 @@ inline __device__ __nv_bfloat16 ldg(const __nv_bfloat16* val) { ...@@ -143,59 +219,49 @@ inline __device__ __nv_bfloat16 ldg(const __nv_bfloat16* val) {
} }
#endif // ENABLE_BF16 #endif // ENABLE_BF16
template <typename T_OUT, typename T_IN> template<typename T_OUT, typename T_IN>
__device__ inline T_OUT cuda_cast(T_IN val) __device__ inline T_OUT cuda_cast(T_IN val) {
{
return val; return val;
} }
template <> template<>
__device__ inline float2 cuda_cast<float2, int2>(int2 val) __device__ inline float2 cuda_cast<float2, int2>(int2 val) {
{
return make_float2(val.x, val.y); return make_float2(val.x, val.y);
} }
template <> template<>
__device__ inline float2 cuda_cast<float2, float>(float val) __device__ inline float2 cuda_cast<float2, float>(float val) {
{
return make_float2(val, val); return make_float2(val, val);
} }
template <> template<>
__device__ inline float2 cuda_cast<float2, half2>(half2 val) __device__ inline float2 cuda_cast<float2, half2>(half2 val) {
{
return __half22float2(val); return __half22float2(val);
} }
template <> template<>
__device__ inline half2 cuda_cast<half2, float2>(float2 val) __device__ inline half2 cuda_cast<half2, float2>(float2 val) {
{
return __float22half2_rn(val); return __float22half2_rn(val);
} }
template <> template<>
__device__ inline half2 cuda_cast<half2, float>(float val) __device__ inline half2 cuda_cast<half2, float>(float val) {
{
return __float2half2_rn(val); return __float2half2_rn(val);
} }
template <> template<>
__device__ inline half2 cuda_cast<half2, half>(half val) __device__ inline half2 cuda_cast<half2, half>(half val) {
{
return __half2half2(val); return __half2half2(val);
} }
template <> template<>
__device__ inline int8_t cuda_cast<int8_t, half>(half val) __device__ inline int8_t cuda_cast<int8_t, half>(half val) {
{ union {
union
{
int8_t int8[2]; int8_t int8[2];
int16_t int16; int16_t int16;
}; };
union union {
{
half fp16; half fp16;
int16_t int16_in; int16_t int16_in;
}; };
...@@ -205,11 +271,9 @@ __device__ inline int8_t cuda_cast<int8_t, half>(half val) ...@@ -205,11 +271,9 @@ __device__ inline int8_t cuda_cast<int8_t, half>(half val)
return int8[0]; return int8[0];
} }
template <> template<>
__device__ inline int16_t cuda_cast<int16_t, half2>(half2 val) __device__ inline int16_t cuda_cast<int16_t, half2>(half2 val) {
{ union {
union
{
int8_t int8[2]; int8_t int8[2];
int16_t int16; int16_t int16;
}; };
...@@ -219,11 +283,9 @@ __device__ inline int16_t cuda_cast<int16_t, half2>(half2 val) ...@@ -219,11 +283,9 @@ __device__ inline int16_t cuda_cast<int16_t, half2>(half2 val)
return int16; return int16;
} }
template <> template<>
__device__ inline int8_t cuda_cast<int8_t, float>(float val) __device__ inline int8_t cuda_cast<int8_t, float>(float val) {
{ union {
union
{
int8_t int8[2]; int8_t int8[2];
int16_t int16; int16_t int16;
}; };
...@@ -232,11 +294,9 @@ __device__ inline int8_t cuda_cast<int8_t, float>(float val) ...@@ -232,11 +294,9 @@ __device__ inline int8_t cuda_cast<int8_t, float>(float val)
return int8[0]; return int8[0];
} }
template <> template<>
__device__ inline int16_t cuda_cast<int16_t, float2>(float2 val) __device__ inline int16_t cuda_cast<int16_t, float2>(float2 val) {
{ union {
union
{
int8_t int8[2]; int8_t int8[2];
int16_t int16; int16_t int16;
}; };
...@@ -246,11 +306,9 @@ __device__ inline int16_t cuda_cast<int16_t, float2>(float2 val) ...@@ -246,11 +306,9 @@ __device__ inline int16_t cuda_cast<int16_t, float2>(float2 val)
return int16; return int16;
} }
template <> template<>
__device__ inline half2 cuda_cast<half2, int16_t>(int16_t val) __device__ inline half2 cuda_cast<half2, int16_t>(int16_t val) {
{ union {
union
{
int8_t int8[2]; int8_t int8[2];
int16_t int16; int16_t int16;
}; };
...@@ -259,11 +317,9 @@ __device__ inline half2 cuda_cast<half2, int16_t>(int16_t val) ...@@ -259,11 +317,9 @@ __device__ inline half2 cuda_cast<half2, int16_t>(int16_t val)
return make_half2(int8[0], int8[1]); return make_half2(int8[0], int8[1]);
} }
template <> template<>
__device__ inline float2 cuda_cast<float2, int16_t>(int16_t val) __device__ inline float2 cuda_cast<float2, int16_t>(int16_t val) {
{ union {
union
{
int8_t int8[2]; int8_t int8[2];
int16_t int16; int16_t int16;
}; };
...@@ -273,83 +329,69 @@ __device__ inline float2 cuda_cast<float2, int16_t>(int16_t val) ...@@ -273,83 +329,69 @@ __device__ inline float2 cuda_cast<float2, int16_t>(int16_t val)
} }
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
template <> template<>
__device__ inline __nv_bfloat16 cuda_cast(int32_t val) __device__ inline __nv_bfloat16 cuda_cast(int32_t val) {
{
return static_cast<float>(val); return static_cast<float>(val);
} }
template <> template<>
__device__ inline __nv_bfloat16 cuda_cast(int8_t val) __device__ inline __nv_bfloat16 cuda_cast(int8_t val) {
{
return static_cast<float>(val); return static_cast<float>(val);
} }
template <> template<>
__device__ inline int8_t cuda_cast(__nv_bfloat16 val) __device__ inline int8_t cuda_cast(__nv_bfloat16 val) {
{
return static_cast<float>(val); return static_cast<float>(val);
} }
template <> template<>
__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val) __device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val) {
{
return __bfloat162float(val); return __bfloat162float(val);
} }
template <> template<>
__device__ inline float2 cuda_cast<float2, __nv_bfloat162>(__nv_bfloat162 val) __device__ inline float2 cuda_cast<float2, __nv_bfloat162>(__nv_bfloat162 val) {
{
return bf1622float2(val); return bf1622float2(val);
} }
template <> template<>
__device__ inline half cuda_cast<half, __nv_bfloat16>(__nv_bfloat16 val) __device__ inline half cuda_cast<half, __nv_bfloat16>(__nv_bfloat16 val) {
{
return __float2half(__bfloat162float(val)); return __float2half(__bfloat162float(val));
} }
template <> template<>
__device__ inline int16_t cuda_cast<int16_t, __nv_bfloat162>(__nv_bfloat162 val) __device__ inline int16_t cuda_cast<int16_t, __nv_bfloat162>(__nv_bfloat162 val) {
{
return bf1622int16(val); return bf1622int16(val);
} }
template <> template<>
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) __device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) {
{
return __float2bfloat16(val); return __float2bfloat16(val);
} }
template <> template<>
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, half>(half val) __device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, half>(half val) {
{
return __float2bfloat16(__half2float(val)); return __float2bfloat16(__half2float(val));
} }
template <> template<>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_bfloat16>(__nv_bfloat16 val) __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_bfloat16>(__nv_bfloat16 val) {
{
return bf162bf162(val); return bf162bf162(val);
} }
template <> template<>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float>(float val) __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float>(float val) {
{
return __float2bfloat162_rn(val); return __float2bfloat162_rn(val);
} }
template <> template<>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float2>(float2 val) __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float2>(float2 val) {
{
return float22bf162(val); return float22bf162(val);
} }
template <> template<>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(int16_t val) __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(int16_t val) {
{ union {
union
{
int8_t int8[2]; int8_t int8[2];
int16_t int16; int16_t int16;
}; };
...@@ -361,72 +403,57 @@ __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(int16_t val) ...@@ -361,72 +403,57 @@ __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(int16_t val)
return res; return res;
} }
template <> template<>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val) __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val) {
{
return float22bf162(__half22float2(val)); return float22bf162(__half22float2(val));
} }
#endif // ENABLE BF16 #endif // ENABLE BF16
template <typename f16_t> template<typename f16_t>
__device__ __forceinline__ __device__ __forceinline__ packed_as<f16_t, 2>::type f162f162(f16_t x);
packed_as<f16_t, 2>::type
f162f162(f16_t x);
template <> template<>
__device__ __forceinline__ __device__ __forceinline__ packed_as<half, 2>::type f162f162<half>(half x) {
packed_as<half, 2>::type return __half2half2(x);
f162f162<half>(half x)
{
return __half2half2(x);
} }
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
template <> template<>
__device__ __forceinline__ __device__ __forceinline__ packed_as<__nv_bfloat16, 2>::type f162f162<__nv_bfloat16>(__nv_bfloat16 x) {
packed_as<__nv_bfloat16, 2>::type return __bfloat162bfloat162(x);
f162f162<__nv_bfloat16>(__nv_bfloat16 x) }
{ #endif
return __bfloat162bfloat162(x);
} template<typename To, typename Ti>
# endif __device__ inline To cuda_sum(Ti val) {
template <typename To, typename Ti>
__device__ inline To cuda_sum(Ti val)
{
return cuda_cast<To>(val); return cuda_cast<To>(val);
}; };
template <typename To> template<typename To>
__device__ inline To cuda_sum(float2 val) __device__ inline To cuda_sum(float2 val) {
{
return cuda_cast<To>(val.x + val.y); return cuda_cast<To>(val.x + val.y);
}; };
// Unary maximum: compute the max of a vector type // Unary maximum: compute the max of a vector type
template <typename To, typename Ti> template<typename To, typename Ti>
__device__ inline To cuda_max(Ti val) __device__ inline To cuda_max(Ti val) {
{
return cuda_cast<To>(val); return cuda_cast<To>(val);
}; };
template <> template<>
__device__ inline float cuda_max(float2 val) __device__ inline float cuda_max(float2 val) {
{
return fmaxf(val.x, val.y); return fmaxf(val.x, val.y);
} }
template <> template<>
__device__ inline half cuda_max(half2 val) __device__ inline half cuda_max(half2 val) {
{
return __hmax(val.x, val.y); return __hmax(val.x, val.y);
} }
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
template <> template<>
__device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val) __device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val) {
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
return __hmax(val.x, val.y); return __hmax(val.x, val.y);
#else #else
...@@ -437,57 +464,49 @@ __device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val) ...@@ -437,57 +464,49 @@ __device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val)
#endif #endif
// Binary maximum: compute the max of two scalar types // Binary maximum: compute the max of two scalar types
template <typename T> template<typename T>
__device__ inline T cuda_max(T val1, T val2) __device__ inline T cuda_max(T val1, T val2) {
{
return (val1 > val2) ? val1 : val2; return (val1 > val2) ? val1 : val2;
} }
template <typename T> template<typename T>
__device__ inline T cuda_abs(T val) __device__ inline T cuda_abs(T val) {
{
assert(false); assert(false);
return {}; return {};
} }
template <> template<>
__device__ inline float cuda_abs(float val) __device__ inline float cuda_abs(float val) {
{
return fabs(val); return fabs(val);
} }
template <> template<>
__device__ inline float2 cuda_abs(float2 val) __device__ inline float2 cuda_abs(float2 val) {
{
return make_float2(fabs(val.x), fabs(val.y)); return make_float2(fabs(val.x), fabs(val.y));
} }
template <> template<>
__device__ inline half cuda_abs(half val) __device__ inline half cuda_abs(half val) {
{
return __habs(val); return __habs(val);
} }
template <> template<>
__device__ inline half2 cuda_abs(half2 val) __device__ inline half2 cuda_abs(half2 val) {
{
return __habs2(val); return __habs2(val);
} }
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__) #if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)
template <> template<>
__device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val) __device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val) {
{
return __habs(val); return __habs(val);
} }
template <> template<>
__device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val) __device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val) {
{
return __habs2(val); return __habs2(val);
} }
#endif #endif
#endif // ENABLE_FP16 #endif // ENABLE_FP16
\ No newline at end of file
...@@ -7,17 +7,15 @@ ...@@ -7,17 +7,15 @@
namespace nunchaku::kernels { namespace nunchaku::kernels {
void attention_fp16( void attention_fp16(Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM] Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM] Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM] Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM]
Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM] float scale) {
float scale int sizeBatch = q.shape[0];
) { int numHeads = q.shape[1];
int sizeBatch = q.shape[0]; int numTokensQ = q.shape[2];
int numHeads = q.shape[1]; int headDim = q.shape[3];
int numTokensQ = q.shape[2];
int headDim = q.shape[3];
int numTokensKV = k.shape[2]; int numTokensKV = k.shape[2];
assert(o.ndims() == 3); assert(o.ndims() == 3);
...@@ -55,20 +53,20 @@ void attention_fp16( ...@@ -55,20 +53,20 @@ void attention_fp16(
assert(headDim == Attention::HEAD_DIM); assert(headDim == Attention::HEAD_DIM);
auto launch = [&]<typename Epilogue>(Epilogue::Arguments args) { auto launch = [&]<typename Epilogue>(Epilogue::Arguments args) {
dim3 grid(numTokensQ / Attention::BLOCK_M, numHeads, sizeBatch); dim3 grid(numTokensQ / Attention::BLOCK_M, numHeads, sizeBatch);
using packed_q_t = typename Attention::packed_q_t; using packed_q_t = typename Attention::packed_q_t;
using packed_k_t = typename Attention::packed_k_t; using packed_k_t = typename Attention::packed_k_t;
using packed_v_t = typename Attention::packed_v_t; using packed_v_t = typename Attention::packed_v_t;
auto func = invoke_kernel<typename Attention::attention_fp16_kernel<Epilogue>, auto func = invoke_kernel<typename Attention::attention_fp16_kernel<Epilogue>,
const packed_q_t *, const packed_q_t *,
const packed_k_t *, const packed_k_t *,
const packed_v_t *, const packed_v_t *,
float, float,
int, int, int,
typename Epilogue::Arguments, int,
bool>; typename Epilogue::Arguments,
bool>;
shmem = std::max(shmem, Attention::template attention_fp16_kernel<Epilogue>::SHMEM_SIZE); shmem = std::max(shmem, Attention::template attention_fp16_kernel<Epilogue>::SHMEM_SIZE);
...@@ -76,26 +74,23 @@ void attention_fp16( ...@@ -76,26 +74,23 @@ void attention_fp16(
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
} }
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>( func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>(q.data_ptr<packed_q_t>(),
q.data_ptr<packed_q_t>(), k.data_ptr<packed_k_t>(),
k.data_ptr<packed_k_t>(), v.data_ptr<packed_v_t>(),
v.data_ptr<packed_v_t>(), scale,
scale, numTokensQ,
numTokensQ, numTokensKV, numTokensKV,
args, args,
false false);
);
checkCUDA(cudaGetLastError()); checkCUDA(cudaGetLastError());
}; };
launch.template operator()<typename GEMM::EpilogueDefault>(typename GEMM::EpilogueDefault::Arguments{ launch.template operator()<typename GEMM::EpilogueDefault>(typename GEMM::EpilogueDefault::Arguments{
.out = o.data_ptr<typename GEMM::half_t>(), .out = o.data_ptr<typename GEMM::half_t>(),
.actualM = sizeBatch * numTokensQ, .actualM = sizeBatch * numTokensQ,
.actualN = numHeads * headDim, .actualN = numHeads * headDim,
}); });
}); });
} }
}; // namespace nunchaku::kernels }; // namespace nunchaku::kernels
\ No newline at end of file
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
namespace nunchaku::kernels { namespace nunchaku::kernels {
// M: Q tokens // M: Q tokens
// N: V HEAD_DIM // N: V HEAD_DIM
// K: K tokens // K: K tokens
// D: QK HEAD_DIM // D: QK HEAD_DIM
...@@ -12,21 +12,21 @@ template<bool bf16out> ...@@ -12,21 +12,21 @@ template<bool bf16out>
struct AttentionFP16Config { struct AttentionFP16Config {
static constexpr int HEAD_DIM = 128; static constexpr int HEAD_DIM = 128;
static constexpr int BLOCK_M = 128; static constexpr int BLOCK_M = 128;
static constexpr int WARP_SIZE = 32; static constexpr int WARP_SIZE = 32;
static constexpr int NUM_WARPS = 8; static constexpr int NUM_WARPS = 8;
static constexpr int WARP_K = 32; static constexpr int WARP_K = 32;
static constexpr int INSN_M = 16; static constexpr int INSN_M = 16;
static constexpr int INSN_N = 16; static constexpr int INSN_N = 16;
static constexpr int INSN_K_QK = 16; static constexpr int INSN_K_QK = 16;
static constexpr int INSN_K_PV = 16; static constexpr int INSN_K_PV = 16;
using half_t = half; using half_t = half;
using half2_t = half2; using half2_t = half2;
using epilogue_half_t = typename std::conditional_t<bf16out, __nv_bfloat16, half>; using epilogue_half_t = typename std::conditional_t<bf16out, __nv_bfloat16, half>;
using epilogue_half2_t = typename std::conditional_t<bf16out, __nv_bfloat162, half2>; using epilogue_half2_t = typename std::conditional_t<bf16out, __nv_bfloat162, half2>;
}; };
...@@ -67,8 +67,8 @@ public: ...@@ -67,8 +67,8 @@ public:
#endif #endif
struct GEMMConfig { struct GEMMConfig {
static constexpr int BLOCK_M = AttentionConfig::BLOCK_M; static constexpr int BLOCK_M = AttentionConfig::BLOCK_M;
static constexpr int BLOCK_N = AttentionConfig::HEAD_DIM; static constexpr int BLOCK_N = AttentionConfig::HEAD_DIM;
static constexpr int WARP_SIZE = AttentionConfig::WARP_SIZE; static constexpr int WARP_SIZE = AttentionConfig::WARP_SIZE;
static constexpr int NUM_WARPS = AttentionConfig::NUM_WARPS; static constexpr int NUM_WARPS = AttentionConfig::NUM_WARPS;
...@@ -86,23 +86,24 @@ public: ...@@ -86,23 +86,24 @@ public:
static constexpr int WARP_N = HEAD_DIM; static constexpr int WARP_N = HEAD_DIM;
static constexpr int WARP_D = HEAD_DIM; static constexpr int WARP_D = HEAD_DIM;
static constexpr int WARP_M_TILES = WARP_M / INSN_M; static constexpr int WARP_M_TILES = WARP_M / INSN_M;
static constexpr int WARP_N_TILES = WARP_N / INSN_N; static constexpr int WARP_N_TILES = WARP_N / INSN_N;
static constexpr int WARP_K_TILES_QK = WARP_K / INSN_N; // when multiplying Q*K, K is on dimension of N in MMA instruction static constexpr int WARP_K_TILES_QK =
WARP_K / INSN_N; // when multiplying Q*K, K is on dimension of N in MMA instruction
static constexpr int WARP_K_TILES_PV = WARP_K / INSN_K_PV; static constexpr int WARP_K_TILES_PV = WARP_K / INSN_K_PV;
static constexpr int WARP_D_TILES = WARP_D / INSN_K_QK; static constexpr int WARP_D_TILES = WARP_D / INSN_K_QK;
using packed_q_t = uint4; using packed_q_t = uint4;
using packed_k_t = uint4; using packed_k_t = uint4;
using packed_v_t = uint4; using packed_v_t = uint4;
using q_warp = std::array<packed_q_t, WARP_M_TILES * WARP_D_TILES>; using q_warp = std::array<packed_q_t, WARP_M_TILES * WARP_D_TILES>;
using k_warp = std::array<packed_k_t, WARP_K_TILES_QK * WARP_D_TILES>; using k_warp = std::array<packed_k_t, WARP_K_TILES_QK * WARP_D_TILES>;
using v_warp = std::array<packed_v_t, WARP_K_TILES_PV * WARP_N_TILES>; using v_warp = std::array<packed_v_t, WARP_K_TILES_PV * WARP_N_TILES>;
using packed_p_t = uint4; using packed_p_t = uint4;
using p_warp = std::array<packed_v_t, WARP_M_TILES * WARP_K_TILES_PV>; using p_warp = std::array<packed_v_t, WARP_M_TILES * WARP_K_TILES_PV>;
using packed_fpsum_t = uint4; using packed_fpsum_t = uint4;
using packed_f32psum_t = typename GEMM::packed_f32psum_t; using packed_f32psum_t = typename GEMM::packed_f32psum_t;
using qk_warp = std::array<packed_f32psum_t, WARP_M_TILES * WARP_K_TILES_QK>; using qk_warp = std::array<packed_f32psum_t, WARP_M_TILES * WARP_K_TILES_QK>;
...@@ -112,16 +113,15 @@ public: ...@@ -112,16 +113,15 @@ public:
using rowval_warp = std::array<float2, WARP_M_TILES>; using rowval_warp = std::array<float2, WARP_M_TILES>;
struct BlockInfo { struct BlockInfo {
int bm; // M: Q tokens, bm: block id of M int bm; // M: Q tokens, bm: block id of M
int head; // H: head int head; // H: head
int batch; // B: batch int batch; // B: batch
int numBlocksM; int numBlocksM;
int numHeads; int numHeads;
int numBatch; int numBatch;
}; };
__device__ __forceinline__ __device__ __forceinline__ static packed_fpsum_t packed_fp32_to_fp16(packed_f32psum_t input) {
static packed_fpsum_t packed_fp32_to_fp16(packed_f32psum_t input) {
std::array<half2_t, 4> results; std::array<half2_t, 4> results;
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
results[i] = float22half2<half2_t>(float2(input.data[i * 2], input.data[i * 2 + 1])); results[i] = float22half2<half2_t>(float2(input.data[i * 2], input.data[i * 2 + 1]));
...@@ -129,21 +129,19 @@ public: ...@@ -129,21 +129,19 @@ public:
return kernels::bit_cast<packed_fpsum_t>(results); return kernels::bit_cast<packed_fpsum_t>(results);
} }
__device__ __forceinline__ __device__ __forceinline__ static packed_f32psum_t packed_fp16_to_fp32(packed_fpsum_t input) {
static packed_f32psum_t packed_fp16_to_fp32(packed_fpsum_t input) {
auto arr = kernels::bit_cast<std::array<half2_t, 4>>(input); auto arr = kernels::bit_cast<std::array<half2_t, 4>>(input);
packed_f32psum_t results; packed_f32psum_t results;
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
float2 tmp = half22float2(arr[i]); float2 tmp = half22float2(arr[i]);
results.data[i * 2] = tmp.x; results.data[i * 2] = tmp.x;
results.data[i * 2 + 1] = tmp.y; results.data[i * 2 + 1] = tmp.y;
} }
return results; return results;
} }
// q: [batch, head, bm, NUM_WARPS, WARP_M_TILES, WARP_D_TILES, WARP_SIZE] of packed_q_t // q: [batch, head, bm, NUM_WARPS, WARP_M_TILES, WARP_D_TILES, WARP_SIZE] of packed_q_t
__device__ __forceinline__ __device__ __forceinline__ static void load_q(const packed_q_t *ptr, q_warp &out, bool pred) {
static void load_q(const packed_q_t *ptr, q_warp &out, bool pred) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE;
...@@ -157,8 +155,7 @@ public: ...@@ -157,8 +155,7 @@ public:
} }
// k: [batch, head, ktile, WARP_K_TILES_QK, WARP_D_TILES, WARP_SIZE] of packed_k_t // k: [batch, head, ktile, WARP_K_TILES_QK, WARP_D_TILES, WARP_SIZE] of packed_k_t
__device__ __forceinline__ __device__ __forceinline__ static void load_k(const packed_k_t *ptr, int ktile, k_warp &out, bool pred) {
static void load_k(const packed_k_t *ptr, int ktile, k_warp &out, bool pred) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE;
...@@ -172,8 +169,7 @@ public: ...@@ -172,8 +169,7 @@ public:
} }
// v: [batch, head, ktile, WARP_K_TILES_PV, WARP_N_TILES, WARP_SIZE] of packed_v_t // v: [batch, head, ktile, WARP_K_TILES_PV, WARP_N_TILES, WARP_SIZE] of packed_v_t
__device__ __forceinline__ __device__ __forceinline__ static void load_v(const packed_v_t *ptr, int ktile, v_warp &out, bool pred) {
static void load_v(const packed_v_t *ptr, int ktile, v_warp &out, bool pred) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE;
...@@ -186,34 +182,31 @@ public: ...@@ -186,34 +182,31 @@ public:
}); });
} }
__device__ __forceinline__ __device__ __forceinline__ static packed_fpsum_t
static packed_fpsum_t mma_f16xf16_f16(packed_fpsum_t a, packed_fpsum_t b, packed_fpsum_t psum) { mma_f16xf16_f16(packed_fpsum_t a, packed_fpsum_t b, packed_fpsum_t psum) {
uint2 out1 = mma_m16n8k16_f16f16f16f16(a, uint2(b.x, b.y), uint2(psum.x, psum.y)); uint2 out1 = mma_m16n8k16_f16f16f16f16(a, uint2(b.x, b.y), uint2(psum.x, psum.y));
uint2 out2 = mma_m16n8k16_f16f16f16f16(a, uint2(b.z, b.w), uint2(psum.z, psum.w)); uint2 out2 = mma_m16n8k16_f16f16f16f16(a, uint2(b.z, b.w), uint2(psum.z, psum.w));
return packed_fpsum_t{out1.x, out1.y, out2.x, out2.y}; return packed_fpsum_t{out1.x, out1.y, out2.x, out2.y};
} }
// set nan values to -inf // set nan values to -inf
__device__ __forceinline__ __device__ __forceinline__ static half2_t fix_nan(half2_t input) {
static half2_t fix_nan(half2_t input) {
static constexpr float neginf = -std::numeric_limits<float>::infinity(); static constexpr float neginf = -std::numeric_limits<float>::infinity();
/** /**
* In accordance to the IEEE-754R standard, * In accordance to the IEEE-754R standard,
* if one of the input parameters to fminf(), fmin(), fmaxf(), or fmax() is NaN, * if one of the input parameters to fminf(), fmin(), fmaxf(), or fmax() is NaN,
* but not the other, * but not the other,
* the result is the non-NaN parameter. * the result is the non-NaN parameter.
*/ */
return __hmax2(input, half2_t(neginf, neginf)); return __hmax2(input, half2_t(neginf, neginf));
} }
__device__ __forceinline__ __device__ __forceinline__ static float fix_nan(float input) {
static float fix_nan(float input) {
static constexpr float neginf = -std::numeric_limits<float>::infinity(); static constexpr float neginf = -std::numeric_limits<float>::infinity();
return fmaxf(input, neginf); return fmaxf(input, neginf);
} }
__device__ __forceinline__ __device__ __forceinline__ static packed_fpsum_t fix_nan(packed_fpsum_t input) {
static packed_fpsum_t fix_nan(packed_fpsum_t input) {
input.x = kernels::bit_cast<int>(fix_nan(kernels::bit_cast<half2_t>(input.x))); input.x = kernels::bit_cast<int>(fix_nan(kernels::bit_cast<half2_t>(input.x)));
input.y = kernels::bit_cast<int>(fix_nan(kernels::bit_cast<half2_t>(input.y))); input.y = kernels::bit_cast<int>(fix_nan(kernels::bit_cast<half2_t>(input.y)));
input.z = kernels::bit_cast<int>(fix_nan(kernels::bit_cast<half2_t>(input.z))); input.z = kernels::bit_cast<int>(fix_nan(kernels::bit_cast<half2_t>(input.z)));
...@@ -221,30 +214,28 @@ public: ...@@ -221,30 +214,28 @@ public:
return input; return input;
} }
__device__ __forceinline__ __device__ __forceinline__ static packed_f32psum_t fix_nan(packed_f32psum_t input) {
static packed_f32psum_t fix_nan(packed_f32psum_t input) { #pragma unroll
#pragma unroll
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
input.data[i] = fix_nan(input.data[i]); input.data[i] = fix_nan(input.data[i]);
} }
return input; return input;
} }
__device__ __forceinline__ __device__ __forceinline__ static qk_warp compute_qk(q_warp Q, k_warp K) {
static qk_warp compute_qk(q_warp Q, k_warp K) {
qk_warp QK; qk_warp QK;
#pragma unroll #pragma unroll
for (int m = 0; m < WARP_M_TILES; m++) { for (int m = 0; m < WARP_M_TILES; m++) {
#pragma unroll #pragma unroll
for (int k = 0; k < WARP_K_TILES_QK; k++) { for (int k = 0; k < WARP_K_TILES_QK; k++) {
#if 0 #if 0
#pragma unroll #pragma unroll
for (int d = 0; d < WARP_D_TILES; d++) { for (int d = 0; d < WARP_D_TILES; d++) {
packed_fpsum_t psum = make_uint4(0, 0, 0, 0); packed_fpsum_t psum = make_uint4(0, 0, 0, 0);
psum = mma_f16xf16_f16(Q[m * WARP_D_TILES + d], K[k * WARP_D_TILES + d], psum); psum = mma_f16xf16_f16(Q[m * WARP_D_TILES + d], K[k * WARP_D_TILES + d], psum);
auto f32psum = packed_fp16_to_fp32(psum); auto f32psum = packed_fp16_to_fp32(psum);
#pragma unroll #pragma unroll
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
QK[m * WARP_K_TILES_QK + k].data[i] += f32psum.data[i]; QK[m * WARP_K_TILES_QK + k].data[i] += f32psum.data[i];
} }
...@@ -252,34 +243,32 @@ public: ...@@ -252,34 +243,32 @@ public:
#else #else
packed_fpsum_t psum = make_uint4(0, 0, 0, 0); packed_fpsum_t psum = make_uint4(0, 0, 0, 0);
#pragma unroll #pragma unroll
for (int d = 0; d < WARP_D_TILES; d++) { for (int d = 0; d < WARP_D_TILES; d++) {
psum = mma_f16xf16_f16(Q[m * WARP_D_TILES + d], K[k * WARP_D_TILES + d], psum); psum = mma_f16xf16_f16(Q[m * WARP_D_TILES + d], K[k * WARP_D_TILES + d], psum);
} }
if constexpr (IS_SM80) { if constexpr (IS_SM80) {
psum = fix_nan(psum); psum = fix_nan(psum);
QK[m * WARP_K_TILES_QK + k] = packed_fp16_to_fp32(psum); QK[m * WARP_K_TILES_QK + k] = packed_fp16_to_fp32(psum);
} else { } else {
QK[m * WARP_K_TILES_QK + k] = fix_nan(packed_fp16_to_fp32(psum)); QK[m * WARP_K_TILES_QK + k] = fix_nan(packed_fp16_to_fp32(psum));
} }
#endif #endif
} }
} }
return QK; return QK;
} }
__device__ __forceinline__ __device__ __forceinline__ static rowval_warp compute_rowmax(qk_warp QK, rowval_warp rowmax, float scale) {
static rowval_warp compute_rowmax(qk_warp QK, rowval_warp rowmax, float scale) { #pragma unroll
#pragma unroll
for (int m = 0; m < WARP_M_TILES; m++) { for (int m = 0; m < WARP_M_TILES; m++) {
float2 maxv; float2 maxv;
#pragma unroll #pragma unroll
for (int k = 0; k < WARP_K_TILES_QK; k++) { for (int k = 0; k < WARP_K_TILES_QK; k++) {
packed_f32psum_t &val = QK[m * WARP_K_TILES_QK + k]; packed_f32psum_t &val = QK[m * WARP_K_TILES_QK + k];
float x = fmaxf(fmaxf(val.data[0], val.data[1]), fmaxf(val.data[4], val.data[5])); float x = fmaxf(fmaxf(val.data[0], val.data[1]), fmaxf(val.data[4], val.data[5]));
float y = fmaxf(fmaxf(val.data[2], val.data[3]), fmaxf(val.data[6], val.data[7])); float y = fmaxf(fmaxf(val.data[2], val.data[3]), fmaxf(val.data[6], val.data[7]));
if (k == 0) { if (k == 0) {
maxv = make_float2(x, y); maxv = make_float2(x, y);
} else { } else {
...@@ -287,7 +276,7 @@ public: ...@@ -287,7 +276,7 @@ public:
maxv.y = fmaxf(maxv.y, y); maxv.y = fmaxf(maxv.y, y);
} }
} }
#pragma unroll #pragma unroll
for (int mask = 1; mask <= 2; mask *= 2) { for (int mask = 1; mask <= 2; mask *= 2) {
maxv.x = fmaxf(maxv.x, __shfl_xor_sync(~0, maxv.x, mask)); maxv.x = fmaxf(maxv.x, __shfl_xor_sync(~0, maxv.x, mask));
maxv.y = fmaxf(maxv.y, __shfl_xor_sync(~0, maxv.y, mask)); maxv.y = fmaxf(maxv.y, __shfl_xor_sync(~0, maxv.y, mask));
...@@ -298,40 +287,38 @@ public: ...@@ -298,40 +287,38 @@ public:
return rowmax; return rowmax;
} }
__device__ __forceinline__ __device__ __forceinline__ static qk_warp softmax(qk_warp QK, rowval_warp rowmax_scaled, float scale) {
static qk_warp softmax(qk_warp QK, rowval_warp rowmax_scaled, float scale) { #pragma unroll
#pragma unroll
for (int m = 0; m < WARP_M_TILES; m++) { for (int m = 0; m < WARP_M_TILES; m++) {
float2 shift = rowmax_scaled[m]; float2 shift = rowmax_scaled[m];
#pragma unroll #pragma unroll
for (int k = 0; k < WARP_K_TILES_QK; k++) { for (int k = 0; k < WARP_K_TILES_QK; k++) {
packed_f32psum_t &val = QK[m * WARP_K_TILES_QK + k]; packed_f32psum_t &val = QK[m * WARP_K_TILES_QK + k];
val.data[0] = cuda_exp2(fmaf(val.data[0], scale, -shift.x)); val.data[0] = cuda_exp2(fmaf(val.data[0], scale, -shift.x));
val.data[1] = cuda_exp2(fmaf(val.data[1], scale, -shift.x)); val.data[1] = cuda_exp2(fmaf(val.data[1], scale, -shift.x));
val.data[4] = cuda_exp2(fmaf(val.data[4], scale, -shift.x)); val.data[4] = cuda_exp2(fmaf(val.data[4], scale, -shift.x));
val.data[5] = cuda_exp2(fmaf(val.data[5], scale, -shift.x)); val.data[5] = cuda_exp2(fmaf(val.data[5], scale, -shift.x));
val.data[2] = cuda_exp2(fmaf(val.data[2], scale, -shift.y)); val.data[2] = cuda_exp2(fmaf(val.data[2], scale, -shift.y));
val.data[3] = cuda_exp2(fmaf(val.data[3], scale, -shift.y)); val.data[3] = cuda_exp2(fmaf(val.data[3], scale, -shift.y));
val.data[6] = cuda_exp2(fmaf(val.data[6], scale, -shift.y)); val.data[6] = cuda_exp2(fmaf(val.data[6], scale, -shift.y));
val.data[7] = cuda_exp2(fmaf(val.data[7], scale, -shift.y)); val.data[7] = cuda_exp2(fmaf(val.data[7], scale, -shift.y));
} }
} }
return QK; return QK;
} }
__device__ __forceinline__ __device__ __forceinline__ static rowval_warp compute_rowsum(qk_warp QK) {
static rowval_warp compute_rowsum(qk_warp QK) {
rowval_warp rowsum; rowval_warp rowsum;
#pragma unroll #pragma unroll
for (int m = 0; m < WARP_M_TILES; m++) { for (int m = 0; m < WARP_M_TILES; m++) {
float2 sumv = make_float2(0.0f, 0.0f); float2 sumv = make_float2(0.0f, 0.0f);
#pragma unroll #pragma unroll
for (int k = 0; k < WARP_K_TILES_QK; k++) { for (int k = 0; k < WARP_K_TILES_QK; k++) {
packed_f32psum_t &val = QK[m * WARP_K_TILES_QK + k]; packed_f32psum_t &val = QK[m * WARP_K_TILES_QK + k];
sumv.x += val.data[0] + val.data[1] + val.data[4] + val.data[5]; sumv.x += val.data[0] + val.data[1] + val.data[4] + val.data[5];
sumv.y += val.data[2] + val.data[3] + val.data[6] + val.data[7]; sumv.y += val.data[2] + val.data[3] + val.data[6] + val.data[7];
} }
#pragma unroll #pragma unroll
for (int mask = 1; mask <= 2; mask *= 2) { for (int mask = 1; mask <= 2; mask *= 2) {
sumv.x += __shfl_xor_sync(~0, sumv.x, mask); sumv.x += __shfl_xor_sync(~0, sumv.x, mask);
sumv.y += __shfl_xor_sync(~0, sumv.y, mask); sumv.y += __shfl_xor_sync(~0, sumv.y, mask);
...@@ -341,10 +328,9 @@ public: ...@@ -341,10 +328,9 @@ public:
return rowsum; return rowsum;
} }
__device__ __forceinline__ __device__ __forceinline__ static rowval_warp compute_rescale(rowval_warp rowmax0, rowval_warp rowmax1) {
static rowval_warp compute_rescale(rowval_warp rowmax0, rowval_warp rowmax1) {
rowval_warp rescale; rowval_warp rescale;
#pragma unroll #pragma unroll
for (int m = 0; m < WARP_M_TILES; m++) { for (int m = 0; m < WARP_M_TILES; m++) {
rescale[m].x = cuda_exp2(rowmax0[m].x - rowmax1[m].x); rescale[m].x = cuda_exp2(rowmax0[m].x - rowmax1[m].x);
rescale[m].y = cuda_exp2(rowmax0[m].y - rowmax1[m].y); rescale[m].y = cuda_exp2(rowmax0[m].y - rowmax1[m].y);
...@@ -352,36 +338,34 @@ public: ...@@ -352,36 +338,34 @@ public:
return rescale; return rescale;
} }
__device__ __forceinline__ __device__ __forceinline__ static o_warp compute_pv(p_warp P, v_warp V, o_warp O, rowval_warp rescale) {
static o_warp compute_pv(p_warp P, v_warp V, o_warp O, rowval_warp rescale) { #pragma unroll
#pragma unroll
for (int m = 0; m < WARP_M_TILES; m++) { for (int m = 0; m < WARP_M_TILES; m++) {
#pragma unroll #pragma unroll
for (int n = 0; n < WARP_N_TILES; n++) { for (int n = 0; n < WARP_N_TILES; n++) {
packed_fpsum_t psum = make_uint4(0, 0, 0, 0); packed_fpsum_t psum = make_uint4(0, 0, 0, 0);
#pragma unroll #pragma unroll
for (int k = 0; k < WARP_K_TILES_PV; k++) { for (int k = 0; k < WARP_K_TILES_PV; k++) {
psum = mma_f16xf16_f16(P[m * WARP_K_TILES_PV + k], V[n * WARP_K_TILES_PV + k], psum); psum = mma_f16xf16_f16(P[m * WARP_K_TILES_PV + k], V[n * WARP_K_TILES_PV + k], psum);
} }
packed_f32psum_t pv = packed_fp16_to_fp32(psum); packed_f32psum_t pv = packed_fp16_to_fp32(psum);
packed_f32psum_t &oval = O[m * WARP_N_TILES + n]; packed_f32psum_t &oval = O[m * WARP_N_TILES + n];
oval.data[0] = oval.data[0] * rescale[m].x + pv.data[0]; oval.data[0] = oval.data[0] * rescale[m].x + pv.data[0];
oval.data[1] = oval.data[1] * rescale[m].x + pv.data[1]; oval.data[1] = oval.data[1] * rescale[m].x + pv.data[1];
oval.data[4] = oval.data[4] * rescale[m].x + pv.data[4]; oval.data[4] = oval.data[4] * rescale[m].x + pv.data[4];
oval.data[5] = oval.data[5] * rescale[m].x + pv.data[5]; oval.data[5] = oval.data[5] * rescale[m].x + pv.data[5];
oval.data[2] = oval.data[2] * rescale[m].y + pv.data[2]; oval.data[2] = oval.data[2] * rescale[m].y + pv.data[2];
oval.data[3] = oval.data[3] * rescale[m].y + pv.data[3]; oval.data[3] = oval.data[3] * rescale[m].y + pv.data[3];
oval.data[6] = oval.data[6] * rescale[m].y + pv.data[6]; oval.data[6] = oval.data[6] * rescale[m].y + pv.data[6];
oval.data[7] = oval.data[7] * rescale[m].y + pv.data[7]; oval.data[7] = oval.data[7] * rescale[m].y + pv.data[7];
} }
} }
return O; return O;
} }
__device__ __forceinline__ __device__ __forceinline__ static rowval_warp compute_l(rowval_warp L, rowval_warp rescale, rowval_warp rowsum) {
static rowval_warp compute_l(rowval_warp L, rowval_warp rescale, rowval_warp rowsum) { #pragma unroll
#pragma unroll
for (int m = 0; m < WARP_M_TILES; m++) { for (int m = 0; m < WARP_M_TILES; m++) {
L[m].x = fmaf(L[m].x, rescale[m].x, rowsum[m].x); L[m].x = fmaf(L[m].x, rescale[m].x, rowsum[m].x);
L[m].y = fmaf(L[m].y, rescale[m].y, rowsum[m].y); L[m].y = fmaf(L[m].y, rescale[m].y, rowsum[m].y);
...@@ -389,13 +373,12 @@ public: ...@@ -389,13 +373,12 @@ public:
return L; return L;
} }
__device__ __forceinline__ __device__ __forceinline__ static p_warp qk_to_p(qk_warp QK) {
static p_warp qk_to_p(qk_warp QK) {
static_assert(WARP_K_TILES_QK == WARP_K_TILES_PV); static_assert(WARP_K_TILES_QK == WARP_K_TILES_PV);
p_warp P; p_warp P;
#pragma unroll #pragma unroll
for (int m = 0; m < WARP_M_TILES; m++) { for (int m = 0; m < WARP_M_TILES; m++) {
#pragma unroll #pragma unroll
for (int k = 0; k < WARP_K_TILES_PV; k++) { for (int k = 0; k < WARP_K_TILES_PV; k++) {
P[m * WARP_K_TILES_PV + k] = packed_fp32_to_fp16(QK[m * WARP_K_TILES_QK + k]); P[m * WARP_K_TILES_PV + k] = packed_fp32_to_fp16(QK[m * WARP_K_TILES_QK + k]);
} }
...@@ -416,43 +399,41 @@ public: ...@@ -416,43 +399,41 @@ public:
// O = compute_pv(P, V, O, rescale); // O = compute_pv(P, V, O, rescale);
// } // }
__device__ __forceinline__ __device__ __forceinline__ static std::tuple<p_warp, rowval_warp>
static std::tuple<p_warp, rowval_warp> compute(q_warp Q, k_warp K, rowval_warp &M, rowval_warp &L, float scale) { compute(q_warp Q, k_warp K, rowval_warp &M, rowval_warp &L, float scale) {
qk_warp qk = compute_qk(Q, K); qk_warp qk = compute_qk(Q, K);
rowval_warp M1 = compute_rowmax(qk, M, scale); rowval_warp M1 = compute_rowmax(qk, M, scale);
qk = softmax(qk, M1, scale); qk = softmax(qk, M1, scale);
rowval_warp rowsum = compute_rowsum(qk); rowval_warp rowsum = compute_rowsum(qk);
p_warp P = qk_to_p(qk); p_warp P = qk_to_p(qk);
rowval_warp rescale = compute_rescale(M, M1); rowval_warp rescale = compute_rescale(M, M1);
M = M1; M = M1;
L = compute_l(L, rescale, rowsum); L = compute_l(L, rescale, rowsum);
return {P, rescale}; return {P, rescale};
} }
__device__ __forceinline__ __device__ __forceinline__ static o_warp compute_o(o_warp O, rowval_warp L) {
static o_warp compute_o(o_warp O, rowval_warp L) { #pragma unroll
#pragma unroll
for (int m = 0; m < WARP_M_TILES; m++) { for (int m = 0; m < WARP_M_TILES; m++) {
float2 inv; float2 inv;
inv.x = cuda_frcp(L[m].x); inv.x = cuda_frcp(L[m].x);
inv.y = cuda_frcp(L[m].y); inv.y = cuda_frcp(L[m].y);
#pragma unroll #pragma unroll
for (int n = 0; n < WARP_N_TILES; n++) { for (int n = 0; n < WARP_N_TILES; n++) {
packed_f32psum_t &oval = O[m * WARP_N_TILES + n]; packed_f32psum_t &oval = O[m * WARP_N_TILES + n];
oval.data[0] = oval.data[0] * inv.x; oval.data[0] = oval.data[0] * inv.x;
oval.data[1] = oval.data[1] * inv.x; oval.data[1] = oval.data[1] * inv.x;
oval.data[4] = oval.data[4] * inv.x; oval.data[4] = oval.data[4] * inv.x;
oval.data[5] = oval.data[5] * inv.x; oval.data[5] = oval.data[5] * inv.x;
oval.data[2] = oval.data[2] * inv.y; oval.data[2] = oval.data[2] * inv.y;
oval.data[3] = oval.data[3] * inv.y; oval.data[3] = oval.data[3] * inv.y;
oval.data[6] = oval.data[6] * inv.y; oval.data[6] = oval.data[6] * inv.y;
oval.data[7] = oval.data[7] * inv.y; oval.data[7] = oval.data[7] * inv.y;
} }
} }
return O; return O;
} }
#if 0 #if 0
template<typename Epilogue> template<typename Epilogue>
__device__ __forceinline__ __device__ __forceinline__
...@@ -462,10 +443,10 @@ public: ...@@ -462,10 +443,10 @@ public:
const packed_k_t *ptr_k, const packed_k_t *ptr_k,
const packed_v_t *ptr_v, const packed_v_t *ptr_v,
float scale, float scale,
int ntokens_q, int ntokens_q,
int ntokens_kv, int ntokens_kv,
Epilogue::Arguments epilogueArgs, Epilogue::Arguments epilogueArgs,
bool alwaysfalse) bool alwaysfalse)
{ {
constexpr int NUM_STAGES = 2; constexpr int NUM_STAGES = 2;
...@@ -485,9 +466,9 @@ public: ...@@ -485,9 +466,9 @@ public:
load_v(ptr_v, k, V[k], true); load_v(ptr_v, k, V[k], true);
} }
#pragma unroll #pragma unroll
for (auto &pack : O) { for (auto &pack : O) {
#pragma unroll #pragma unroll
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
pack.data[i] = 0; pack.data[i] = 0;
} }
...@@ -498,7 +479,7 @@ public: ...@@ -498,7 +479,7 @@ public:
M.fill(make_float2(neginf, neginf)); M.fill(make_float2(neginf, neginf));
__shared__ q_warp Q_shmem[NUM_WARPS]; __shared__ q_warp Q_shmem[NUM_WARPS];
#pragma unroll #pragma unroll
for (int i = 0; i < Q.size(); i++) { for (int i = 0; i < Q.size(); i++) {
store<true>(&Q_shmem[warpId][i], Q[i]); store<true>(&Q_shmem[warpId][i], Q[i]);
} }
...@@ -507,9 +488,9 @@ public: ...@@ -507,9 +488,9 @@ public:
// TODO: mask tokens in last block // TODO: mask tokens in last block
for (int k1 = 0; k1 < ntokens_kv / WARP_K; k1 += NUM_STAGES) { for (int k1 = 0; k1 < ntokens_kv / WARP_K; k1 += NUM_STAGES) {
#pragma unroll #pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) { for (int k2 = 0; k2 < NUM_STAGES; k2++) {
#pragma unroll #pragma unroll
for (int i = 0; i < Q.size(); i++) { for (int i = 0; i < Q.size(); i++) {
Q[i] = load<true>(&Q_shmem[warpId][i]); Q[i] = load<true>(&Q_shmem[warpId][i]);
} }
...@@ -519,7 +500,7 @@ public: ...@@ -519,7 +500,7 @@ public:
bool pred = nextk < ntokens_kv / WARP_K; bool pred = nextk < ntokens_kv / WARP_K;
load_k(ptr_k, nextk, K[idx], pred); load_k(ptr_k, nextk, K[idx], pred);
load_v(ptr_v, nextk, V[idx], pred); load_v(ptr_v, nextk, V[idx], pred);
// __syncthreads(); // __syncthreads();
// if (alwaysfalse) { // if (alwaysfalse) {
// dummy = clock(); // dummy = clock();
...@@ -549,42 +530,40 @@ public: ...@@ -549,42 +530,40 @@ public:
} }
#else #else
template<typename Epilogue> template<typename Epilogue>
__device__ __forceinline__ __device__ __forceinline__ static void attention_fp16_block(const BlockInfo binfo,
static void attention_fp16_block( const packed_q_t *ptr_q,
const BlockInfo binfo, const packed_k_t *ptr_k,
const packed_q_t *ptr_q, const packed_v_t *ptr_v,
const packed_k_t *ptr_k, float scale,
const packed_v_t *ptr_v, int ntokens_q,
float scale, int ntokens_kv,
int ntokens_q, Epilogue::Arguments epilogueArgs,
int ntokens_kv, bool alwaysfalse) {
Epilogue::Arguments epilogueArgs,
bool alwaysfalse)
{
// constexpr int NUM_STAGES = 2; // constexpr int NUM_STAGES = 2;
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE;
q_warp Q; // 32 q_warp Q; // 32
k_warp K; // 64 k_warp K; // 64
v_warp V; // 64 v_warp V; // 64
o_warp O; // 64 o_warp O; // 64
rowval_warp L; // 2 rowval_warp L; // 2
rowval_warp M; // 2 rowval_warp M; // 2
load_q(ptr_q, Q, true); load_q(ptr_q, Q, true);
load_k(ptr_k, 0, K, true); load_k(ptr_k, 0, K, true);
#pragma unroll #pragma unroll
for (auto &pack : O) { for (auto &pack : O) {
#pragma unroll #pragma unroll
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
pack.data[i] = 0; pack.data[i] = 0;
} }
} }
static constexpr float neginf = -std::numeric_limits<float>::max(); // not real inf, to prevent nan during computation static constexpr float neginf =
-std::numeric_limits<float>::max(); // not real inf, to prevent nan during computation
L.fill(make_float2(0.0f, 0.0f)); L.fill(make_float2(0.0f, 0.0f));
M.fill(make_float2(neginf, neginf)); M.fill(make_float2(neginf, neginf));
...@@ -593,7 +572,7 @@ public: ...@@ -593,7 +572,7 @@ public:
using q_shmem_t = packed_q_t[NUM_WARPS][SHMEM_TILES][WARP_SIZE]; using q_shmem_t = packed_q_t[NUM_WARPS][SHMEM_TILES][WARP_SIZE];
__shared__ q_shmem_t Q_shmem; __shared__ q_shmem_t Q_shmem;
#pragma unroll #pragma unroll
for (int i = 0; i < SHMEM_TILES; i++) { for (int i = 0; i < SHMEM_TILES; i++) {
store<true>(&Q_shmem[warpId][i][laneId], Q[Q.size() - 1 - i]); store<true>(&Q_shmem[warpId][i][laneId], Q[Q.size() - 1 - i]);
} }
...@@ -602,12 +581,12 @@ public: ...@@ -602,12 +581,12 @@ public:
int dummy = 0; int dummy = 0;
// TODO: mask tokens in last block // TODO: mask tokens in last block
for (int k1 = 0; k1 < ntokens_kv / WARP_K; k1 ++) { for (int k1 = 0; k1 < ntokens_kv / WARP_K; k1++) {
if (alwaysfalse) { if (alwaysfalse) {
ptr_v += K[0].x; ptr_v += K[0].x;
} }
#pragma unroll #pragma unroll
for (int i = 0; i < SHMEM_TILES; i++) { for (int i = 0; i < SHMEM_TILES; i++) {
Q[Q.size() - 1 - i] = load<true>(&Q_shmem[warpId][i][laneId]); Q[Q.size() - 1 - i] = load<true>(&Q_shmem[warpId][i][laneId]);
} }
...@@ -628,8 +607,6 @@ public: ...@@ -628,8 +607,6 @@ public:
dummy = clock(); dummy = clock();
} }
auto [P, rescale] = compute(Q, K, M, L, scale); auto [P, rescale] = compute(Q, K, M, L, scale);
if (alwaysfalse) { if (alwaysfalse) {
...@@ -644,9 +621,7 @@ public: ...@@ -644,9 +621,7 @@ public:
// dummy = clock(); // dummy = clock();
// } // }
load_k(ptr_k, k1+1, K, k1+1 < ntokens_kv / WARP_K); load_k(ptr_k, k1 + 1, K, k1 + 1 < ntokens_kv / WARP_K);
// if (alwaysfalse) { // if (alwaysfalse) {
// dummy = clock(); // dummy = clock();
...@@ -665,38 +640,41 @@ public: ...@@ -665,38 +640,41 @@ public:
auto f16psum = GEMM::packed_fp32_to_fp16(O); auto f16psum = GEMM::packed_fp32_to_fp16(O);
Epilogue()(typename GEMM::BlockInfo{ Epilogue()(
.bm = binfo.batch * binfo.numBlocksM + binfo.bm, typename GEMM::BlockInfo{
.bn = binfo.head, .bm = binfo.batch * binfo.numBlocksM + binfo.bm,
.numBlocksM = binfo.numBatch * binfo.numBlocksM, .bn = binfo.head,
.numBlocksN = binfo.numHeads, .numBlocksM = binfo.numBatch * binfo.numBlocksM,
}, f16psum, binfo.numBatch * binfo.numBlocksM * BLOCK_M, binfo.numHeads * HEAD_DIM, 0, epilogueArgs); .numBlocksN = binfo.numHeads,
},
f16psum,
binfo.numBatch * binfo.numBlocksM * BLOCK_M,
binfo.numHeads * HEAD_DIM,
0,
epilogueArgs);
} }
#endif #endif
template<typename Epilogue> template<typename Epilogue>
struct attention_fp16_kernel { struct attention_fp16_kernel {
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750; static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
static constexpr int SHMEM_SIZE = 0; // sizeof(q_shmem_t); static constexpr int SHMEM_SIZE = 0; // sizeof(q_shmem_t);
__device__ __device__ void operator()(const packed_q_t *ptr_q,
void operator()( const packed_k_t *ptr_k,
const packed_q_t *ptr_q, const packed_v_t *ptr_v,
const packed_k_t *ptr_k, float scale,
const packed_v_t *ptr_v, int ntokens_q,
float scale, int ntokens_kv,
int ntokens_q, Epilogue::Arguments epilogueArgs,
int ntokens_kv, bool alwaysfalse) {
Epilogue::Arguments epilogueArgs,
bool alwaysfalse)
{
BlockInfo binfo = { BlockInfo binfo = {
.bm = (int)blockIdx.x, .bm = (int)blockIdx.x,
.head = (int)blockIdx.y, .head = (int)blockIdx.y,
.batch = (int)blockIdx.z, .batch = (int)blockIdx.z,
.numBlocksM = (int)gridDim.x, .numBlocksM = (int)gridDim.x,
.numHeads = (int)gridDim.y, .numHeads = (int)gridDim.y,
.numBatch = (int)gridDim.z, .numBatch = (int)gridDim.z,
}; };
// extern __shared__ uint8_t shmem[]; // extern __shared__ uint8_t shmem[];
...@@ -706,21 +684,20 @@ public: ...@@ -706,21 +684,20 @@ public:
attention_fp16_block<Epilogue>( attention_fp16_block<Epilogue>(
binfo, binfo,
ptr_q + ((binfo.batch * binfo.numHeads + binfo.head) * binfo.numBlocksM + binfo.bm) * NUM_WARPS * WARP_M_TILES * WARP_D_TILES * WARP_SIZE, ptr_q + ((binfo.batch * binfo.numHeads + binfo.head) * binfo.numBlocksM + binfo.bm) * NUM_WARPS *
ptr_k + (binfo.batch * binfo.numHeads + binfo.head) * ktiles * WARP_K_TILES_QK * WARP_D_TILES * WARP_SIZE, WARP_M_TILES * WARP_D_TILES * WARP_SIZE,
ptr_v + (binfo.batch * binfo.numHeads + binfo.head) * ktiles * WARP_K_TILES_PV * WARP_N_TILES * WARP_SIZE, ptr_k +
(binfo.batch * binfo.numHeads + binfo.head) * ktiles * WARP_K_TILES_QK * WARP_D_TILES * WARP_SIZE,
ptr_v +
(binfo.batch * binfo.numHeads + binfo.head) * ktiles * WARP_K_TILES_PV * WARP_N_TILES * WARP_SIZE,
scale, scale,
ntokens_q, ntokens_q,
ntokens_kv, ntokens_kv,
// *Q_shmem, // *Q_shmem,
epilogueArgs, epilogueArgs,
alwaysfalse alwaysfalse);
);
} }
}; };
}; };
}; // namespace nunchaku::kernels
}; // namespace nunchaku::kernels
...@@ -19,22 +19,23 @@ public: ...@@ -19,22 +19,23 @@ public:
IMPORT_GEMM_BASE(Config); IMPORT_GEMM_BASE(Config);
public: public:
struct EpilogueGelu { struct EpilogueGelu {
struct Arguments { size_t unused; }; struct Arguments {
size_t unused;
};
// static constexpr float SHIFT_VALUE = 0.171875f; // static constexpr float SHIFT_VALUE = 0.171875f;
__device__ __forceinline__ __device__ __forceinline__ void
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) { operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) {
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) { for (int i = 0; i < WARP_M_TILES; i++) {
#pragma unroll #pragma unroll
for (int j = 0; j < WARP_N_TILES; j++) { for (int j = 0; j < WARP_N_TILES; j++) {
#pragma unroll #pragma unroll
for (int k = 0; k < 4; k++) { for (int k = 0; k < 4; k++) {
half2_t &data = fpsum[i * WARP_N_TILES + j].data[k]; half2_t &data = fpsum[i * WARP_N_TILES + j].data[k];
data = gelu_half2(data); data = gelu_half2(data);
// data = __hadd2(data, half2_t(SHIFT_VALUE, SHIFT_VALUE)); // data = __hadd2(data, half2_t(SHIFT_VALUE, SHIFT_VALUE));
} }
} }
...@@ -48,33 +49,41 @@ public: ...@@ -48,33 +49,41 @@ public:
half_t *out; half_t *out;
int actualM, actualN; int actualM, actualN;
half_t *pool_out; // [M / PoolSize, N] half_t *pool_out; // [M / PoolSize, N]
const float *rotary_emb; // [M, HEAD_DIM / 2, ROTARY_EMB_NUM_ELEMENTS] const float *rotary_emb; // [M, HEAD_DIM / 2, ROTARY_EMB_NUM_ELEMENTS]
const half_t *rmsnorm_weight_q; // [HEAD_DIM] const half_t *rmsnorm_weight_q; // [HEAD_DIM]
const half_t *rmsnorm_weight_k; // [HEAD_DIM] const half_t *rmsnorm_weight_k; // [HEAD_DIM]
float epsilon; float epsilon;
}; };
static constexpr int HEAD_DIM = 128; static constexpr int HEAD_DIM = 128;
static constexpr int NUM_HEADS_PER_WARP = WARP_N / HEAD_DIM; static constexpr int NUM_HEADS_PER_WARP = WARP_N / HEAD_DIM;
static constexpr int PoolSize = 128; static constexpr int PoolSize = 128;
static constexpr int NUM_WARPS_PER_POOL = PoolSize / WARP_M; static constexpr int NUM_WARPS_PER_POOL = PoolSize / WARP_M;
static constexpr int NUM_POOLS_PER_BLOCK = BLOCK_M / PoolSize; static constexpr int NUM_POOLS_PER_BLOCK = BLOCK_M / PoolSize;
static constexpr int ROTARY_EMB_NUM_ELEMENTS = 2; // 1 for theta, 2 for {sin, cos} pair static constexpr int ROTARY_EMB_NUM_ELEMENTS = 2; // 1 for theta, 2 for {sin, cos} pair
__device__ __forceinline__ __device__ __forceinline__ static void apply(fpsum_warp fpsum,
static void apply(fpsum_warp fpsum, half_t *out, int M, int N, int K, half_t *pool_out, const float *rotary_emb, const half_t *rmsnorm_weight, float epsilon, int maxRows) { half_t *out,
int M,
int N,
int K,
half_t *pool_out,
const float *rotary_emb,
const half_t *rmsnorm_weight,
float epsilon,
int maxRows) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE;
__shared__ alignas(128) uint8_t shmem[NUM_WARPS][ceilDiv(unpack_fpsum::SHMEM_SIZE, 128) * 128]; __shared__ alignas(128) uint8_t shmem[NUM_WARPS][ceilDiv(unpack_fpsum::SHMEM_SIZE, 128) * 128];
constexpr int PACK_SIZE = unpack_fpsum::PACK_SIZE; constexpr int PACK_SIZE = unpack_fpsum::PACK_SIZE;
using pack_t = unpack_fpsum::pack_t; using pack_t = unpack_fpsum::pack_t;
using pack_rope_t = std::array<float, PACK_SIZE / 2 * ROTARY_EMB_NUM_ELEMENTS>; using pack_rope_t = std::array<float, PACK_SIZE / 2 * ROTARY_EMB_NUM_ELEMENTS>;
constexpr int LANES_PER_HEAD = HEAD_DIM / PACK_SIZE; constexpr int LANES_PER_HEAD = HEAD_DIM / PACK_SIZE;
pack_t reduce_tmp; pack_t reduce_tmp;
...@@ -91,98 +100,107 @@ public: ...@@ -91,98 +100,107 @@ public:
} }
} }
const float *rotary_emb_base_addr = &rotary_emb[(warpId * WARP_M) * HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS + laneId * PACK_SIZE / 2 * ROTARY_EMB_NUM_ELEMENTS]; const float *rotary_emb_base_addr = &rotary_emb[(warpId * WARP_M) * HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS +
laneId * PACK_SIZE / 2 * ROTARY_EMB_NUM_ELEMENTS];
CHECK_NAN(fpsum, "fpsum"); CHECK_NAN(fpsum, "fpsum");
unpack_fpsum()(fpsum, out + warpId * WARP_M * N, N, maxRows - warpId * WARP_M, INT_MAX, shmem[warpId], [&](int rowId, pack_t &pack) ALWAYSINLINE { unpack_fpsum()(fpsum,
// load rope out + warpId * WARP_M * N,
pack_rope_t rope; N,
if (laneId < LANES_PER_HEAD) { maxRows - warpId * WARP_M,
// freq = load(reinterpret_cast<pack_freq_t *>(&freqs_cis[(warpId * WARP_M + rowId) * HEAD_DIM * 2 + laneId * PACK_SIZE * 2])); INT_MAX,
rope = load(reinterpret_cast<const pack_rope_t *>(&rotary_emb_base_addr[rowId * HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS])); shmem[warpId],
} [&](int rowId, pack_t &pack) ALWAYSINLINE {
if constexpr (LANES_PER_HEAD < WARP_SIZE) { // load rope
for (int i = 0; i < rope.size(); i++) { pack_rope_t rope;
rope[i] = __shfl_sync(~0, rope[i], laneId % LANES_PER_HEAD); if (laneId < LANES_PER_HEAD) {
} // freq = load(reinterpret_cast<pack_freq_t *>(&freqs_cis[(warpId * WARP_M + rowId) *
} // HEAD_DIM * 2 + laneId * PACK_SIZE * 2]));
rope = load(reinterpret_cast<const pack_rope_t *>(
// rmsnorm &rotary_emb_base_addr[rowId * HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS]));
float sqrsum = 0.0f; }
for (int i = 0; i < PACK_SIZE; i++) { if constexpr (LANES_PER_HEAD < WARP_SIZE) {
sqrsum += float(pack[i]) * float(pack[i]); for (int i = 0; i < rope.size(); i++) {
CHECK_NAN(sqrsum, "sqrsum"); rope[i] = __shfl_sync(~0, rope[i], laneId % LANES_PER_HEAD);
} }
#pragma unroll }
for (int mask = LANES_PER_HEAD / 2; mask > 0; mask /= 2) {
sqrsum += __shfl_xor_sync(~0, sqrsum, mask); // rmsnorm
} float sqrsum = 0.0f;
sqrsum /= HEAD_DIM; for (int i = 0; i < PACK_SIZE; i++) {
float coef = cuda_frsqrt(sqrsum + epsilon); sqrsum += float(pack[i]) * float(pack[i]);
CHECK_NAN(coef, "coef"); CHECK_NAN(sqrsum, "sqrsum");
}
for (int i = 0; i < PACK_SIZE; i++) { #pragma unroll
pack[i] *= coef * float(rms[i]); for (int mask = LANES_PER_HEAD / 2; mask > 0; mask /= 2) {
sqrsum += __shfl_xor_sync(~0, sqrsum, mask);
CHECK_NAN(rms[i], "rms.wgt"); }
CHECK_NAN(pack[i], "rms.out"); sqrsum /= HEAD_DIM;
} float coef = cuda_frsqrt(sqrsum + epsilon);
CHECK_NAN(coef, "coef");
#if 1
// rope for (int i = 0; i < PACK_SIZE; i++) {
for (int i = 0; i < PACK_SIZE; i += 2) { pack[i] *= coef * float(rms[i]);
float2 pack2 = half22float2(half2_t(pack[i], pack[i+1]));
CHECK_NAN(rms[i], "rms.wgt");
CHECK_NAN(freq[i].x, "rope.freq"); CHECK_NAN(pack[i], "rms.out");
CHECK_NAN(freq[i].y, "rope.freq"); }
CHECK_NAN(freq[i+1].x, "rope.freq");
CHECK_NAN(freq[i+1].y, "rope.freq"); #if 1
// rope
// half2_t tmp = __hmul2(freq[i], pack2); for (int i = 0; i < PACK_SIZE; i += 2) {
// tmp = __hfma2(freq[i+1], pack2, tmp); float2 pack2 = half22float2(half2_t(pack[i], pack[i + 1]));
// pack[i] = tmp.x;
// pack[i+1] = tmp.y; CHECK_NAN(freq[i].x, "rope.freq");
CHECK_NAN(freq[i].y, "rope.freq");
// printf("block.x=%d block.y=%d warpId=%d rowId=%d (%d) freqs = %f %f %f %f\n", CHECK_NAN(freq[i + 1].x, "rope.freq");
// blockIdx.x, blockIdx.y, warpId, rowId, CHECK_NAN(freq[i + 1].y, "rope.freq");
// blockIdx.x * BLOCK_M + warpId * WARP_M + rowId,
// (float)freq[i].x, (float)freq[i].y, (float)freq[i+1].x, (float)freq[i+1].y // half2_t tmp = __hmul2(freq[i], pack2);
// ); // tmp = __hfma2(freq[i+1], pack2, tmp);
// __trap(); // pack[i] = tmp.x;
// pack[i+1] = tmp.y;
// half2_t tmp = __hmul2(half2_t(pack2.x, pack2.x), freq[i]);
// tmp = __hfma2(half2_t(pack2.y, pack2.y), freq[i+1], tmp); // printf("block.x=%d block.y=%d warpId=%d rowId=%d (%d) freqs = %f %f %f %f\n",
// pack[i] = tmp.x; // blockIdx.x, blockIdx.y, warpId, rowId,
// pack[i+1] = tmp.y; // blockIdx.x * BLOCK_M + warpId * WARP_M + rowId,
// (float)freq[i].x, (float)freq[i].y, (float)freq[i+1].x, (float)freq[i+1].y
float sin, cos; // );
// __trap();
if constexpr (ROTARY_EMB_NUM_ELEMENTS == 1) {
sin = cuda_sin(rope[i / 2]); // half2_t tmp = __hmul2(half2_t(pack2.x, pack2.x), freq[i]);
cos = cuda_cos(rope[i / 2]); // tmp = __hfma2(half2_t(pack2.y, pack2.y), freq[i+1], tmp);
} // pack[i] = tmp.x;
if constexpr (ROTARY_EMB_NUM_ELEMENTS == 2) { // pack[i+1] = tmp.y;
sin = rope[i];
cos = rope[i+1]; float sin, cos;
}
if constexpr (ROTARY_EMB_NUM_ELEMENTS == 1) {
// pack[i] = pack2.x * freq[i].x + pack2.y * freq[i].y; sin = cuda_sin(rope[i / 2]);
// pack[i+1] = pack2.x * freq[i+1].x + pack2.y * freq[i+1].y; cos = cuda_cos(rope[i / 2]);
}
pack[i] = half_t(pack2.x * cos - pack2.y * sin); if constexpr (ROTARY_EMB_NUM_ELEMENTS == 2) {
pack[i+1] = half_t(pack2.x * sin + pack2.y * cos); sin = rope[i];
cos = rope[i + 1];
CHECK_NAN(pack[i], "rope.out"); }
CHECK_NAN(pack[i+1], "rope.out");
} // pack[i] = pack2.x * freq[i].x + pack2.y * freq[i].y;
#endif // pack[i+1] = pack2.x * freq[i+1].x + pack2.y * freq[i+1].y;
pack[i] = half_t(pack2.x * cos - pack2.y * sin);
pack[i + 1] = half_t(pack2.x * sin + pack2.y * cos);
CHECK_NAN(pack[i], "rope.out");
CHECK_NAN(pack[i + 1], "rope.out");
}
#endif
// mean pool // mean pool
for (int i = 0; i < PACK_SIZE; i++) { for (int i = 0; i < PACK_SIZE; i++) {
reduce_tmp[i] += pack[i]; reduce_tmp[i] += pack[i];
} }
}); });
if (!pool_out) { if (!pool_out) {
return; return;
...@@ -193,7 +211,7 @@ public: ...@@ -193,7 +211,7 @@ public:
if (warpId < NUM_POOLS_PER_BLOCK) { if (warpId < NUM_POOLS_PER_BLOCK) {
const int row = warpId * NUM_WARPS_PER_POOL; const int row = warpId * NUM_WARPS_PER_POOL;
reduce_tmp = load<true>(&pool[row]); reduce_tmp = load<true>(&pool[row]);
for (int i = 1; i < NUM_WARPS_PER_POOL; i++) { for (int i = 1; i < NUM_WARPS_PER_POOL; i++) {
pack_t pack = load<true>(&pool[row + i]); pack_t pack = load<true>(&pool[row + i]);
...@@ -210,8 +228,8 @@ public: ...@@ -210,8 +228,8 @@ public:
__syncthreads(); __syncthreads();
} }
__device__ __forceinline__ __device__ __forceinline__ void
void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, const Arguments &args) { operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, const Arguments &args) {
const int bm = binfo.bm; const int bm = binfo.bm;
const int bn = binfo.bn; const int bn = binfo.bn;
...@@ -223,34 +241,39 @@ public: ...@@ -223,34 +241,39 @@ public:
assert(args.actualN == N); assert(args.actualN == N);
if (is_q || is_k) { if (is_q || is_k) {
apply( apply(fpsum,
fpsum, args.out + bm * BLOCK_M * args.actualN + bn * BLOCK_N,
args.out + bm * BLOCK_M * args.actualN + bn * BLOCK_N, M,
M, N, K, N,
args.pool_out ? args.pool_out + bm * BLOCK_M / PoolSize * N : nullptr, K,
args.rotary_emb + bm * BLOCK_M * (HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS), args.pool_out ? args.pool_out + bm * BLOCK_M / PoolSize * N : nullptr,
is_q ? args.rmsnorm_weight_q : args.rmsnorm_weight_k, args.rotary_emb + bm * BLOCK_M * (HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS),
args.epsilon, is_q ? args.rmsnorm_weight_q : args.rmsnorm_weight_k,
args.actualM - bm * BLOCK_M args.epsilon,
); args.actualM - bm * BLOCK_M);
} else { } else {
EpilogueDefault()(binfo, fpsum, M, N, K, typename EpilogueDefault::Arguments{ EpilogueDefault()(binfo,
.out = args.out, fpsum,
.actualM = args.actualM, M,
.actualN = args.actualN, N,
}); K,
typename EpilogueDefault::Arguments{
.out = args.out,
.actualM = args.actualM,
.actualN = args.actualN,
});
} }
} }
}; };
struct EpilogueRMSNormRope { struct EpilogueRMSNormRope {
static constexpr int HEAD_DIM = 128; static constexpr int HEAD_DIM = 128;
static constexpr int NUM_HEADS_PER_WARP = WARP_N / HEAD_DIM; static constexpr int NUM_HEADS_PER_WARP = WARP_N / HEAD_DIM;
static constexpr int WARP_N_TILES_PER_HEAD = WARP_N_TILES / NUM_HEADS_PER_WARP; static constexpr int WARP_N_TILES_PER_HEAD = WARP_N_TILES / NUM_HEADS_PER_WARP;
static constexpr int ROTARY_EMB_NUM_ELEMENTS = 2; static constexpr int ROTARY_EMB_NUM_ELEMENTS = 2;
using packed_rotemb_t = float4; using packed_rotemb_t = float4;
static constexpr int WARP_N_ROTEMB_TILES = WARP_N_TILES / NUM_HEADS_PER_WARP * 2; static constexpr int WARP_N_ROTEMB_TILES = WARP_N_TILES / NUM_HEADS_PER_WARP * 2;
using rotemb_warp = std::array<packed_rotemb_t, WARP_M_TILES * WARP_N_ROTEMB_TILES>; // 128 regs using rotemb_warp = std::array<packed_rotemb_t, WARP_M_TILES * WARP_N_ROTEMB_TILES>; // 128 regs
...@@ -263,17 +286,17 @@ public: ...@@ -263,17 +286,17 @@ public:
float epsilon; float epsilon;
}; };
__device__ __forceinline__ __device__ __forceinline__ static rotemb_warp load_rotemb(const packed_rotemb_t *ptr_rotemb) {
static rotemb_warp load_rotemb(const packed_rotemb_t *ptr_rotemb) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE;
rotemb_warp rotemb; rotemb_warp rotemb;
const packed_rotemb_t *ptrlane = &ptr_rotemb[warpId * WARP_M_TILES * WARP_N_ROTEMB_TILES * WARP_SIZE + laneId]; const packed_rotemb_t *ptrlane =
&ptr_rotemb[warpId * WARP_M_TILES * WARP_N_ROTEMB_TILES * WARP_SIZE + laneId];
unrolled_loop<WARP_M_TILES>([&]<int i>() { unrolled_loop<WARP_M_TILES>([&]<int i>() {
unrolled_loop<WARP_N_ROTEMB_TILES>([&]<int j>() { unrolled_loop<WARP_N_ROTEMB_TILES>([&]<int j>() {
constexpr int offset = (i * WARP_N_ROTEMB_TILES + j) * WARP_SIZE; constexpr int offset = (i * WARP_N_ROTEMB_TILES + j) * WARP_SIZE;
rotemb[i * WARP_N_ROTEMB_TILES + j] = load(&ptrlane[offset]); rotemb[i * WARP_N_ROTEMB_TILES + j] = load(&ptrlane[offset]);
}); });
}); });
...@@ -281,28 +304,26 @@ public: ...@@ -281,28 +304,26 @@ public:
return rotemb; return rotemb;
} }
__device__ __forceinline__ __device__ __forceinline__ static void load_rmsnorm(const half_t *ptr_rmsnorm_weight, half_t *shmem) {
static void load_rmsnorm(const half_t *ptr_rmsnorm_weight, half_t *shmem) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
static constexpr int PACK_SIZE = HEAD_DIM / WARP_SIZE; static constexpr int PACK_SIZE = HEAD_DIM / WARP_SIZE;
using packed_t = std::array<half_t, PACK_SIZE>; using packed_t = std::array<half_t, PACK_SIZE>;
packed_t pack = load(reinterpret_cast<const packed_t *>(ptr_rmsnorm_weight + laneId * PACK_SIZE)); packed_t pack = load(reinterpret_cast<const packed_t *>(ptr_rmsnorm_weight + laneId * PACK_SIZE));
store<true>(reinterpret_cast<packed_t *>(shmem + laneId * PACK_SIZE), pack); store<true>(reinterpret_cast<packed_t *>(shmem + laneId * PACK_SIZE), pack);
} }
__device__ __forceinline__ __device__ __forceinline__ static packed_fpsum_t load_rmsnorm_from_shmem(half_t *shmem, int n) {
static packed_fpsum_t load_rmsnorm_from_shmem(half_t *shmem, int n) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
const int col = n * INSN_N + laneId / 16 * 8; // lane 0-15: n*16+0, lane 16-31: n*16+8 const int col = n * INSN_N + laneId / 16 * 8; // lane 0-15: n*16+0, lane 16-31: n*16+8
uint4 tmp; uint4 tmp;
ldmatrix(shmem + col, tmp); ldmatrix(shmem + col, tmp);
return kernels::bit_cast<packed_fpsum_t>(tmp); return kernels::bit_cast<packed_fpsum_t>(tmp);
} }
__device__ __forceinline__ __device__ __forceinline__ static void
static void apply(fpsum_warp &fpsum, const packed_rotemb_t *ptr_rotemb, const half_t *ptr_rmsnorm_weight, float epsilon) { apply(fpsum_warp &fpsum, const packed_rotemb_t *ptr_rotemb, const half_t *ptr_rmsnorm_weight, float epsilon) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE;
...@@ -319,21 +340,21 @@ public: ...@@ -319,21 +340,21 @@ public:
return fval.x * fval.x + fval.y * fval.y; return fval.x * fval.x + fval.y * fval.y;
}; };
#pragma unroll #pragma unroll
for (int head = 0; head < NUM_HEADS_PER_WARP; head++) { for (int head = 0; head < NUM_HEADS_PER_WARP; head++) {
const int n_offset = head * WARP_N_TILES_PER_HEAD; const int n_offset = head * WARP_N_TILES_PER_HEAD;
#pragma unroll #pragma unroll
for (int m = 0; m < WARP_M_TILES; m++) { for (int m = 0; m < WARP_M_TILES; m++) {
float sqrsum[2] = {0.0f, 0.0f}; float sqrsum[2] = {0.0f, 0.0f};
#pragma unroll #pragma unroll
for (int n = 0; n < WARP_N_TILES_PER_HEAD; n++) { for (int n = 0; n < WARP_N_TILES_PER_HEAD; n++) {
sqrsum[0] += sqr(fpsum[m * WARP_N_TILES + n + n_offset].data[0]); sqrsum[0] += sqr(fpsum[m * WARP_N_TILES + n + n_offset].data[0]);
sqrsum[1] += sqr(fpsum[m * WARP_N_TILES + n + n_offset].data[1]); sqrsum[1] += sqr(fpsum[m * WARP_N_TILES + n + n_offset].data[1]);
sqrsum[0] += sqr(fpsum[m * WARP_N_TILES + n + n_offset].data[2]); sqrsum[0] += sqr(fpsum[m * WARP_N_TILES + n + n_offset].data[2]);
sqrsum[1] += sqr(fpsum[m * WARP_N_TILES + n + n_offset].data[3]); sqrsum[1] += sqr(fpsum[m * WARP_N_TILES + n + n_offset].data[3]);
} }
#pragma unroll #pragma unroll
for (int mask = 1; mask <= 2; mask *= 2) { for (int mask = 1; mask <= 2; mask *= 2) {
sqrsum[0] += __shfl_xor_sync(~0, sqrsum[0], mask); sqrsum[0] += __shfl_xor_sync(~0, sqrsum[0], mask);
sqrsum[1] += __shfl_xor_sync(~0, sqrsum[1], mask); sqrsum[1] += __shfl_xor_sync(~0, sqrsum[1], mask);
...@@ -343,14 +364,14 @@ public: ...@@ -343,14 +364,14 @@ public:
} }
} }
#pragma unroll #pragma unroll
for (int head = 0; head < NUM_HEADS_PER_WARP; head++) { for (int head = 0; head < NUM_HEADS_PER_WARP; head++) {
const int n_offset = head * WARP_N_TILES_PER_HEAD; const int n_offset = head * WARP_N_TILES_PER_HEAD;
#pragma unroll #pragma unroll
for (int n = 0; n < WARP_N_TILES_PER_HEAD; n++) { for (int n = 0; n < WARP_N_TILES_PER_HEAD; n++) {
packed_f32psum_t rms = packed_fp16_to_fp32(load_rmsnorm_from_shmem(&shmem_rmsnorm[warpId][0], n)); packed_f32psum_t rms = packed_fp16_to_fp32(load_rmsnorm_from_shmem(&shmem_rmsnorm[warpId][0], n));
#pragma unroll #pragma unroll
for (int m = 0; m < WARP_M_TILES; m++) { for (int m = 0; m < WARP_M_TILES; m++) {
packed_f32psum_t pack = packed_fp16_to_fp32(fpsum[m * WARP_N_TILES + n + n_offset]); packed_f32psum_t pack = packed_fp16_to_fp32(fpsum[m * WARP_N_TILES + n + n_offset]);
pack.data[0] *= rmsnorm_coef[head][m][0] * rms.data[0]; pack.data[0] *= rmsnorm_coef[head][m][0] * rms.data[0];
...@@ -385,8 +406,8 @@ public: ...@@ -385,8 +406,8 @@ public:
} }
} }
__device__ __forceinline__ __device__ __forceinline__ void
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) { operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) {
const int bm = binfo.bm; const int bm = binfo.bm;
const int bn = binfo.bn; const int bn = binfo.bn;
...@@ -395,22 +416,20 @@ public: ...@@ -395,22 +416,20 @@ public:
const bool is_k = !is_q && bn < binfo.numBlocksN / 3 * 2; const bool is_k = !is_q && bn < binfo.numBlocksN / 3 * 2;
if (is_q || is_k) { if (is_q || is_k) {
apply( apply(fpsum,
fpsum, args.rotary_emb + bm * NUM_WARPS * WARP_M_TILES * WARP_N_ROTEMB_TILES * WARP_SIZE,
args.rotary_emb + bm * NUM_WARPS * WARP_M_TILES * WARP_N_ROTEMB_TILES * WARP_SIZE, is_q ? args.rmsnorm_weight_q : args.rmsnorm_weight_k,
is_q ? args.rmsnorm_weight_q : args.rmsnorm_weight_k, args.epsilon);
args.epsilon
);
} }
} }
}; };
struct EpiloguePackQKV { struct EpiloguePackQKV {
using attn_half_t = half; using attn_half_t = half;
using attn_half2_t = half2; using attn_half2_t = half2;
using packed_qkv_t = uint4; using packed_qkv_t = uint4;
static constexpr int HEAD_DIM = 128; static constexpr int HEAD_DIM = 128;
static constexpr int INSN_K_QK = 16; static constexpr int INSN_K_QK = 16;
static constexpr int INSN_K_PV = 16; static constexpr int INSN_K_PV = 16;
...@@ -424,8 +443,7 @@ public: ...@@ -424,8 +443,7 @@ public:
int strideHead_v; int strideHead_v;
}; };
__device__ __forceinline__ __device__ __forceinline__ static attn_half2_t convert_half2(half2_t input) {
static attn_half2_t convert_half2(half2_t input) {
if constexpr (std::is_same_v<half2_t, attn_half2_t>) { if constexpr (std::is_same_v<half2_t, attn_half2_t>) {
return input; return input;
} else { } else {
...@@ -434,8 +452,7 @@ public: ...@@ -434,8 +452,7 @@ public:
} }
} }
__device__ __forceinline__ __device__ __forceinline__ static packed_qkv_t pack_q(packed_fpsum_t input) {
static packed_qkv_t pack_q(packed_fpsum_t input) {
packed_qkv_t output; packed_qkv_t output;
output.x = kernels::bit_cast<int>(convert_half2(input.data[0])); output.x = kernels::bit_cast<int>(convert_half2(input.data[0]));
output.y = kernels::bit_cast<int>(convert_half2(input.data[1])); output.y = kernels::bit_cast<int>(convert_half2(input.data[1]));
...@@ -444,8 +461,7 @@ public: ...@@ -444,8 +461,7 @@ public:
return output; return output;
} }
__device__ __forceinline__ __device__ __forceinline__ static packed_qkv_t pack_k(packed_fpsum_t input) {
static packed_qkv_t pack_k(packed_fpsum_t input) {
packed_qkv_t output; packed_qkv_t output;
output.x = kernels::bit_cast<int>(convert_half2(input.data[0])); output.x = kernels::bit_cast<int>(convert_half2(input.data[0]));
output.y = kernels::bit_cast<int>(convert_half2(input.data[2])); output.y = kernels::bit_cast<int>(convert_half2(input.data[2]));
...@@ -454,8 +470,7 @@ public: ...@@ -454,8 +470,7 @@ public:
return output; return output;
} }
__device__ __forceinline__ __device__ __forceinline__ static packed_qkv_t pack_v(packed_fpsum_t input) {
static packed_qkv_t pack_v(packed_fpsum_t input) {
packed_qkv_t output; packed_qkv_t output;
output.x = kernels::bit_cast<int>(convert_half2(movmatrix(input.data[0]))); output.x = kernels::bit_cast<int>(convert_half2(movmatrix(input.data[0])));
output.y = kernels::bit_cast<int>(convert_half2(movmatrix(input.data[1]))); output.y = kernels::bit_cast<int>(convert_half2(movmatrix(input.data[1])));
...@@ -464,8 +479,7 @@ public: ...@@ -464,8 +479,7 @@ public:
return output; return output;
} }
__device__ __forceinline__ __device__ __forceinline__ static void mask(packed_qkv_t &pack, uint32_t maskVal, int m, int maxRows) {
static void mask(packed_qkv_t &pack, uint32_t maskVal, int m, int maxRows) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
if (m * INSN_M + laneId / 4 >= maxRows) { if (m * INSN_M + laneId / 4 >= maxRows) {
pack.x = maskVal; pack.x = maskVal;
...@@ -479,15 +493,15 @@ public: ...@@ -479,15 +493,15 @@ public:
// qkv: [batch, head, bm, NUM_WARPS, WARP_M_TILES, WARP_N_TILES, WARP_SIZE] of packed_qkv_t // qkv: [batch, head, bm, NUM_WARPS, WARP_M_TILES, WARP_N_TILES, WARP_SIZE] of packed_qkv_t
template<typename F> template<typename F>
__device__ __forceinline__ __device__ __forceinline__ static void
static void apply(fpsum_warp &fpsum, packed_qkv_t *ptr_output, int maxRows, F &&funcPack, attn_half2_t maskVal) { apply(fpsum_warp &fpsum, packed_qkv_t *ptr_output, int maxRows, F &&funcPack, attn_half2_t maskVal) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE;
static_assert(HEAD_DIM == WARP_N); static_assert(HEAD_DIM == WARP_N);
packed_qkv_t *ptrlane = &ptr_output[((warpId * WARP_M_TILES + 0) * WARP_N_TILES + 0) * WARP_SIZE + laneId]; packed_qkv_t *ptrlane = &ptr_output[((warpId * WARP_M_TILES + 0) * WARP_N_TILES + 0) * WARP_SIZE + laneId];
unrolled_loop<WARP_M_TILES>([&]<int m>() ALWAYSINLINE { unrolled_loop<WARP_M_TILES>([&]<int m>() ALWAYSINLINE {
unrolled_loop<WARP_N_TILES>([&]<int n>() ALWAYSINLINE { unrolled_loop<WARP_N_TILES>([&]<int n>() ALWAYSINLINE {
packed_qkv_t pack = funcPack(fpsum[m * WARP_N_TILES + n]); packed_qkv_t pack = funcPack(fpsum[m * WARP_N_TILES + n]);
...@@ -497,35 +511,34 @@ public: ...@@ -497,35 +511,34 @@ public:
}); });
} }
__device__ __forceinline__ __device__ __forceinline__ void
void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, const Arguments &args) { operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, const Arguments &args) {
const int bm = binfo.bm; const int bm = binfo.bm;
const int bn = binfo.bn; const int bn = binfo.bn;
assert(binfo.numBlocksN % 3 == 0); assert(binfo.numBlocksN % 3 == 0);
const int numBlocksQ = binfo.numBlocksN / 3; const int numBlocksQ = binfo.numBlocksN / 3;
const bool is_q = bn < numBlocksQ; const bool is_q = bn < numBlocksQ;
const bool is_k = !is_q && bn < numBlocksQ * 2; const bool is_k = !is_q && bn < numBlocksQ * 2;
// bn is head_id (assume HEAD_DIM == WARP_N) // bn is head_id (assume HEAD_DIM == WARP_N)
int head_id, strideHead; int head_id, strideHead;
if (is_q) { if (is_q) {
head_id = bn; head_id = bn;
strideHead = args.strideHead_q; strideHead = args.strideHead_q;
} else if (is_k) { } else if (is_k) {
head_id = bn - numBlocksQ; head_id = bn - numBlocksQ;
strideHead = args.strideHead_k; strideHead = args.strideHead_k;
} else { } else {
head_id = bn - numBlocksQ * 2; head_id = bn - numBlocksQ * 2;
strideHead = args.strideHead_v; strideHead = args.strideHead_v;
} }
int block_offset = head_id * strideHead + bm * NUM_WARPS * WARP_M_TILES * WARP_N_TILES * WARP_SIZE; int block_offset = head_id * strideHead + bm * NUM_WARPS * WARP_M_TILES * WARP_N_TILES * WARP_SIZE;
int maxRows = args.actualM - bm * BLOCK_M; int maxRows = args.actualM - bm * BLOCK_M;
// static constexpr float neginf = -std::numeric_limits<float>::infinity(); // static constexpr float neginf = -std::numeric_limits<float>::infinity();
if (is_q) { if (is_q) {
apply(fpsum, args.out_q + block_offset, maxRows, pack_q, attn_half2_t(0.0f, 0.0f)); apply(fpsum, args.out_q + block_offset, maxRows, pack_q, attn_half2_t(0.0f, 0.0f));
} else if (is_k) { } else if (is_k) {
...@@ -537,9 +550,9 @@ public: ...@@ -537,9 +550,9 @@ public:
}; };
struct EpilogueLiteLA { struct EpilogueLiteLA {
__device__ __forceinline__ __device__ __forceinline__ static packed_f32psum_t
static packed_f32psum_t mma_litela(packed_fpsum_t k, packed_fpsum_t v, packed_f32psum_t psum) { mma_litela(packed_fpsum_t k, packed_fpsum_t v, packed_f32psum_t psum) {
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
k.data[i] = movmatrix(k.data[i]); k.data[i] = movmatrix(k.data[i]);
v.data[i] = movmatrix(v.data[i]); v.data[i] = movmatrix(v.data[i]);
...@@ -549,20 +562,20 @@ public: ...@@ -549,20 +562,20 @@ public:
} }
static constexpr int LITELA_HEAD_DIM = 32; static constexpr int LITELA_HEAD_DIM = 32;
static constexpr int LITELA_K_TILES = LITELA_HEAD_DIM / 16; static constexpr int LITELA_K_TILES = LITELA_HEAD_DIM / 16;
static constexpr int LITELA_V_TILES = LITELA_HEAD_DIM / 16; static constexpr int LITELA_V_TILES = LITELA_HEAD_DIM / 16;
static constexpr int SHMEM_SIZE = NUM_WARPS * (LITELA_HEAD_DIM + 1) * (LITELA_HEAD_DIM + 8) * sizeof(float); static constexpr int SHMEM_SIZE = NUM_WARPS * (LITELA_HEAD_DIM + 1) * (LITELA_HEAD_DIM + 8) * sizeof(float);
// out_vk: [batch_size, num_heads, head_dim + 1, head_dim] // out_vk: [batch_size, num_heads, head_dim + 1, head_dim]
__device__ __forceinline__ __device__ __forceinline__ static void
static void apply_litela(const BlockInfo binfo, fpsum_warp fpsum, float *out_vk, int num_blocks_per_batch) { apply_litela(const BlockInfo binfo, fpsum_warp fpsum, float *out_vk, int num_blocks_per_batch) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE;
using vk_t = float[NUM_WARPS][LITELA_HEAD_DIM + 1][LITELA_HEAD_DIM + 8]; using vk_t = float[NUM_WARPS][LITELA_HEAD_DIM + 1][LITELA_HEAD_DIM + 8];
extern __shared__ uint8_t shmem[]; extern __shared__ uint8_t shmem[];
vk_t &shmem_vk = *reinterpret_cast<vk_t *>(shmem); vk_t &shmem_vk = *reinterpret_cast<vk_t *>(shmem);
static_assert(sizeof(vk_t) == SHMEM_SIZE); static_assert(sizeof(vk_t) == SHMEM_SIZE);
...@@ -570,11 +583,13 @@ public: ...@@ -570,11 +583,13 @@ public:
assert(binfo.numBlocksN % 3 == 0); assert(binfo.numBlocksN % 3 == 0);
const int num_heads = binfo.numBlocksN / 3 * 2 * (WARP_N / (LITELA_HEAD_DIM * 2)); const int num_heads = binfo.numBlocksN / 3 * 2 * (WARP_N / (LITELA_HEAD_DIM * 2));
const int batch_id = binfo.bm / num_blocks_per_batch; const int batch_id = binfo.bm / num_blocks_per_batch;
for (int head_id = 0; head_id < WARP_N / (LITELA_HEAD_DIM * 2); head_id++) { for (int head_id = 0; head_id < WARP_N / (LITELA_HEAD_DIM * 2); head_id++) {
const int global_head_id = (binfo.bn - binfo.numBlocksN / 3) * (WARP_N / (LITELA_HEAD_DIM * 2)) + head_id; const int global_head_id =
float *out_vk_current_head = out_vk + (batch_id * num_heads + global_head_id) * (LITELA_HEAD_DIM + 1) * LITELA_HEAD_DIM; (binfo.bn - binfo.numBlocksN / 3) * (WARP_N / (LITELA_HEAD_DIM * 2)) + head_id;
float *out_vk_current_head =
out_vk + (batch_id * num_heads + global_head_id) * (LITELA_HEAD_DIM + 1) * LITELA_HEAD_DIM;
for (int i = laneId; i < sizeof(shmem_vk) / sizeof(float) / NUM_WARPS; i += WARP_SIZE) { for (int i = laneId; i < sizeof(shmem_vk) / sizeof(float) / NUM_WARPS; i += WARP_SIZE) {
*((&shmem_vk[warpId][0][0]) + i) = 0; *((&shmem_vk[warpId][0][0]) + i) = 0;
...@@ -583,12 +598,13 @@ public: ...@@ -583,12 +598,13 @@ public:
for (int tile_v = 0; tile_v < LITELA_V_TILES; tile_v++) { for (int tile_v = 0; tile_v < LITELA_V_TILES; tile_v++) {
for (int tile_k = 0; tile_k < LITELA_K_TILES; tile_k++) { for (int tile_k = 0; tile_k < LITELA_K_TILES; tile_k++) {
packed_f32psum_t attn_sum = { 0 }; packed_f32psum_t attn_sum = {0};
for (int i = 0; i < WARP_M_TILES; i++) { for (int i = 0; i < WARP_M_TILES; i++) {
packed_fpsum_t k = fpsum[i * WARP_N_TILES + head_id * (LITELA_HEAD_DIM * 2) / 16 + tile_k]; packed_fpsum_t k = fpsum[i * WARP_N_TILES + head_id * (LITELA_HEAD_DIM * 2) / 16 + tile_k];
packed_fpsum_t v = fpsum[i * WARP_N_TILES + head_id * (LITELA_HEAD_DIM * 2) / 16 + LITELA_HEAD_DIM / 16 + tile_v]; packed_fpsum_t v = fpsum[i * WARP_N_TILES + head_id * (LITELA_HEAD_DIM * 2) / 16 +
LITELA_HEAD_DIM / 16 + tile_v];
for (int j = 0; j < 4; j++) { for (int j = 0; j < 4; j++) {
k.data[j] = __hmax2(k.data[j], half2_t(0, 0)); // relu k.data[j] = __hmax2(k.data[j], half2_t(0, 0)); // relu
} }
attn_sum = mma_litela(k, v, attn_sum); attn_sum = mma_litela(k, v, attn_sum);
} }
...@@ -607,14 +623,14 @@ public: ...@@ -607,14 +623,14 @@ public:
} }
} }
for (int tile_k = 0; tile_k < LITELA_K_TILES; tile_k++) { for (int tile_k = 0; tile_k < LITELA_K_TILES; tile_k++) {
packed_f32psum_t attn_sum = { 0 }; packed_f32psum_t attn_sum = {0};
for (int i = 0; i < WARP_M_TILES; i++) { for (int i = 0; i < WARP_M_TILES; i++) {
packed_fpsum_t k = fpsum[i * WARP_N_TILES + head_id * (LITELA_HEAD_DIM * 2) / 16 + tile_k]; packed_fpsum_t k = fpsum[i * WARP_N_TILES + head_id * (LITELA_HEAD_DIM * 2) / 16 + tile_k];
packed_fpsum_t v = {}; packed_fpsum_t v = {};
for (int j = 0; j < 4; j++) { for (int j = 0; j < 4; j++) {
k.data[j] = __hmax2(k.data[j], half2_t(0, 0)); // relu k.data[j] = __hmax2(k.data[j], half2_t(0, 0)); // relu
} }
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
v.data[i] = half2_t(1, 1); v.data[i] = half2_t(1, 1);
} }
...@@ -660,21 +676,23 @@ public: ...@@ -660,21 +676,23 @@ public:
int actualM; int actualM;
}; };
__device__ __forceinline__ __device__ __forceinline__ void
void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, const Arguments &args) { operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, const Arguments &args) {
const int bm = binfo.bm; const int bm = binfo.bm;
const int bn = binfo.bn; const int bn = binfo.bn;
if (bn < binfo.numBlocksN / 3) { if (bn < binfo.numBlocksN / 3) {
fpsum = apply_act(fpsum, [](half_t x) { return __hmax(x, 0); }); // relu fpsum = apply_act(fpsum, [](half_t x) { return __hmax(x, 0); }); // relu
return EpilogueDefault()( return EpilogueDefault()(binfo,
binfo, fpsum,
fpsum, M,
M, N / 3, K, typename EpilogueDefault::Arguments{ N / 3,
.out = args.out_q, K,
.actualM = args.actualM, typename EpilogueDefault::Arguments{
.actualN = N / 3, .out = args.out_q,
}); .actualM = args.actualM,
.actualN = N / 3,
});
} }
return apply_litela(binfo, fpsum, args.out_vk, args.num_blocks_per_batch); return apply_litela(binfo, fpsum, args.out_vk, args.num_blocks_per_batch);
...@@ -686,46 +704,49 @@ public: ...@@ -686,46 +704,49 @@ public:
struct vk_mul_q_kernel { struct vk_mul_q_kernel {
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750; static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
// FIXME FIXME FIXME // FIXME FIXME FIXME
__device__ __device__ void operator()(half_t *q, const float *vk, float eps, int num_tokens) {
void operator()(half_t *q, const float *vk, float eps, int num_tokens) {
const int block_id = blockIdx.x; const int block_id = blockIdx.x;
const int head_id = blockIdx.y; const int head_id = blockIdx.y;
const int batch_id = blockIdx.z; const int batch_id = blockIdx.z;
const int num_blocks = gridDim.x; const int num_blocks = gridDim.x;
const int num_heads = gridDim.y; const int num_heads = gridDim.y;
const int block_size = blockDim.x; const int block_size = blockDim.x;
bool pred = block_id * block_size + threadIdx.x < num_tokens; bool pred = block_id * block_size + threadIdx.x < num_tokens;
half_t *localq = &q[(((batch_id * num_blocks + block_id) * block_size + threadIdx.x) * num_heads + head_id) * LITELA_HEAD_DIM]; half_t *localq =
&q[(((batch_id * num_blocks + block_id) * block_size + threadIdx.x) * num_heads + head_id) *
LITELA_HEAD_DIM];
const float *localvk = &vk[(batch_id * num_heads + head_id) * (LITELA_HEAD_DIM + 1) * LITELA_HEAD_DIM]; const float *localvk = &vk[(batch_id * num_heads + head_id) * (LITELA_HEAD_DIM + 1) * LITELA_HEAD_DIM];
// half_t *localout = &out[(((batch_id * num_blocks + block_id) * block_size + threadIdx.x) * num_heads + head_id) * LITELA_HEAD_DIM]; // half_t *localout = &out[(((batch_id * num_blocks + block_id) * block_size + threadIdx.x) * num_heads
// + head_id) * LITELA_HEAD_DIM];
using packed_q = std::array<half_t, 8>; using packed_q = std::array<half_t, 8>;
using packed_vk = std::array<float, 4>; using packed_vk = std::array<float, 4>;
half_t qblock[LITELA_HEAD_DIM]; half_t qblock[LITELA_HEAD_DIM];
for (int i = 0; i < LITELA_HEAD_DIM; i += sizeof(packed_q) / sizeof(half_t)) { for (int i = 0; i < LITELA_HEAD_DIM; i += sizeof(packed_q) / sizeof(half_t)) {
if (pred) { if (pred) {
*reinterpret_cast<packed_q *>(&qblock[i]) = load(reinterpret_cast<const packed_q *>(&localq[i])); *reinterpret_cast<packed_q *>(&qblock[i]) =
load(reinterpret_cast<const packed_q *>(&localq[i]));
} }
} }
float outblock[LITELA_HEAD_DIM + 1]; float outblock[LITELA_HEAD_DIM + 1];
#pragma unroll #pragma unroll
for (int j = 0; j < LITELA_HEAD_DIM + 1; j++) { for (int j = 0; j < LITELA_HEAD_DIM + 1; j++) {
outblock[j] = 0; outblock[j] = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < LITELA_HEAD_DIM; i += sizeof(packed_vk) / sizeof(float)) { for (int i = 0; i < LITELA_HEAD_DIM; i += sizeof(packed_vk) / sizeof(float)) {
packed_vk vkpack = load(reinterpret_cast<const packed_vk *>(&localvk[j * LITELA_HEAD_DIM + i])); packed_vk vkpack = load(reinterpret_cast<const packed_vk *>(&localvk[j * LITELA_HEAD_DIM + i]));
#pragma unroll #pragma unroll
for (int k = 0; k < vkpack.size(); k++) { for (int k = 0; k < vkpack.size(); k++) {
outblock[j] += (float)qblock[i + k] * vkpack[k]; outblock[j] += (float)qblock[i + k] * vkpack[k];
} }
} }
} }
for (int i = 0; i < LITELA_HEAD_DIM; i += sizeof(packed_q) / sizeof(half_t)) { for (int i = 0; i < LITELA_HEAD_DIM; i += sizeof(packed_q) / sizeof(half_t)) {
packed_q opack; packed_q opack;
for (int k = 0; k < opack.size(); k++) { for (int k = 0; k < opack.size(); k++) {
...@@ -739,11 +760,11 @@ public: ...@@ -739,11 +760,11 @@ public:
}; };
}; };
template<typename Epilogue> template<typename Epilogue>
struct test_epilogue_kernel { struct test_epilogue_kernel {
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750; static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
static constexpr size_t SHMEM_PER_WARP = ceilDiv<size_t>(Base::template load_act_to_fpsum<false>::SHMEM_SIZE, 128) * 128; static constexpr size_t SHMEM_PER_WARP =
ceilDiv<size_t>(Base::template load_act_to_fpsum<false>::SHMEM_SIZE, 128) * 128;
static constexpr size_t SHMEM_SIZE = SHMEM_PER_WARP * NUM_WARPS; static constexpr size_t SHMEM_SIZE = SHMEM_PER_WARP * NUM_WARPS;
struct Arguments { struct Arguments {
...@@ -757,18 +778,16 @@ public: ...@@ -757,18 +778,16 @@ public:
typename Epilogue::Arguments argsEpilogue; typename Epilogue::Arguments argsEpilogue;
}; };
__device__ __forceinline__ __device__ __forceinline__ void operator()(Arguments args) {
void operator()(Arguments args)
{
const BlockInfo binfo = { const BlockInfo binfo = {
.bm = (int)blockIdx.x, .bm = (int)blockIdx.x,
.bn = (int)blockIdx.y, .bn = (int)blockIdx.y,
.numBlocksM = (int)gridDim.x, .numBlocksM = (int)gridDim.x,
.numBlocksN = (int)gridDim.y, .numBlocksN = (int)gridDim.y,
}; };
const int bm = binfo.bm; const int bm = binfo.bm;
const int bn = binfo.bn; const int bn = binfo.bn;
const int warpId = threadIdx.x / WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE;
const int m_offset = bm * BLOCK_M + warpId * WARP_M; const int m_offset = bm * BLOCK_M + warpId * WARP_M;
...@@ -778,26 +797,27 @@ public: ...@@ -778,26 +797,27 @@ public:
fpsum_warp fpsum; fpsum_warp fpsum;
Base::template load_act_to_fpsum<false>()( Base::template load_act_to_fpsum<false>()(args.input + m_offset * args.actualN + n_offset,
args.input + m_offset * args.actualN + n_offset, args.actualN,
args.actualN, args.actualM - m_offset,
args.actualM - m_offset, args.actualN - n_offset,
args.actualN - n_offset, fpsum,
fpsum, shmem + warpId * SHMEM_PER_WARP);
shmem + warpId * SHMEM_PER_WARP
);
Epilogue()(binfo, fpsum, args.M, args.N, 0, args.argsEpilogue); Epilogue()(binfo, fpsum, args.M, args.N, 0, args.argsEpilogue);
EpilogueDefault()(binfo, fpsum, args.M, args.N, 0, typename EpilogueDefault::Arguments{ EpilogueDefault()(binfo,
.out = args.output, fpsum,
.actualM = args.actualM, args.M,
.actualN = args.actualN, args.N,
}); 0,
typename EpilogueDefault::Arguments{
.out = args.output,
.actualM = args.actualM,
.actualN = args.actualN,
});
} }
}; };
}; };
}; // namespace nunchaku::kernels
}; // namespace nunchaku::kernels
\ No newline at end of file
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
#include "gemm_utils.cuh" #include "gemm_utils.cuh"
#include "mma_earlycuda.cuh" #include "mma_earlycuda.cuh"
#pragma nv_diag_suppress 177 #pragma nv_diag_suppress 177
#ifdef _MSC_VER #ifdef _MSC_VER
...@@ -32,8 +31,8 @@ class GEMMConfig_W4A4 { ...@@ -32,8 +31,8 @@ class GEMMConfig_W4A4 {
public: public:
// BE CAREFUL: weights need to be repacked when the tiling size changes // BE CAREFUL: weights need to be repacked when the tiling size changes
static constexpr int BLOCK_M = 256; static constexpr int BLOCK_M = 256;
static constexpr int BLOCK_N = 128; static constexpr int BLOCK_N = 128;
static constexpr int WARP_SIZE = 32; static constexpr int WARP_SIZE = 32;
static constexpr int NUM_WARPS = 8; static constexpr int NUM_WARPS = 8;
...@@ -45,18 +44,18 @@ public: ...@@ -45,18 +44,18 @@ public:
// may generate incorrect results in certain circumstances // may generate incorrect results in certain circumstances
static constexpr bool FASTER_I2F = faster_i2f; static constexpr bool FASTER_I2F = faster_i2f;
using half_t = typename std::conditional_t<bf16, __nv_bfloat16, half>; using half_t = typename std::conditional_t<bf16, __nv_bfloat16, half>;
using half2_t = typename std::conditional_t<bf16, __nv_bfloat162, half2>; using half2_t = typename std::conditional_t<bf16, __nv_bfloat162, half2>;
}; };
using GEMMConfig_W4A4_FP16 = GEMMConfig_W4A4<false>; using GEMMConfig_W4A4_FP16 = GEMMConfig_W4A4<false>;
using GEMMConfig_W4A4_BF16 = GEMMConfig_W4A4<true>; using GEMMConfig_W4A4_BF16 = GEMMConfig_W4A4<true>;
using GEMMConfig_W4A4_FP16_FasterI2F = GEMMConfig_W4A4<false, true>; using GEMMConfig_W4A4_FP16_FasterI2F = GEMMConfig_W4A4<false, true>;
class GEMMConfig_W8A8 { class GEMMConfig_W8A8 {
public: public:
static constexpr int BLOCK_M = 256; static constexpr int BLOCK_M = 256;
static constexpr int BLOCK_N = 128; static constexpr int BLOCK_N = 128;
static constexpr int WARP_SIZE = 32; static constexpr int WARP_SIZE = 32;
static constexpr int NUM_WARPS = 8; static constexpr int NUM_WARPS = 8;
...@@ -97,13 +96,13 @@ public: ...@@ -97,13 +96,13 @@ public:
/** /**
* refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16864-c * refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16864-c
* *
* wscales store order: (pack = 4) * wscales store order: (pack = 4)
* 0 1 8 9 <-- load by lane 0, broadcast to lane {0, 4, 8, ..., 28} (8x) * 0 1 8 9 <-- load by lane 0, broadcast to lane {0, 4, 8, ..., 28} (8x)
* 2 3 10 11 <-- load by lane 1, broadcast to lane {1, 5, 9, ..., 29} (8x) * 2 3 10 11 <-- load by lane 1, broadcast to lane {1, 5, 9, ..., 29} (8x)
* 4 5 12 13 <-- load by lane 2, broadcast to lane {2, 6, 10, ..., 30} (8x) * 4 5 12 13 <-- load by lane 2, broadcast to lane {2, 6, 10, ..., 30} (8x)
* 6 7 14 15 <-- load by lane 3, broadcast to lane {3, 7, 11, ..., 31} (8x) * 6 7 14 15 <-- load by lane 3, broadcast to lane {3, 7, 11, ..., 31} (8x)
* *
* 16 17 24 25 <-- load by lane 4, broadcast to lane {0, 4, 8, ..., 28} (8x) * 16 17 24 25 <-- load by lane 4, broadcast to lane {0, 4, 8, ..., 28} (8x)
* ... * ...
* 22 23 30 31 <-- load by lane 7, broadcast to lane {3, 7, 11, ..., 31} (8x) * 22 23 30 31 <-- load by lane 7, broadcast to lane {3, 7, 11, ..., 31} (8x)
...@@ -111,24 +110,25 @@ public: ...@@ -111,24 +110,25 @@ public:
* 112 113 120 121 <-- load by lane 28, broadcast to lane {0, 4, 8, ..., 28} (8x) * 112 113 120 121 <-- load by lane 28, broadcast to lane {0, 4, 8, ..., 28} (8x)
* ... * ...
* 118 119 126 127 <-- load by lane 31, broadcast to lane {3, 7, 11, ..., 31} (8x) * 118 119 126 127 <-- load by lane 31, broadcast to lane {3, 7, 11, ..., 31} (8x)
* *
* wscales store order: (pack = 8) * wscales store order: (pack = 8)
* 0 1 8 9 16 17 24 25 <-- load by lane 0, broadcast to lane {0, 4, 8, ..., 28} (8x) * 0 1 8 9 16 17 24 25 <-- load by lane 0, broadcast to lane {0, 4, 8, ..., 28} (8x)
* 2 3 10 11 18 19 26 27 <-- load by lane 1, broadcast to lane {1, 5, 9, ..., 29} (8x) * 2 3 10 11 18 19 26 27 <-- load by lane 1, broadcast to lane {1, 5, 9, ..., 29} (8x)
* 4 5 12 13 20 21 28 29 <-- load by lane 2, broadcast to lane {2, 6, 10, ..., 30} (8x) * 4 5 12 13 20 21 28 29 <-- load by lane 2, broadcast to lane {2, 6, 10, ..., 30} (8x)
* 6 7 14 15 22 23 30 31 <-- load by lane 3, broadcast to lane {3, 7, 11, ..., 31} (8x) * 6 7 14 15 22 23 30 31 <-- load by lane 3, broadcast to lane {3, 7, 11, ..., 31} (8x)
* *
* 224 225 232 233 240 241 248 249 <-- load by lane 28, broadcast to lane {0, 4, 8, ..., 28} (8x) * 224 225 232 233 240 241 248 249 <-- load by lane 28, broadcast to lane {0, 4, 8, ..., 28} (8x)
* ... * ...
* 230 231 238 239 246 247 254 255 <-- load by lane 31, broadcast to lane {3, 7, 11, ..., 31} (8x) * 230 231 238 239 246 247 254 255 <-- load by lane 31, broadcast to lane {3, 7, 11, ..., 31} (8x)
* *
* {k}-th wscale used by lane {i} => {k // (WSCALES_PACK_SIZE * WARP_SIZE)}-th pack, in lane {4*(k // WSCALES_PACK_SIZE) + i % 4}, element {k % WSCALES_PACK_SIZE} * {k}-th wscale used by lane {i} => {k // (WSCALES_PACK_SIZE * WARP_SIZE)}-th pack, in lane {4*(k //
* * WSCALES_PACK_SIZE) + i % 4}, element {k % WSCALES_PACK_SIZE}
*
* max pack size set to 8 since max load size is 16 bytes / lane * max pack size set to 8 since max load size is 16 bytes / lane
* min pack size set to 2 since shuffle granularity is 32b 2*half * min pack size set to 2 since shuffle granularity is 32b 2*half
* */ * */
static constexpr int WSCALES_PACK_SIZE = clamp(WARP_N / WARP_SIZE, 4 / sizeof(half), 16 / sizeof(half)); static constexpr int WSCALES_PACK_SIZE = clamp(WARP_N / WARP_SIZE, 4 / sizeof(half), 16 / sizeof(half));
static constexpr int WSCALES_NUM_PACKS = ceilDiv(WARP_N, (WSCALES_PACK_SIZE * WARP_SIZE)); static constexpr int WSCALES_NUM_PACKS = ceilDiv(WARP_N, (WSCALES_PACK_SIZE * WARP_SIZE));
static constexpr int WSCALES_VALID_LANES = std::min(WARP_SIZE, WARP_N / WSCALES_PACK_SIZE); static constexpr int WSCALES_VALID_LANES = std::min(WARP_SIZE, WARP_N / WSCALES_PACK_SIZE);
/** /**
...@@ -145,16 +145,17 @@ public: ...@@ -145,16 +145,17 @@ public:
* ... * ...
* 54 62 * 54 62
* 55 63 <-- load by lane 31, broadcast to lane {28, 29, 30, 31} (4x) * 55 63 <-- load by lane 31, broadcast to lane {28, 29, 30, 31} (4x)
* *
* {k}-th wscale used by lane {i} => {k // (ASCALES_PACK_SIZE * WARP_SIZE)}-th pack, in lane {8*(k // ASCALES_PACK_SIZE) + i // 4}, element {k % ASCALES_PACK_SIZE} * {k}-th wscale used by lane {i} => {k // (ASCALES_PACK_SIZE * WARP_SIZE)}-th pack, in lane {8*(k //
* ASCALES_PACK_SIZE) + i // 4}, element {k % ASCALES_PACK_SIZE}
*/ */
static constexpr int ASCALES_PACK_SIZE = clamp(WARP_M / WARP_SIZE, 4 / sizeof(half), 16 / sizeof(half)); static constexpr int ASCALES_PACK_SIZE = clamp(WARP_M / WARP_SIZE, 4 / sizeof(half), 16 / sizeof(half));
static constexpr int ASCALES_NUM_PACKS = ceilDiv(WARP_M, (ASCALES_PACK_SIZE * WARP_SIZE)); static constexpr int ASCALES_NUM_PACKS = ceilDiv(WARP_M, (ASCALES_PACK_SIZE * WARP_SIZE));
static constexpr int ASCALES_VALID_LANES = std::min(WARP_SIZE, WARP_M / ASCALES_PACK_SIZE); static constexpr int ASCALES_VALID_LANES = std::min(WARP_SIZE, WARP_M / ASCALES_PACK_SIZE);
using packed_act_t = uint4; using packed_act_t = uint4;
using packed_wgt_t = uint4; using packed_wgt_t = uint4;
struct alignas(32) packed_psum_t { struct alignas(32) packed_psum_t {
int data[8]; int data[8];
}; };
...@@ -184,15 +185,14 @@ public: ...@@ -184,15 +185,14 @@ public:
half2_t data[ASCALES_PACK_SIZE / 2]; half2_t data[ASCALES_PACK_SIZE / 2];
}; };
using act_warp = std::array<packed_act_t, WARP_M_TILES>; using act_warp = std::array<packed_act_t, WARP_M_TILES>;
using wgt_warp = std::array<packed_wgt_t, WARP_N_TILES>; using wgt_warp = std::array<packed_wgt_t, WARP_N_TILES>;
using ascale_warp = std::array<packed_ascale_t, ASCALES_NUM_PACKS>; using ascale_warp = std::array<packed_ascale_t, ASCALES_NUM_PACKS>;
using wscale_warp = std::array<packed_wscale_t, WSCALES_NUM_PACKS>; using wscale_warp = std::array<packed_wscale_t, WSCALES_NUM_PACKS>;
using fpsum_warp = std::array<packed_fpsum_t, WARP_M_TILES * WARP_N_TILES>; using fpsum_warp = std::array<packed_fpsum_t, WARP_M_TILES * WARP_N_TILES>;
using f32psum_warp = std::array<packed_f32psum_t, WARP_M_TILES * WARP_N_TILES>; using f32psum_warp = std::array<packed_f32psum_t, WARP_M_TILES * WARP_N_TILES>;
using gated_fpsum_warp = std::array<packed_gated_fpsum_t, WARP_M_TILES * WARP_N_TILES>; using gated_fpsum_warp = std::array<packed_gated_fpsum_t, WARP_M_TILES * WARP_N_TILES>;
struct BlockInfo { struct BlockInfo {
int bm; int bm;
int bn; int bn;
...@@ -200,19 +200,19 @@ public: ...@@ -200,19 +200,19 @@ public:
int numBlocksN; int numBlocksN;
}; };
__device__ __forceinline__ __device__ __forceinline__ static packed_f32psum_t
static packed_f32psum_t mma_f16xf16_f32(packed_fpsum_t a, packed_fpsum_t b, packed_f32psum_t psum) { mma_f16xf16_f32(packed_fpsum_t a, packed_fpsum_t b, packed_f32psum_t psum) {
static_assert(std::is_same_v<half_t, half> || std::is_same_v<half_t, __nv_bfloat16>); static_assert(std::is_same_v<half_t, half> || std::is_same_v<half_t, __nv_bfloat16>);
static constexpr bool is_bf16 = std::is_same_v<half_t, __nv_bfloat16>; static constexpr bool is_bf16 = std::is_same_v<half_t, __nv_bfloat16>;
uint4 out1 = mma_m16n8k16_f32f16f16f32<is_bf16>( uint4 out1 = mma_m16n8k16_f32f16f16f32<is_bf16>(
kernels::bit_cast<uint4>(a), kernels::bit_cast<uint4>(a),
kernels::bit_cast<uint2>(std::array<half2_t, 2>(b.data[0], b.data[1])), kernels::bit_cast<uint2>(std::array<half2_t, 2>(b.data[0], b.data[1])),
kernels::bit_cast<uint4>(float4(psum.data[0], psum.data[1], psum.data[2], psum.data[3]))); kernels::bit_cast<uint4>(float4(psum.data[0], psum.data[1], psum.data[2], psum.data[3])));
uint4 out2 = mma_m16n8k16_f32f16f16f32<is_bf16>( uint4 out2 = mma_m16n8k16_f32f16f16f32<is_bf16>(
kernels::bit_cast<uint4>(a), kernels::bit_cast<uint4>(a),
kernels::bit_cast<uint2>(std::array<half2_t, 2>(b.data[2], b.data[3])), kernels::bit_cast<uint2>(std::array<half2_t, 2>(b.data[2], b.data[3])),
kernels::bit_cast<uint4>(float4(psum.data[4], psum.data[5], psum.data[6], psum.data[7]))); kernels::bit_cast<uint4>(float4(psum.data[4], psum.data[5], psum.data[6], psum.data[7])));
psum.data[0] = kernels::bit_cast<float>(out1.x); psum.data[0] = kernels::bit_cast<float>(out1.x);
psum.data[1] = kernels::bit_cast<float>(out1.y); psum.data[1] = kernels::bit_cast<float>(out1.y);
...@@ -222,12 +222,11 @@ public: ...@@ -222,12 +222,11 @@ public:
psum.data[5] = kernels::bit_cast<float>(out2.y); psum.data[5] = kernels::bit_cast<float>(out2.y);
psum.data[6] = kernels::bit_cast<float>(out2.z); psum.data[6] = kernels::bit_cast<float>(out2.z);
psum.data[7] = kernels::bit_cast<float>(out2.w); psum.data[7] = kernels::bit_cast<float>(out2.w);
return psum; return psum;
} }
__device__ __forceinline__ __device__ __forceinline__ static packed_fpsum_t packed_fp32_to_fp16(packed_f32psum_t input) {
static packed_fpsum_t packed_fp32_to_fp16(packed_f32psum_t input) {
packed_fpsum_t results; packed_fpsum_t results;
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
results.data[i] = float22half2<half2_t>(float2(input.data[i * 2], input.data[i * 2 + 1])); results.data[i] = float22half2<half2_t>(float2(input.data[i * 2], input.data[i * 2 + 1]));
...@@ -235,31 +234,28 @@ public: ...@@ -235,31 +234,28 @@ public:
return results; return results;
} }
__device__ __forceinline__ __device__ __forceinline__ static packed_f32psum_t packed_fp16_to_fp32(packed_fpsum_t input) {
static packed_f32psum_t packed_fp16_to_fp32(packed_fpsum_t input) {
packed_f32psum_t results; packed_f32psum_t results;
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
float2 tmp = half22float2(input.data[i]); float2 tmp = half22float2(input.data[i]);
results.data[i * 2] = tmp.x; results.data[i * 2] = tmp.x;
results.data[i * 2 + 1] = tmp.y; results.data[i * 2 + 1] = tmp.y;
} }
return results; return results;
} }
__device__ __forceinline__ __device__ __forceinline__ static fpsum_warp packed_fp32_to_fp16(f32psum_warp input) {
static fpsum_warp packed_fp32_to_fp16(f32psum_warp input) {
fpsum_warp results; fpsum_warp results;
#pragma unroll #pragma unroll
for (int i = 0; i < results.size(); i++) { for (int i = 0; i < results.size(); i++) {
results[i] = packed_fp32_to_fp16(input[i]); results[i] = packed_fp32_to_fp16(input[i]);
} }
return results; return results;
} }
__device__ __forceinline__ __device__ __forceinline__ static f32psum_warp packed_fp16_to_fp32(fpsum_warp input) {
static f32psum_warp packed_fp16_to_fp32(fpsum_warp input) {
f32psum_warp results; f32psum_warp results;
#pragma unroll #pragma unroll
for (int i = 0; i < results.size(); i++) { for (int i = 0; i < results.size(); i++) {
results[i] = packed_fp16_to_fp32(input[i]); results[i] = packed_fp16_to_fp32(input[i]);
} }
...@@ -267,108 +263,110 @@ public: ...@@ -267,108 +263,110 @@ public:
} }
// activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t // activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t
__device__ __forceinline__ __device__ __forceinline__ static void load_act(const packed_act_t *act, int k, int K, act_warp &out, bool pred) {
static void load_act(const packed_act_t *act, int k, int K, act_warp &out, bool pred) {
int laneId = threadIdx.x % WARP_SIZE; int laneId = threadIdx.x % WARP_SIZE;
int warpId = threadIdx.x / WARP_SIZE; int warpId = threadIdx.x / WARP_SIZE;
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) { for (int i = 0; i < WARP_M_TILES; i++) {
//if (pred) { // if (pred) {
// out[i] = load(&act[((warpId * WARP_M_TILES + i) * K / WARP_K + k) * WARP_SIZE + laneId]); // out[i] = load(&act[((warpId * WARP_M_TILES + i) * K / WARP_K + k) * WARP_SIZE + laneId]);
out[i] = load_pred(&act[((k * NUM_WARPS + warpId) * WARP_M_TILES + i) * WARP_SIZE + laneId], pred); out[i] = load_pred(&act[((k * NUM_WARPS + warpId) * WARP_M_TILES + i) * WARP_SIZE + laneId], pred);
//} //}
} }
} }
// weight: column major: [N / BLOCK_N, 1, K / WARP_K, WARP_N_TILES, WARP_SIZE] of packed_wgt_t // weight: column major: [N / BLOCK_N, 1, K / WARP_K, WARP_N_TILES, WARP_SIZE] of packed_wgt_t
__device__ __forceinline__ __device__ __forceinline__ static void load_wgt(const packed_wgt_t *wgt, int k, int K, wgt_warp &out, bool pred) {
static void load_wgt(const packed_wgt_t *wgt, int k, int K, wgt_warp &out, bool pred) {
int laneId = threadIdx.x % WARP_SIZE; int laneId = threadIdx.x % WARP_SIZE;
// const packed_wgt_t *ptr = &wgt[(0 * K / WARP_K + k) * WARP_SIZE + laneId]; // const packed_wgt_t *ptr = &wgt[(0 * K / WARP_K + k) * WARP_SIZE + laneId];
const packed_wgt_t *ptr = &wgt[(0 + k * WARP_N_TILES) * WARP_SIZE + laneId]; const packed_wgt_t *ptr = &wgt[(0 + k * WARP_N_TILES) * WARP_SIZE + laneId];
// int offset = K / WARP_K * WARP_SIZE; // int offset = K / WARP_K * WARP_SIZE;
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_N_TILES; i++) { for (int i = 0; i < WARP_N_TILES; i++) {
//if (pred) { // if (pred) {
// out[i] = load(&wgt[(i * K / WARP_K + k) * WARP_SIZE + laneId]); // out[i] = load(&wgt[(i * K / WARP_K + k) * WARP_SIZE + laneId]);
// out[i] = load(&wgt[(i + k * WARP_N_TILES) * WARP_SIZE + laneId]); // out[i] = load(&wgt[(i + k * WARP_N_TILES) * WARP_SIZE + laneId]);
out[i] = load_pred(&ptr[i * WARP_SIZE], pred); out[i] = load_pred(&ptr[i * WARP_SIZE], pred);
// ptr += offset; // ptr += offset;
//} //}
} }
} }
// ascales: row major [M / BLOCK_M, K / group size, NUM_WARPS, ASCALES_NUM_PACKS, ASCALES_VALID_LANES] of packed_ascale_t // ascales: row major [M / BLOCK_M, K / group size, NUM_WARPS, ASCALES_NUM_PACKS, ASCALES_VALID_LANES] of
__device__ __forceinline__ // packed_ascale_t
static void load_ascale(const packed_ascale_t *ascales, int group, int M, ascale_warp &out, bool pred) { __device__ __forceinline__ static void
load_ascale(const packed_ascale_t *ascales, int group, int M, ascale_warp &out, bool pred) {
int laneId = threadIdx.x % WARP_SIZE; int laneId = threadIdx.x % WARP_SIZE;
int warpId = threadIdx.x / WARP_SIZE; int warpId = threadIdx.x / WARP_SIZE;
#pragma unroll #pragma unroll
for (int i = 0; i < ASCALES_NUM_PACKS; i++) { for (int i = 0; i < ASCALES_NUM_PACKS; i++) {
// if (pred && laneId < ASCALES_VALID_LANES) { // if (pred && laneId < ASCALES_VALID_LANES) {
// out[i] = ascales[(group * M / WARP_M + warpId) * ASCALES_VALID_LANES * ASCALES_NUM_PACKS + i * ASCALES_VALID_LANES + laneId]; // out[i] = ascales[(group * M / WARP_M + warpId) * ASCALES_VALID_LANES * ASCALES_NUM_PACKS + i *
out[i] = load_pred(&ascales[(group * NUM_WARPS + warpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES + i * ASCALES_VALID_LANES + laneId], pred && laneId < ASCALES_VALID_LANES); // ASCALES_VALID_LANES + laneId];
out[i] = load_pred(&ascales[(group * NUM_WARPS + warpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES +
i * ASCALES_VALID_LANES + laneId],
pred && laneId < ASCALES_VALID_LANES);
// } // }
} }
} }
// wscales: column major [N / BLOCK_N, K / group size, 1, WSCALES_NUM_PACKS, WSCALES_VALID_LANES] of packed_wscale_t </del> // wscales: column major [N / BLOCK_N, K / group size, 1, WSCALES_NUM_PACKS, WSCALES_VALID_LANES] of packed_wscale_t
__device__ __forceinline__ // </del>
static void load_wscale(const packed_wscale_t *wscales, int group, int N, wscale_warp &out, bool pred) { __device__ __forceinline__ static void
load_wscale(const packed_wscale_t *wscales, int group, int N, wscale_warp &out, bool pred) {
int laneId = threadIdx.x % WARP_SIZE; int laneId = threadIdx.x % WARP_SIZE;
// static_assert(WSCALES_NUM_PACKS * WSCALES_VALID_LANES == 32); // static_assert(WSCALES_NUM_PACKS * WSCALES_VALID_LANES == 32);
// static_assert(sizeof(packed_wscale_t) == 8); // static_assert(sizeof(packed_wscale_t) == 8);
// const packed_wscale_t *ptr = &wscales[(group * WSCALES_NUM_PACKS + 0) * WSCALES_VALID_LANES + laneId]; // const packed_wscale_t *ptr = &wscales[(group * WSCALES_NUM_PACKS + 0) * WSCALES_VALID_LANES + laneId];
// // const packed_wscale_t *ptr = (const packed_wscale_t *)((const char *)wscales) + ((group * WSCALES_NUM_PACKS + 0) * WSCALES_VALID_LANES + laneId) * sizeof(packed_wscale_t); // // const packed_wscale_t *ptr = (const packed_wscale_t *)((const char *)wscales) + ((group *
// WSCALES_NUM_PACKS + 0) * WSCALES_VALID_LANES + laneId) * sizeof(packed_wscale_t);
#pragma unroll #pragma unroll
for (int i = 0; i < WSCALES_NUM_PACKS; i++) { for (int i = 0; i < WSCALES_NUM_PACKS; i++) {
// if (pred && laneId < WSCALES_VALID_LANES) { // if (pred && laneId < WSCALES_VALID_LANES) {
// out[i] = wscales[group * N / WARP_N * WSCALES_VALID_LANES * WSCALES_NUM_PACKS + i * WSCALES_VALID_LANES + laneId]; // out[i] = wscales[group * N / WARP_N * WSCALES_VALID_LANES * WSCALES_NUM_PACKS + i * WSCALES_VALID_LANES +
// out[i] = load(&wscales[group * N / WARP_N * WSCALES_VALID_LANES * WSCALES_NUM_PACKS + i * WSCALES_VALID_LANES + laneId]); // laneId]; out[i] = load(&wscales[group * N / WARP_N * WSCALES_VALID_LANES * WSCALES_NUM_PACKS + i *
out[i] = load_pred(&wscales[(group * WSCALES_NUM_PACKS + i) * WSCALES_VALID_LANES + laneId], pred && laneId < WSCALES_VALID_LANES); // WSCALES_VALID_LANES + laneId]);
// out[i] = load(&ptr[i * WSCALES_VALID_LANES]); out[i] = load_pred(&wscales[(group * WSCALES_NUM_PACKS + i) * WSCALES_VALID_LANES + laneId],
pred && laneId < WSCALES_VALID_LANES);
// out[i] = load(&ptr[i * WSCALES_VALID_LANES]);
// } // }
} }
} }
// get {k}-th and {k+1}-th wscale from the block, k must be multiples of 2, k must be uniform across all lanes // get {k}-th and {k+1}-th wscale from the block, k must be multiples of 2, k must be uniform across all lanes
__device__ __forceinline__ __device__ __forceinline__ static half2_t broadcast_wscale(wscale_warp block, int k, int laneId) {
static half2_t broadcast_wscale(wscale_warp block, int k, int laneId) { const int packIdx = k / (WSCALES_PACK_SIZE * WARP_SIZE);
const int packIdx = k / (WSCALES_PACK_SIZE * WARP_SIZE); const int srcLane = 4 * (k / WSCALES_PACK_SIZE) + laneId % 4;
const int srcLane = 4 * (k / WSCALES_PACK_SIZE) + laneId % 4;
const int elementIdx = k % WSCALES_PACK_SIZE / 2; const int elementIdx = k % WSCALES_PACK_SIZE / 2;
return __shfl_sync(~0, block[packIdx].data[elementIdx], srcLane); return __shfl_sync(~0, block[packIdx].data[elementIdx], srcLane);
} }
// get {k}-th and {k+1}-th ascale from the block, k must be multiples of 2, k must be uniform across all lanes // get {k}-th and {k+1}-th ascale from the block, k must be multiples of 2, k must be uniform across all lanes
__device__ __forceinline__ __device__ __forceinline__ static half2_t broadcast_ascale(ascale_warp block, int k, int laneId) {
static half2_t broadcast_ascale(ascale_warp block, int k, int laneId) { const int packIdx = k / (ASCALES_PACK_SIZE * WARP_SIZE);
const int packIdx = k / (ASCALES_PACK_SIZE * WARP_SIZE); const int srcLane = 8 * (k / ASCALES_PACK_SIZE) + laneId / 4;
const int srcLane = 8 * (k / ASCALES_PACK_SIZE) + laneId / 4;
const int elementIdx = k % ASCALES_PACK_SIZE / 2; const int elementIdx = k % ASCALES_PACK_SIZE / 2;
return __shfl_sync(~0, block[packIdx].data[elementIdx], srcLane); return __shfl_sync(~0, block[packIdx].data[elementIdx], srcLane);
} }
struct i2f_normal { struct i2f_normal {
__device__ __forceinline__ __device__ __forceinline__ static float2 int2float2(int x, int y) {
static float2 int2float2(int x, int y) {
return make_float2(__int2float_rn(x), __int2float_rn(y)); return make_float2(__int2float_rn(x), __int2float_rn(y));
} }
__device__ __forceinline__ __device__ __forceinline__ static half2_t int2half2(int x, int y) {
static half2_t int2half2(int x, int y) {
return float22half2<half2_t>(int2float2(x, y)); return float22half2<half2_t>(int2float2(x, y));
} }
}; };
template<typename i2f = i2f_normal, typename F> template<typename i2f = i2f_normal, typename F>
__device__ __forceinline__ __device__ __forceinline__ static void
static void apply_scales(F &&getpsum, ascale_warp ascale, wscale_warp wscale, fpsum_warp &fpsum) { apply_scales(F &&getpsum, ascale_warp ascale, wscale_warp wscale, fpsum_warp &fpsum) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE;
...@@ -377,8 +375,8 @@ public: ...@@ -377,8 +375,8 @@ public:
for (int i = 0; i < WARP_M_TILES; i++) { for (int i = 0; i < WARP_M_TILES; i++) {
half2_t as = broadcast_ascale(ascale, i * 2, laneId); half2_t as = broadcast_ascale(ascale, i * 2, laneId);
asx[i] = half2_t(as.x, as.x); asx[i] = half2_t(as.x, as.x);
asy[i] = half2_t(as.y, as.y); asy[i] = half2_t(as.y, as.y);
} }
for (int j = 0; j < WARP_N_TILES; j++) { for (int j = 0; j < WARP_N_TILES; j++) {
...@@ -392,8 +390,9 @@ public: ...@@ -392,8 +390,9 @@ public:
// constexpr int target = 0; // constexpr int target = 0;
// if (threadIdx.x == 3 && j == 1 && i == 0) { // if (threadIdx.x == 3 && j == 1 && i == 0) {
// printf("before ws2 = %f %f fsum.data[%d] = %f %f\n", (float)ws2.x, (float)ws2.y, target, (float)fsum.data[target].x, (float)fsum.data[target].y); // printf("before ws2 = %f %f fsum.data[%d] = %f %f\n", (float)ws2.x, (float)ws2.y, target,
// (float)fsum.data[target].x, (float)fsum.data[target].y);
// } // }
fsum.data[0] = __hfma2(i2f::int2half2(psum.data[0], psum.data[1]), __hmul2(asx[i], ws1), fsum.data[0]); fsum.data[0] = __hfma2(i2f::int2half2(psum.data[0], psum.data[1]), __hmul2(asx[i], ws1), fsum.data[0]);
...@@ -402,15 +401,16 @@ public: ...@@ -402,15 +401,16 @@ public:
fsum.data[3] = __hfma2(i2f::int2half2(psum.data[6], psum.data[7]), __hmul2(asy[i], ws2), fsum.data[3]); fsum.data[3] = __hfma2(i2f::int2half2(psum.data[6], psum.data[7]), __hmul2(asy[i], ws2), fsum.data[3]);
// if (threadIdx.x == 3 && j == 1 && i == 0) { // if (threadIdx.x == 3 && j == 1 && i == 0) {
// printf("before ws2 = %f %f fsum.data[%d] = %f %f\n", (float)ws2.x, (float)ws2.y, target, (float)fsum.data[target].x, (float)fsum.data[target].y); // printf("before ws2 = %f %f fsum.data[%d] = %f %f\n", (float)ws2.x, (float)ws2.y, target,
// (float)fsum.data[target].x, (float)fsum.data[target].y);
// } // }
} }
} }
} }
template<typename i2f = i2f_normal, typename F> template<typename i2f = i2f_normal, typename F>
__device__ __forceinline__ __device__ __forceinline__ static void
static void apply_scales(F &&getpsum, ascale_warp ascale, wscale_warp wscale, f32psum_warp &fpsum) { apply_scales(F &&getpsum, ascale_warp ascale, wscale_warp wscale, f32psum_warp &fpsum) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE;
...@@ -419,8 +419,8 @@ public: ...@@ -419,8 +419,8 @@ public:
for (int i = 0; i < WARP_M_TILES; i++) { for (int i = 0; i < WARP_M_TILES; i++) {
half2_t as = broadcast_ascale(ascale, i * 2, laneId); half2_t as = broadcast_ascale(ascale, i * 2, laneId);
asx[i] = half22float2(half2_t(as.x, as.x)); asx[i] = half22float2(half2_t(as.x, as.x));
asy[i] = half22float2(half2_t(as.y, as.y)); asy[i] = half22float2(half2_t(as.y, as.y));
} }
auto fma2 = [](float2 a, float2 b, float &cx, float &cy) ALWAYSINLINE { auto fma2 = [](float2 a, float2 b, float &cx, float &cy) ALWAYSINLINE {
...@@ -449,17 +449,18 @@ public: ...@@ -449,17 +449,18 @@ public:
* input: WARP_M of half (in shared memory, per-warp) * input: WARP_M of half (in shared memory, per-warp)
* output: [..., ASCALES_NUM_PACKS, ASCALES_VALID_LANES] in global memory, per-warp * output: [..., ASCALES_NUM_PACKS, ASCALES_VALID_LANES] in global memory, per-warp
*/ */
__device__ __forceinline__ __device__ __forceinline__ static void pack_ascales(const half_t *input, packed_ascale_t *output) {
static void pack_ascales(const half_t *input, packed_ascale_t *output) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
#pragma unroll #pragma unroll
for (int j = 0; j < ASCALES_NUM_PACKS; j++) { for (int j = 0; j < ASCALES_NUM_PACKS; j++) {
if (laneId < ASCALES_VALID_LANES) { if (laneId < ASCALES_VALID_LANES) {
packed_ascale_t tmp; packed_ascale_t tmp;
#pragma unroll #pragma unroll
for (int i = 0; i < ASCALES_PACK_SIZE; i += 2) { for (int i = 0; i < ASCALES_PACK_SIZE; i += 2) {
tmp.data[i / 2].x = input[j * ASCALES_PACK_SIZE * WARP_SIZE + laneId / 8 * 8 * ASCALES_PACK_SIZE + laneId % 8 + i * 8]; tmp.data[i / 2].x = input[j * ASCALES_PACK_SIZE * WARP_SIZE + laneId / 8 * 8 * ASCALES_PACK_SIZE +
tmp.data[i / 2].y = input[j * ASCALES_PACK_SIZE * WARP_SIZE + laneId / 8 * 8 * ASCALES_PACK_SIZE + laneId % 8 + (i + 1) * 8]; laneId % 8 + i * 8];
tmp.data[i / 2].y = input[j * ASCALES_PACK_SIZE * WARP_SIZE + laneId / 8 * 8 * ASCALES_PACK_SIZE +
laneId % 8 + (i + 1) * 8];
} }
output[j * ASCALES_VALID_LANES + laneId] = tmp; output[j * ASCALES_VALID_LANES + laneId] = tmp;
} }
...@@ -470,16 +471,17 @@ public: ...@@ -470,16 +471,17 @@ public:
* input: WARP_N of half (in shared memory, per-warp) * input: WARP_N of half (in shared memory, per-warp)
* output: [..., WSCALES_NUM_PACKS, WSCALES_VALID_LANES] in global memory, per-warp * output: [..., WSCALES_NUM_PACKS, WSCALES_VALID_LANES] in global memory, per-warp
*/ */
__device__ __forceinline__ __device__ __forceinline__ static void pack_wscales(const half_t *input, packed_wscale_t *output) {
static void pack_wscales(const half_t *input, packed_wscale_t *output) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
#pragma unroll #pragma unroll
for (int j = 0; j < WSCALES_NUM_PACKS; j++) { for (int j = 0; j < WSCALES_NUM_PACKS; j++) {
if (laneId < WSCALES_VALID_LANES) { if (laneId < WSCALES_VALID_LANES) {
packed_wscale_t tmp; packed_wscale_t tmp;
#pragma unroll #pragma unroll
for (int i = 0; i < WSCALES_PACK_SIZE; i += 2) { for (int i = 0; i < WSCALES_PACK_SIZE; i += 2) {
tmp.data[i / 2] = *reinterpret_cast<const half2_t *>(&input[j * WSCALES_PACK_SIZE * WARP_SIZE + laneId / 4 * 4 * WSCALES_PACK_SIZE + laneId % 4 * 2 + i * 4]); tmp.data[i / 2] = *reinterpret_cast<const half2_t *>(
&input[j * WSCALES_PACK_SIZE * WARP_SIZE + laneId / 4 * 4 * WSCALES_PACK_SIZE + laneId % 4 * 2 +
i * 4]);
} }
store(&output[j * WSCALES_VALID_LANES + laneId], tmp); store(&output[j * WSCALES_VALID_LANES + laneId], tmp);
} }
...@@ -491,17 +493,16 @@ public: ...@@ -491,17 +493,16 @@ public:
using matrix_t = half_t[8][WARP_N + 8]; using matrix_t = half_t[8][WARP_N + 8];
static constexpr int SHMEM_SIZE = sizeof(matrix_t); static constexpr int SHMEM_SIZE = sizeof(matrix_t);
static constexpr int PACK_SIZE = WARP_N / WARP_SIZE; static constexpr int PACK_SIZE = WARP_N / WARP_SIZE;
using pack_t = std::array<half_t, PACK_SIZE>; using pack_t = std::array<half_t, PACK_SIZE>;
// F (int rowId, pack_t &pack) // F (int rowId, pack_t &pack)
template<typename ...F> template<typename... F>
__device__ __forceinline__ __device__ __forceinline__ void operator()(
void operator()(fpsum_warp fpsum, half_t *output, int stride, int maxRows, int maxCols, void *shmem, F &&...plugins) { fpsum_warp fpsum, half_t *output, int stride, int maxRows, int maxCols, void *shmem, F &&...plugins) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
// pack_t reduce_tmp; // pack_t reduce_tmp;
// constexpr bool enableReduce = !std::is_void_v<FuncReduce>; // constexpr bool enableReduce = !std::is_void_v<FuncReduce>;
...@@ -518,19 +519,19 @@ public: ...@@ -518,19 +519,19 @@ public:
// } // }
// }; // };
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) { for (int i = 0; i < WARP_M_TILES; i++) {
#pragma unroll #pragma unroll
for (int j = 0; j < WARP_N_TILES; j++) { for (int j = 0; j < WARP_N_TILES; j++) {
packed_fpsum_t &fsum = fpsum[i * WARP_N_TILES + j]; packed_fpsum_t &fsum = fpsum[i * WARP_N_TILES + j];
int row = laneId / 4; int row = laneId / 4;
int col = laneId % 4 * 2 + j * INSN_N; int col = laneId % 4 * 2 + j * INSN_N;
*reinterpret_cast<half2_t *>(&mat[row][col + 0]) = fsum.data[0]; *reinterpret_cast<half2_t *>(&mat[row][col + 0]) = fsum.data[0];
*reinterpret_cast<half2_t *>(&mat[row][col + 8]) = fsum.data[2]; *reinterpret_cast<half2_t *>(&mat[row][col + 8]) = fsum.data[2];
} }
__syncwarp(); __syncwarp();
#pragma unroll #pragma unroll
for (int row = 0; row < 8; row++) { for (int row = 0; row < 8; row++) {
pack_t pack = *reinterpret_cast<pack_t *>(&mat[row][laneId * PACK_SIZE]); pack_t pack = *reinterpret_cast<pack_t *>(&mat[row][laneId * PACK_SIZE]);
...@@ -542,22 +543,24 @@ public: ...@@ -542,22 +543,24 @@ public:
bool pred = i * INSN_M + row < maxRows && laneId * PACK_SIZE < maxCols; bool pred = i * INSN_M + row < maxRows && laneId * PACK_SIZE < maxCols;
// if (pred) { // if (pred) {
store_pred(reinterpret_cast<pack_t *>(&output[(i * INSN_M + row) * stride + laneId * PACK_SIZE]), pack, pred); store_pred(reinterpret_cast<pack_t *>(&output[(i * INSN_M + row) * stride + laneId * PACK_SIZE]),
pack,
pred);
// } // }
} }
__syncwarp(); __syncwarp();
#pragma unroll #pragma unroll
for (int j = 0; j < WARP_N_TILES; j++) { for (int j = 0; j < WARP_N_TILES; j++) {
packed_fpsum_t &fsum = fpsum[i * WARP_N_TILES + j]; packed_fpsum_t &fsum = fpsum[i * WARP_N_TILES + j];
int row = laneId / 4; int row = laneId / 4;
int col = laneId % 4 * 2 + j * INSN_N; int col = laneId % 4 * 2 + j * INSN_N;
*reinterpret_cast<half2_t *>(&mat[row][col + 0]) = fsum.data[1]; *reinterpret_cast<half2_t *>(&mat[row][col + 0]) = fsum.data[1];
*reinterpret_cast<half2_t *>(&mat[row][col + 8]) = fsum.data[3]; *reinterpret_cast<half2_t *>(&mat[row][col + 8]) = fsum.data[3];
} }
__syncwarp(); __syncwarp();
#pragma unroll #pragma unroll
for (int row = 0; row < 8; row++) { for (int row = 0; row < 8; row++) {
pack_t pack = *reinterpret_cast<pack_t *>(&mat[row][laneId * PACK_SIZE]); pack_t pack = *reinterpret_cast<pack_t *>(&mat[row][laneId * PACK_SIZE]);
...@@ -569,7 +572,10 @@ public: ...@@ -569,7 +572,10 @@ public:
bool pred = i * INSN_M + 8 + row < maxRows && laneId * PACK_SIZE < maxCols; bool pred = i * INSN_M + 8 + row < maxRows && laneId * PACK_SIZE < maxCols;
// if (pred) { // if (pred) {
store_pred(reinterpret_cast<pack_t *>(&output[(i * INSN_M + 8 + row) * stride + laneId * PACK_SIZE]), pack, pred); store_pred(
reinterpret_cast<pack_t *>(&output[(i * INSN_M + 8 + row) * stride + laneId * PACK_SIZE]),
pack,
pred);
// } // }
} }
__syncwarp(); __syncwarp();
...@@ -584,22 +590,22 @@ public: ...@@ -584,22 +590,22 @@ public:
// [WARP_M, WARP_N * 2] when fuse_glu // [WARP_M, WARP_N * 2] when fuse_glu
template<bool fuse_glu> template<bool fuse_glu>
struct load_act_to_fpsum { struct load_act_to_fpsum {
using matrix_t = half_t[INSN_M][WARP_N + 8]; using matrix_t = half_t[INSN_M][WARP_N + 8];
static constexpr size_t SHMEM_SIZE = sizeof(matrix_t); static constexpr size_t SHMEM_SIZE = sizeof(matrix_t);
__device__ __forceinline__ __device__ __forceinline__ void
void operator()(const half_t *input, int stride, int maxRows, int maxCols, fpsum_warp &out, void *shmem) { operator()(const half_t *input, int stride, int maxRows, int maxCols, fpsum_warp &out, void *shmem) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem); matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
constexpr int PACK_SIZE = WARP_N / WARP_SIZE; constexpr int PACK_SIZE = WARP_N / WARP_SIZE;
using packed_input = std::array<half_t, PACK_SIZE>; using packed_input = std::array<half_t, PACK_SIZE>;
using packed_raw_input = std::array<half2_t, PACK_SIZE>; using packed_raw_input = std::array<half2_t, PACK_SIZE>;
#pragma unroll #pragma unroll
for (int m = 0; m < WARP_M_TILES; m++) { for (int m = 0; m < WARP_M_TILES; m++) {
#pragma unroll #pragma unroll
for (int row = 0; row < INSN_M; row++) { for (int row = 0; row < INSN_M; row++) {
packed_input pack; packed_input pack;
// TODO: numCols not multiples of PACK_SIZE // TODO: numCols not multiples of PACK_SIZE
...@@ -608,9 +614,10 @@ public: ...@@ -608,9 +614,10 @@ public:
raw.fill(half2_t(0, 0)); raw.fill(half2_t(0, 0));
bool pred = (m * INSN_M + row) < maxRows && laneId * PACK_SIZE * 2 < maxCols; bool pred = (m * INSN_M + row) < maxRows && laneId * PACK_SIZE * 2 < maxCols;
if (pred) { if (pred) {
raw = load(reinterpret_cast<const packed_raw_input *>(input + (m * INSN_M + row) * stride + laneId * PACK_SIZE * 2)); raw = load(reinterpret_cast<const packed_raw_input *>(input + (m * INSN_M + row) * stride +
laneId * PACK_SIZE * 2));
} }
#pragma unroll #pragma unroll
for (int j = 0; j < PACK_SIZE; j++) { for (int j = 0; j < PACK_SIZE; j++) {
pack[j] = raw[j].x * silu(raw[j].y); pack[j] = raw[j].x * silu(raw[j].y);
} }
...@@ -618,7 +625,8 @@ public: ...@@ -618,7 +625,8 @@ public:
pack.fill(half_t(0)); pack.fill(half_t(0));
bool pred = (m * INSN_M + row) < maxRows && laneId * PACK_SIZE < maxCols; bool pred = (m * INSN_M + row) < maxRows && laneId * PACK_SIZE < maxCols;
if (pred) { if (pred) {
pack = load(reinterpret_cast<const packed_input *>(input + (m * INSN_M + row) * stride + laneId * PACK_SIZE)); pack = load(reinterpret_cast<const packed_input *>(input + (m * INSN_M + row) * stride +
laneId * PACK_SIZE));
} }
} }
store<true>(reinterpret_cast<packed_input *>(&mat[row][laneId * PACK_SIZE]), pack); store<true>(reinterpret_cast<packed_input *>(&mat[row][laneId * PACK_SIZE]), pack);
...@@ -637,27 +645,24 @@ public: ...@@ -637,27 +645,24 @@ public:
} }
}; };
template<typename F> template<typename F>
__device__ __forceinline__ __device__ __forceinline__ static fpsum_warp apply_act(fpsum_warp fpsum, F func) {
static fpsum_warp apply_act(fpsum_warp fpsum, F func) {
fpsum_warp result; fpsum_warp result;
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) { for (int i = 0; i < WARP_M_TILES; i++) {
#pragma unroll #pragma unroll
for (int j = 0; j < WARP_N_TILES; j++) { for (int j = 0; j < WARP_N_TILES; j++) {
#pragma unroll #pragma unroll
for (int k = 0; k < 4; k++) { for (int k = 0; k < 4; k++) {
half2_t &dst = result[i * WARP_N_TILES + j].data[k]; half2_t &dst = result[i * WARP_N_TILES + j].data[k];
half2_t src = fpsum[i * WARP_N_TILES + j].data[k]; half2_t src = fpsum[i * WARP_N_TILES + j].data[k];
dst.x = func(src.x); dst.x = func(src.x);
dst.y = func(src.y); dst.y = func(src.y);
} }
} }
} }
return result; return result;
} }
struct EpilogueDefault { struct EpilogueDefault {
struct Arguments { struct Arguments {
...@@ -665,57 +670,57 @@ public: ...@@ -665,57 +670,57 @@ public:
int actualM, actualN; int actualM, actualN;
}; };
__device__ __forceinline__ __device__ __forceinline__ void
void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, const Arguments &args) { operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, const Arguments &args) {
const int warpId = threadIdx.x / WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE;
__shared__ alignas(128) uint8_t shmem[NUM_WARPS][ceilDiv(unpack_fpsum::SHMEM_SIZE, 128) * 128]; __shared__ alignas(128) uint8_t shmem[NUM_WARPS][ceilDiv(unpack_fpsum::SHMEM_SIZE, 128) * 128];
const int m_offset = binfo.bm * BLOCK_M + warpId * WARP_M; const int m_offset = binfo.bm * BLOCK_M + warpId * WARP_M;
const int n_offset = binfo.bn * BLOCK_N; const int n_offset = binfo.bn * BLOCK_N;
unpack_fpsum()( unpack_fpsum()(fpsum,
fpsum, args.out + m_offset * args.actualN + n_offset,
args.out + m_offset * args.actualN + n_offset, args.actualN,
args.actualN, args.actualM - m_offset,
args.actualM - m_offset, args.actualN - n_offset,
args.actualN - n_offset, shmem[warpId],
shmem[warpId], [&](int rowId, unpack_fpsum::pack_t &pack) ALWAYSINLINE {
[&](int rowId, unpack_fpsum::pack_t &pack) ALWAYSINLINE { if constexpr (std::is_same_v<half_t, half>) {
if constexpr (std::is_same_v<half_t, half>) { #pragma unroll
#pragma unroll for (int i = 0; i < pack.size(); i++) {
for (int i = 0; i < pack.size(); i++) { pack[i] = __hmin(pack[i], (half)65504);
pack[i] = __hmin(pack[i], (half)65504); pack[i] = __hmax(pack[i], (half)-65504);
pack[i] = __hmax(pack[i], (half)-65504); }
} }
} });
}
);
} }
}; };
struct EpilogueNop { struct EpilogueNop {
// workaround for layout mismatch between host and device code // workaround for layout mismatch between host and device code
struct Arguments { size_t unused; }; struct Arguments {
size_t unused;
};
__device__ __forceinline__ __device__ __forceinline__ void
void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, const Arguments &args) { operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, const Arguments &args) {}
}
}; };
template<bool USE_BIAS = true, bool USE_SCALE = false> template<bool USE_BIAS = true, bool USE_SCALE = false>
struct EpilogueBias { struct EpilogueBias {
struct Arguments { struct Arguments {
const packed_wscale_t *bias; // [N / BLOCK_N, WSCALES_NUM_PACKS, WSCALES_VALID_LANES] of packed_wscale_t const packed_wscale_t *bias; // [N / BLOCK_N, WSCALES_NUM_PACKS, WSCALES_VALID_LANES] of packed_wscale_t
const packed_wscale_t *scale; const packed_wscale_t *scale;
}; };
__device__ __forceinline__ __device__ __forceinline__ void
void apply_bias(fpsum_warp &fpsum, int M, int N, int K, const packed_wscale_t *bias, const packed_wscale_t *scale) { apply_bias(fpsum_warp &fpsum, int M, int N, int K, const packed_wscale_t *bias, const packed_wscale_t *scale) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
// if (laneId == 0) { // if (laneId == 0) {
// printf("block.x=%d block.y=%d warpId=%d bias=%p\n", blockIdx.x, blockIdx.y, threadIdx.x / WARP_SIZE, bias); // printf("block.x=%d block.y=%d warpId=%d bias=%p\n", blockIdx.x, blockIdx.y, threadIdx.x / WARP_SIZE,
// bias);
// } // }
wscale_warp b, s; wscale_warp b, s;
...@@ -737,7 +742,6 @@ public: ...@@ -737,7 +742,6 @@ public:
s1 = broadcast_wscale(s, j * 4, laneId); s1 = broadcast_wscale(s, j * 4, laneId);
s2 = broadcast_wscale(s, j * 4 + 2, laneId); s2 = broadcast_wscale(s, j * 4 + 2, laneId);
} }
for (int i = 0; i < WARP_M_TILES; i++) { for (int i = 0; i < WARP_M_TILES; i++) {
auto &fsum = fpsum[i * WARP_N_TILES + j]; auto &fsum = fpsum[i * WARP_N_TILES + j];
...@@ -762,114 +766,114 @@ public: ...@@ -762,114 +766,114 @@ public:
} }
} }
__device__ __forceinline__ __device__ __forceinline__ void
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) { operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) {
const int bn = binfo.bn; const int bn = binfo.bn;
if constexpr (USE_BIAS || USE_SCALE) { if constexpr (USE_BIAS || USE_SCALE) {
apply_bias( apply_bias(fpsum,
fpsum, M, N, K, M,
args.bias + bn * WSCALES_NUM_PACKS * WSCALES_VALID_LANES, N,
args.scale + bn * WSCALES_NUM_PACKS * WSCALES_VALID_LANES K,
); args.bias + bn * WSCALES_NUM_PACKS * WSCALES_VALID_LANES,
args.scale + bn * WSCALES_NUM_PACKS * WSCALES_VALID_LANES);
} }
} }
}; };
struct EpilogueSilu { struct EpilogueSilu {
struct Arguments { size_t unused; }; struct Arguments {
size_t unused;
};
__device__ __forceinline__ __device__ __forceinline__ void
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) { operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) {
fpsum = apply_act(fpsum, [](half_t x) { return silu(x); }); fpsum = apply_act(fpsum, [](half_t x) { return silu(x); });
} }
}; };
template<typename ...Epilogues> template<typename... Epilogues>
struct EpilogueCombination { struct EpilogueCombination {
using Arguments = std::tuple<typename Epilogues::Arguments...>; using Arguments = std::tuple<typename Epilogues::Arguments...>;
__device__ __forceinline__ __device__ __forceinline__ void
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) { operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) {
// this function makes intellisense crashes :( // this function makes intellisense crashes :(
#if __INTELLISENSE__ #if __INTELLISENSE__
__trap(); // should not happen when actually compiling __trap(); // should not happen when actually compiling
#else #else
std::tuple<Epilogues...> epilogues; std::tuple<Epilogues...> epilogues;
auto run = [&]<size_t idx>() { auto run = [&]<size_t idx>() {
std::get<idx>(epilogues).operator()(binfo, fpsum, M, N, K, std::get<idx>(args)); std::get<idx>(epilogues).operator()(binfo, fpsum, M, N, K, std::get<idx>(args));
}; };
auto foreach = [&]<size_t ...Is>(std::index_sequence<Is...>) { auto foreach = [&]<size_t... Is>(std::index_sequence<Is...>) { (run.template operator()<Is>(), ...); };
(run.template operator()<Is>(), ...); foreach (std::make_index_sequence<sizeof...(Epilogues)>())
}; ;
foreach(std::make_index_sequence<sizeof...(Epilogues)>()); #endif
#endif
} }
}; };
}; };
#define IMPORT_GEMM_BASE(config) \ #define IMPORT_GEMM_BASE(config) \
using Base = GEMMBase<config>; \ using Base = GEMMBase<config>; \
using Base::BLOCK_M; \ using Base::BLOCK_M; \
using Base::BLOCK_N; \ using Base::BLOCK_N; \
using Base::WARP_SIZE; \ using Base::WARP_SIZE; \
using Base::NUM_WARPS; \ using Base::NUM_WARPS; \
using Base::INSN_M; \ using Base::INSN_M; \
using Base::INSN_N; \ using Base::INSN_N; \
using Base::INSN_K; \ using Base::INSN_K; \
using typename Base::half_t; \ using typename Base::half_t; \
using typename Base::half2_t; \ using typename Base::half2_t; \
using Base::WARP_M; \ using Base::WARP_M; \
using Base::WARP_N; \ using Base::WARP_N; \
using Base::WARP_K; \ using Base::WARP_K; \
using Base::WARP_M_TILES; \ using Base::WARP_M_TILES; \
using Base::WARP_N_TILES; \ using Base::WARP_N_TILES; \
using Base::WARP_K_TILES; \ using Base::WARP_K_TILES; \
using Base::WSCALES_PACK_SIZE; \ using Base::WSCALES_PACK_SIZE; \
using Base::WSCALES_NUM_PACKS; \ using Base::WSCALES_NUM_PACKS; \
using Base::WSCALES_VALID_LANES; \ using Base::WSCALES_VALID_LANES; \
using Base::ASCALES_PACK_SIZE; \ using Base::ASCALES_PACK_SIZE; \
using Base::ASCALES_NUM_PACKS; \ using Base::ASCALES_NUM_PACKS; \
using Base::ASCALES_VALID_LANES; \ using Base::ASCALES_VALID_LANES; \
using typename Base::packed_act_t; \ using typename Base::packed_act_t; \
using typename Base::packed_wgt_t; \ using typename Base::packed_wgt_t; \
using typename Base::packed_psum_t; \ using typename Base::packed_psum_t; \
using typename Base::packed_fpsum_t; \ using typename Base::packed_fpsum_t; \
using typename Base::packed_gated_fpsum_t; \ using typename Base::packed_gated_fpsum_t; \
using typename Base::packed_f32psum_t; \ using typename Base::packed_f32psum_t; \
using typename Base::packed_wscale_t; \ using typename Base::packed_wscale_t; \
using typename Base::packed_ascale_t; \ using typename Base::packed_ascale_t; \
using typename Base::act_warp; \ using typename Base::act_warp; \
using typename Base::wgt_warp; \ using typename Base::wgt_warp; \
using typename Base::ascale_warp; \ using typename Base::ascale_warp; \
using typename Base::wscale_warp; \ using typename Base::wscale_warp; \
using typename Base::fpsum_warp; \ using typename Base::fpsum_warp; \
using typename Base::f32psum_warp; \ using typename Base::f32psum_warp; \
using typename Base::gated_fpsum_warp; \ using typename Base::gated_fpsum_warp; \
using typename Base::BlockInfo; \ using typename Base::BlockInfo; \
using typename Base::unpack_fpsum; \ using typename Base::unpack_fpsum; \
using typename Base::EpilogueDefault; \ using typename Base::EpilogueDefault; \
using typename Base::EpilogueNop; \ using typename Base::EpilogueNop; \
template<bool USE_BIAS, bool USE_SCALE> \ template<bool USE_BIAS, bool USE_SCALE> \
using EpilogueBias = typename Base::EpilogueBias<USE_BIAS, USE_SCALE>; \ using EpilogueBias = typename Base::EpilogueBias<USE_BIAS, USE_SCALE>; \
using Base::mma_f16xf16_f32; \ using Base::mma_f16xf16_f32; \
using Base::packed_fp32_to_fp16; \ using Base::packed_fp32_to_fp16; \
using Base::packed_fp16_to_fp32; \ using Base::packed_fp16_to_fp32; \
using Base::load_act; \ using Base::load_act; \
using Base::load_wgt; \ using Base::load_wgt; \
using Base::load_ascale; \ using Base::load_ascale; \
using Base::load_wscale; \ using Base::load_wscale; \
using Base::broadcast_wscale; \ using Base::broadcast_wscale; \
using Base::broadcast_ascale; \ using Base::broadcast_ascale; \
using Base::apply_scales; \ using Base::apply_scales; \
using Base::pack_ascales; \ using Base::pack_ascales; \
using Base::pack_wscales; \ using Base::pack_wscales; \
using Base::apply_act; using Base::apply_act;
template<typename kernel> template<typename kernel>
constexpr int min_arch() { constexpr int min_arch() {
if constexpr (requires {kernel::MIN_ARCH;}) { if constexpr (requires { kernel::MIN_ARCH; }) {
return kernel::MIN_ARCH; return kernel::MIN_ARCH;
} else { } else {
return 0; return 0;
...@@ -877,16 +881,15 @@ constexpr int min_arch() { ...@@ -877,16 +881,15 @@ constexpr int min_arch() {
} }
template<typename kernel> template<typename kernel>
constexpr int max_arch() { constexpr int max_arch() {
if constexpr (requires {kernel::MAX_ARCH;}) { if constexpr (requires { kernel::MAX_ARCH; }) {
return kernel::MAX_ARCH; return kernel::MAX_ARCH;
} else { } else {
return INT_MAX; return INT_MAX;
} }
} }
template<typename kernel, typename ...T> template<typename kernel, typename... T>
__global__ __global__ static void invoke_kernel(T... args) {
static void invoke_kernel(T ...args) {
#ifdef __CUDA_ARCH__ #ifdef __CUDA_ARCH__
if constexpr (__CUDA_ARCH__ >= min_arch<kernel>() && __CUDA_ARCH__ <= max_arch<kernel>()) { if constexpr (__CUDA_ARCH__ >= min_arch<kernel>() && __CUDA_ARCH__ <= max_arch<kernel>()) {
kernel()(args...); kernel()(args...);
...@@ -900,8 +903,7 @@ static void invoke_kernel(T ...args) { ...@@ -900,8 +903,7 @@ static void invoke_kernel(T ...args) {
} }
template<typename T> template<typename T>
__global__ __global__ static void test_sizeof_device() {
static void test_sizeof_device() {
printf("sizeof on device = %d\n", (int)sizeof(T)); printf("sizeof on device = %d\n", (int)sizeof(T));
} }
...@@ -918,4 +920,4 @@ static void test_sizeof() { ...@@ -918,4 +920,4 @@ static void test_sizeof() {
checkCUDA(cudaDeviceSynchronize()); checkCUDA(cudaDeviceSynchronize());
} }
}; // namespace nunchaku::kernels }; // namespace nunchaku::kernels
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment