Commit e9ad0535 authored by muyangli's avatar muyangli
Browse files

[major] support SANA

parent 9eb2cee0
#include "SanaModel.h"
#include "kernels/zgemm/zgemm.h"
#include "flash_api.h"
#include "kernels/misc_kernels.h"
#include <nvtx3/nvToolsExt.h>
using spdlog::fmt_lib::format;
using namespace nunchaku;
SanaLinearAttention::SanaLinearAttention(int dim, bool bias, bool pag, Tensor::ScalarType dtype, Device device) :
dim(dim),
dim_pad(ceilDiv(dim, 128) * 128),
qkv_proj(dim, dim_pad * 3, bias, dtype, device),
out_proj(dim_pad, dim, bias, dtype, device),
pag_to_v(std::nullopt)
{
registerChildren
(qkv_proj, "qkv_proj")
(out_proj, "out_proj")
;
if (pag) {
pag_to_v.emplace(dim, dim_pad, bias, dtype, device);
registerChildren(pag_to_v.value(), "pag_to_v");
}
}
Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
constexpr int HEAD_DIM = 32;
assert(x.ndims() == 3);
const int batch_size = x.shape[0];
const int num_tokens = x.shape[1];
const int num_tokens_pad = ceilDiv(num_tokens, 256) * 256;
assert(x.shape[2] == dim);
const int num_heads = dim_pad / HEAD_DIM;
if (num_tokens_pad != num_tokens) {
spdlog::debug("SanaLinearAttention: pad num_tokens from {} to {}", num_tokens, num_tokens_pad);
Tensor x_pad = Tensor::allocate({batch_size, num_tokens_pad, dim}, x.dtype(), x.device());
x_pad.zero_();
for (int i = 0; i < batch_size; i++) {
x_pad.slice(0, i, i + 1).slice(1, 0, num_tokens).copy_(x.slice(0, i, i + 1));
}
x = x_pad;
}
auto qact = qkv_proj.quantize(x, false);
Tensor q = Tensor::allocate({batch_size, num_tokens_pad, dim_pad}, x.dtype(), x.device());
Tensor vk = Tensor::allocate({batch_size, num_heads, HEAD_DIM + 1, HEAD_DIM}, Tensor::FP32, x.device());
kernels::gemm_w4a4(
qact.act,
qkv_proj.qweight,
{},
{},
qact.ascales,
qkv_proj.wscales,
{}, {}, qact.lora_act, qkv_proj.lora_up, {}, {}, {}, {}, {}, qkv_proj.bias, {},
vk, q,
qact.is_unsigned, qkv_proj.lora_scales, false);
debug("vk", vk);
debug("q", q);
kernels::linearattn_vk_mul_q(q, vk);
debug("raw_attn_output", q);
if (num_tokens_pad != num_tokens) {
Tensor q_unpad = Tensor::allocate({batch_size, num_tokens, dim_pad}, q.dtype(), q.device());
for (int i = 0; i < batch_size; i++) {
q_unpad.slice(0, i, i + 1).copy_(q.slice(0, i, i + 1).slice(1, 0, num_tokens));
}
q = q_unpad;
}
// kernels::gemm_w8a8_fuse_litela(qact.act, qkv.qweight, q, vk, qact.ascales, qkv.wscales);
// return out_proj.forward(q);
if (!out.valid()) {
out = Tensor::allocate({batch_size, num_tokens, dim}, q.dtype(), q.device());
}
out_proj.forward(q, out);
return out;
}
Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) {
const int batch_size = x.shape[0];
const int num_tokens = x.shape[1];
Tensor out = Tensor::allocate({batch_size, num_tokens, dim}, x.dtype(), x.device());
Tensor x_org, x_ptb;
Tensor out_org, out_ptb;
if (cfg) {
assert(batch_size % 3 == 0);
x_org = x.slice(0, 0, batch_size * 2 / 3);
x_ptb = x.slice(0, batch_size * 2 / 3, batch_size);
out_org = out.slice(0, 0, batch_size * 2 / 3);
out_ptb = out.slice(0, batch_size * 2 / 3, batch_size);
} else {
assert(batch_size % 2 == 0);
x_org = x.slice(0, 0, batch_size / 2);
x_ptb = x.slice(0, batch_size / 2, batch_size);
out_org = out.slice(0, 0, batch_size / 2);
out_ptb = out.slice(0, batch_size / 2, batch_size);
}
this->forward(x_org, out_org);
Tensor v_ptb = this->pag_to_v.value().forward(x_ptb);
this->out_proj.forward(v_ptb, out_ptb);
return out;
}
MultiHeadCrossAttention::MultiHeadCrossAttention(int num_heads, int head_dim, Tensor::ScalarType dtype, Device device) :
num_heads(num_heads), head_dim(head_dim),
q_linear(num_heads * head_dim, num_heads * head_dim, true, dtype, device),
kv_linear(num_heads * head_dim, num_heads * head_dim * 2, true, dtype, device),
out_proj(num_heads * head_dim, num_heads * head_dim, true, dtype, device)
{
registerChildren
(q_linear, "q_linear")
(kv_linear, "kv_linear")
(out_proj, "out_proj")
;
}
Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens_img, Tensor cu_seqlens_txt) {
assert(x.ndims() == 3);
assert(cond.ndims() == 2);
assert(cu_seqlens_img.ndims() == 1);
assert(cu_seqlens_txt.ndims() == 1);
const int batch_size = x.shape[0];
const int num_tokens_img = x.shape[1];
const int num_tokens_txt = cond.shape[0];
assert(cu_seqlens_img.shape[0] == batch_size + 1);
assert(cu_seqlens_txt.shape[0] == batch_size + 1);
Tensor q = q_linear.forward(x).view({batch_size * num_tokens_img, num_heads, head_dim});
Tensor kv = kv_linear.forward(cond).view({num_tokens_txt, num_heads * 2, head_dim});
Tensor k = kv.slice(1, 0, num_heads);
Tensor v = kv.slice(1, num_heads, num_heads * 2);
Tensor attn_output = mha_varlen_fwd(
q, k, v,
cu_seqlens_img, cu_seqlens_txt,
num_tokens_img, num_tokens_txt,
0.0f,
pow(q.shape[-1], (-0.5)),
false, false,
-1, -1,
false
).front().view({batch_size, num_tokens_img, num_heads * head_dim});
// Tensor attn_output = mha_fwd(q, k, v,
// 0.0f,
// pow(q.shape[-1], (-0.5)),
// false, -1, -1, false
// ).front().view({B, N, num_heads * head_dim});
return out_proj.forward(attn_output);
}
SanaGLUMBConv::SanaGLUMBConv(int in_features, int hidden_features, Tensor::ScalarType dtype, Device device) :
in_features(in_features), hidden_features(hidden_features),
inverted_conv(in_features, hidden_features * 2, true, dtype, device),
depth_conv(hidden_features * 2, true, dtype, device),
point_conv(hidden_features, in_features, false, dtype, device)
{
registerChildren
(inverted_conv, "inverted_conv")
(depth_conv, "depth_conv")
(point_conv, "point_conv")
;
}
Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) {
if (H <= 0 || W <= 0) {
H = W = sqrt(x.shape[1]);
}
x = inverted_conv.forward_silu(x);
x = x.view({x.shape[0], H, W, x.shape[-1]});
debug("inverted_conv_output", x);
x = depth_conv.forward(x);
debug("depth_conv_output", x);
x = x.view({x.shape[0], H * W, x.shape[-1]});
auto qact = point_conv.quantize(x, true);
return point_conv.forward_quant(qact);
}
SanaLinearTransformerBlock::SanaLinearTransformerBlock(int hidden_size, int intermediate_size, int num_cross_attention_heads, bool pag, Tensor::ScalarType dtype, Device device) :
hidden_size(hidden_size), num_cross_attention_heads(num_cross_attention_heads),
attn(hidden_size, false, pag, dtype, device),
cross_attn(num_cross_attention_heads, hidden_size / num_cross_attention_heads, dtype, device),
ff(hidden_size, intermediate_size, dtype, device),
norm1(hidden_size, 1e-6, false, dtype, device),
norm2(hidden_size, 1e-6, false, dtype, device)
{
this->scale_shift_table = Tensor::allocate({6, hidden_size}, dtype, device);
registerChildren
(attn, "attn")
(cross_attn, "cross_attn")
(ff, "ff")
;
registerParams
(this->scale_shift_table, "scale_shift_table")
;
}
Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg) {
nvtxRangePushA("SanaLinearTransformerBlock");
nvtxRangePushA("chunk");
// Tensor ones = Tensor::ones({hidden_size}, Tensor::FP16, x.device());
const int batch_size = timestep.shape[0];
timestep = timestep.copy(timestep.device());
timestep = timestep.view({batch_size, 6, hidden_size});
kernels::mul_add_batch(timestep, {}, false, 0, this->scale_shift_table, false);
debug("shifted_timestep", timestep);
std::array<Tensor, 6> chunked;
for (int i = 0; i < 6; i++) {
chunked[i] = timestep.slice(1, i, i + 1);
}
auto &&[shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] = chunked;
// auto &&[shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] = kernels::split_mod<6>(timestep);
nvtxRangePop();
{
nvtxRangePushA("LinearAttention");
Tensor residual = hidden_states;
Tensor norm_hidden_states = norm1.forward(hidden_states);
kernels::mul_add_batch(norm_hidden_states, scale_msa, true, 1, shift_msa, true);
debug("norm_hidden_states_la", norm_hidden_states);
Tensor attn_output = pag ? attn.forward_pag(norm_hidden_states, cfg) : attn.forward(norm_hidden_states);
debug("attn_output_la", attn_output);
kernels::mul_add_batch(attn_output, gate_msa, true, 0, residual, true);
hidden_states = attn_output;
nvtxRangePop();
}
{
nvtxRangePushA("CrossAttention");
debug("norm_hidden_states_cross", hidden_states);
Tensor attn_output = cross_attn.forward(hidden_states, encoder_hidden_states, cu_seqlens_img, cu_seqlens_txt);
debug("attn_output_cross", attn_output);
kernels::mul_add_batch(attn_output, {}, false, 0, hidden_states, true);
hidden_states = attn_output;
nvtxRangePop();
}
{
nvtxRangePushA("Feed-forward");
debug("hidden_states_ff", hidden_states);
Tensor norm_hidden_states = norm2.forward(hidden_states);
kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 1, shift_mlp, true);
debug("norm_hidden_states_ff", norm_hidden_states);
Tensor ff_output = ff.forward(norm_hidden_states, H, W);
debug("ff_output", ff_output);
kernels::mul_add_batch(ff_output, gate_mlp, true, 0, hidden_states, true);
hidden_states = ff_output;
nvtxRangePop();
}
nvtxRangePop();
debug("hidden_states_out", hidden_states);
return hidden_states;
}
SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device) :
config(config)
{
const int inner_dim = config.num_attention_heads * config.attention_head_dim;
for (int i = 0; i < config.num_layers; i++) {
transformer_blocks.push_back(std::make_unique<SanaLinearTransformerBlock>(
inner_dim,
ceilDiv(int(round(config.expand_ratio * inner_dim)), 64) * 64,
config.num_cross_attention_heads,
std::find(config.pag_layers.begin(), config.pag_layers.end(), i) != config.pag_layers.end(),
dtype, device
));
registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i));
}
}
Tensor SanaModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg) {
for (int i = 0; i < config.num_layers; i++) {
auto &&block = transformer_blocks[i];
hidden_states = block->forward(
hidden_states, encoder_hidden_states, timestep, cu_seqlens_img, cu_seqlens_txt, H, W,
pag && std::find(config.pag_layers.begin(), config.pag_layers.end(), i) != config.pag_layers.end(),
cfg
);
}
return hidden_states;
}
#pragma once
#include "common.h"
#include "Tensor.h"
#include "Linear.h"
#include "layernorm.h"
class SanaLinearAttention : public Module {
public:
SanaLinearAttention(int dim, bool bias, bool pag, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor x, Tensor out = {});
Tensor forward_pag(Tensor x, bool cfg);
public:
const int dim;
const int dim_pad;
private:
GEMM_W4A4 qkv_proj;
GEMM_W4A4 out_proj;
std::optional<GEMM_W4A4> pag_to_v;
};
class MultiHeadCrossAttention : public Module {
public:
MultiHeadCrossAttention(int num_heads, int head_dim, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor x, Tensor cond, Tensor cu_seqlens_img, Tensor cu_seqlens_txt);
public:
const int num_heads;
const int head_dim;
private:
GEMM_W4A4 q_linear;
GEMM_F16 kv_linear;
GEMM_W4A4 out_proj;
};
class SanaGLUMBConv : public Module {
public:
SanaGLUMBConv(int in_features, int hidden_features, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor x, int H, int W);
public:
const int in_features;
const int hidden_features;
private:
GEMM_W4A4 inverted_conv;
DWCONV depth_conv;
GEMM_W4A4 point_conv;
};
class SanaLinearTransformerBlock : public Module {
public:
SanaLinearTransformerBlock(int hidden_size, int intermediate_size, int num_cross_attention_heads, bool pag, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg);
public:
const int hidden_size;
const int num_cross_attention_heads;
private:
Tensor scale_shift_table;
// Tensor ones;
SanaLinearAttention attn;
MultiHeadCrossAttention cross_attn;
SanaGLUMBConv ff;
LayerNorm norm1, norm2;
};
struct SanaConfig {
int num_layers;
int num_attention_heads;
int attention_head_dim;
int num_cross_attention_heads;
double expand_ratio;
std::vector<int> pag_layers;
};
class SanaModel : public Module {
public:
SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg);
public:
const SanaConfig config;
public:
std::vector<std::unique_ptr<SanaLinearTransformerBlock>> transformer_blocks;
};
\ No newline at end of file
......@@ -32,6 +32,10 @@ public:
size_t getSize() { return size; }
Device getDevice() { return device; }
virtual bool isAsyncBuffer() {
return false;
}
protected:
template <typename Derived>
std::shared_ptr<Derived> shared_from_base() {
......@@ -90,6 +94,9 @@ public:
}
checkCUDA(cudaFreeAsync(this->ptr, 0));
}
virtual bool isAsyncBuffer() override {
return true;
}
};
class BufferCUDASync : public Buffer {
......@@ -499,16 +506,16 @@ private:
return cudaMemcpyDefault;
}
static bool isAsyncBuffer(Buffer *buffer) {
return dynamic_cast<BufferCUDA *>(buffer);
}
// static bool isAsyncBuffer(Buffer *buffer) {
// return dynamic_cast<BufferCUDA *>(buffer);
// }
static inline std::map<cudaStream_t, std::set<std::shared_ptr<Buffer>>> lockedBuffers;
public:
// before launching an async operation, make sure to lock the buffer in case the buffer is freed before GPU completes
static void lockBuffer(std::shared_ptr<Buffer> buffer, cudaStream_t stream) {
if (!isAsyncBuffer(buffer.get())) {
if (!buffer->isAsyncBuffer()) {
lockedBuffers[stream].insert(buffer);
}
}
......
......@@ -33,7 +33,7 @@ Tensor from_torch(at::Tensor input) {
result.scalarType = mapType.at(input.scalar_type());
result.buffer = std::make_shared<BufferTorchTensor>(std::move(input));
Tensor::lockBuffer(result.buffer, getCurrentCUDAStream());
// Tensor::lockBuffer(result.buffer, getCurrentCUDAStream());
return result;
}
......
......@@ -13,6 +13,10 @@ public:
this->device.type = this->tensor.is_cuda() ? Device::CUDA : Device::CPU;
this->device.idx = this->tensor.get_device();
}
virtual bool isAsyncBuffer() override {
// TODO: figure out how torch manages memory
return true;
}
private:
at::Tensor tensor;
};
......
#pragma once
#include "common.h"
#include "Tensor.h"
#include <cutlass/cutlass.h>
#include <cutlass/half.h>
#include <cutlass/bfloat16.h>
template<typename F>
inline void dispatchF16(Tensor::ScalarType type, F &&func) {
if (type == Tensor::FP16) {
func.template operator()<cutlass::half_t>();
} else if (type == Tensor::BF16) {
func.template operator()<cutlass::bfloat16_t>();
} else {
assert(false);
}
}
\ No newline at end of file
#include "common.h"
#include "Tensor.h"
#include "dispatch_cutlass.h"
#include <cuda_runtime.h>
#include "cutlass/cutlass.h"
......@@ -10,6 +12,7 @@
// depthwise_Conv2d operation cutlass_sm80_tensorop_f16_s16x8x16fprop_analytic_f16_256x128_64x3_nhwc_align8
#if 0
using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, 64>;
using FilterShape = cutlass::MatrixShape<3, 3>;
......@@ -194,3 +197,140 @@ Tensor depthwise_conv2d_kernel(Tensor A, Tensor B) {
return D;
}
#endif
Tensor dwconv_f16(Tensor input, Tensor weight, Tensor out, Tensor bias) {
assert(input.ndims() == 4);
const int N = input.size(0);
const int H = input.size(1);
const int W = input.size(2);
const int C_ = input.size(3);
assert(weight.ndims() == 4);
const int K = weight.size(0);
const int R = weight.size(1);
const int S = weight.size(2);
const int C__ = weight.size(3);
// weight = weight.copy(weight.device());
dispatchF16(weight.dtype(), [&]<typename half_t>() {
using ElementOutput = half_t;
using ElementAccumulator = half_t;
using ElementComputeEpilogue = half_t;
using ElementInputA = half_t;
using ElementInputB = half_t;
using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutOutput = cutlass::layout::TensorNHWC;
using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, 64>;
using FilterShape = cutlass::MatrixShape<3, 3>;
using ThreadblockShape = cutlass::gemm::GemmShape<ThreadBlockOutputShape::kNHW, 64, FilterShape::kCount>;
using WarpShape = cutlass::gemm::GemmShape<16, 64, FilterShape::kCount>;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop<
ElementInputA, LayoutInputA,
ElementInputB, LayoutInputB,
ElementOutput, LayoutOutput,
ElementAccumulator,
cutlass::arch::OpClassSimt,
cutlass::arch::Sm80,
ThreadblockShape,
ThreadBlockOutputShape,
FilterShape,
WarpShape,
InstructionShape,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementOutput, ElementComputeEpilogue>,
cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle<
1,
ThreadBlockOutputShape::kN,
ThreadBlockOutputShape::kH,
ThreadBlockOutputShape::kW>,
4,
cutlass::arch::OpMultiplyAdd,
cutlass::conv::IteratorAlgorithm::kFixedStrideDilation,
cutlass::conv::StrideSupport::kFixed,
cutlass::MatrixShape<1, 1>,
cutlass::MatrixShape<1, 1>>::Kernel;
using DeviceKernel = typename cutlass::conv::device::DirectConvolution<DepthwiseDirect2dConv>;
cutlass::conv::Conv2dProblemSize problem_size(
cutlass::Tensor4DCoord(N, H, W, C_),
cutlass::Tensor4DCoord(K, R, S, C__),
cutlass::Tensor4DCoord(1, 1, 1, 1),
cutlass::MatrixCoord(1, 1),
cutlass::MatrixCoord(1, 1),
cutlass::conv::Mode::kCrossCorrelation,
1,
C_ // groups
);
const int P = problem_size.P;
const int Q = problem_size.Q;
if (!out.valid()) {
out = Tensor::allocate({N, P, Q, K}, input.dtype(), input.device());
}
assert(out.ndims() == 4);
assert(out.size(0) == N);
assert(out.size(1) == P);
assert(out.size(2) == Q);
assert(out.size(3) == K);
Tensor tmp_weight = Tensor::empty_like(weight);
cutlass::TensorRef<ElementInputA, LayoutInputA> a_ref(input.data_ptr<ElementInputA>(), LayoutInputA(input.stride(2), input.stride(1), input.stride(0)));
cutlass::TensorRef<ElementInputB, LayoutInputB> b_ref(weight.data_ptr<ElementInputB>(), LayoutInputB(weight.stride(2), weight.stride(1), weight.stride(0)));
cutlass::TensorRef<ElementOutput, LayoutOutput> c_ref(bias.valid() ? bias.data_ptr<ElementOutput>() : out.data_ptr<ElementOutput>(), LayoutOutput(0, 0, 0));
cutlass::TensorRef<ElementOutput, LayoutOutput> d_ref(out.data_ptr<ElementOutput>(), LayoutOutput(out.stride(2), out.stride(1), out.stride(0)));
cutlass::TensorRef<ElementOutput, LayoutOutput> tmpw_ref(tmp_weight.data_ptr<ElementOutput>(), LayoutOutput(tmp_weight.stride(2), tmp_weight.stride(1), tmp_weight.stride(0)));
typename DeviceKernel::Arguments arguments{
problem_size,
a_ref,
b_ref,
c_ref,
d_ref,
{ElementOutput(1.0f), ElementOutput(bias.valid() ? 1.0f : 0.0f)},
tmpw_ref,
};
DeviceKernel implicit_gemm_op;
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
BufferCUDA workspace(workspace_size);
auto stream = getCurrentCUDAStream();
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot implement");
}
status = implicit_gemm_op.initialize(arguments, workspace.getPtr(), stream);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}
status = implicit_gemm_op(stream);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot run");
}
});
return out;
}
\ No newline at end of file
......@@ -3,4 +3,6 @@
#include "common.h"
#include "Tensor.h"
Tensor depthwise_conv2d_kernel(Tensor A, Tensor B);
\ No newline at end of file
// Tensor depthwise_conv2d_kernel(Tensor A, Tensor B);
Tensor dwconv_f16(Tensor input, Tensor weight, Tensor out, Tensor bias);
\ No newline at end of file
......@@ -47,7 +47,7 @@ Tensor gemm_batched_fp16(
auto sizeO = cutlass::MatrixCoord(M, N);
if (!out.valid()) {
auto outShape = a.shape;
auto outShape = TensorShape(a.shape.dataExtent);
outShape[-1] = N;
out = Tensor::empty(outShape, Tensor::FP32, a.device());
}
......
#include "gemm_f16.h"
#include "dispatch_cutlass.h"
#include <cutlass/core_io.h>
#include <cutlass/cutlass.h>
#include <cutlass/half.h>
......@@ -13,8 +15,8 @@ using spdlog::fmt_lib::format;
Tensor gemm_f16(Tensor input, // FP16
Tensor weight, // FP16
Tensor out, // FP16
float alpha,
float beta
Tensor bias,
float alpha
) {
auto N = weight.size(0);
auto K = input.size(-1);
......@@ -23,102 +25,111 @@ Tensor gemm_f16(Tensor input, // FP16
spdlog::debug("gemm_f16: M={} K={} N={}", M, K, N);
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using ElementComputeEpilogue = cutlass::bfloat16_t;
using ElementInputA = cutlass::bfloat16_t; // <- data type of elements in input matrix A
using ElementInputB = cutlass::bfloat16_t; // <- data type of elements in input matrix B
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
// #if CUDA_ARCH >= 800
using Gemm = cutlass::gemm::device::Gemm<
ElementInputA, cutlass::layout::RowMajor, ElementInputB, cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementComputeEpilogue>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
auto input_size = cutlass::MatrixCoord(M, K);
auto weight_size = cutlass::MatrixCoord(K, N);
auto output_size = cutlass::MatrixCoord(M, N);
auto device = input.device();
// use the broadcasted bias as the output
// auto out = bias.to(device).view({1, -1}).repeat({M, 1});
if (!out.valid()) {
auto out_shape = input.shape;
out_shape[-1] = N;
out = Tensor::empty(out_shape, input.scalar_type(), input.device());
}
// FIXME: check contiguous of input if dims >= 3
assert(input.stride(-1) == 1);
// assert(input.is_contiguous());
assert(weight.is_contiguous());
assert(out.dtype() == input.scalar_type());
assert(out.shape[-1] == N);
assert(out.numel() / out.shape[-1] == M);
assert(out.stride(-1) == 1);
// FIXME: check contiguous of output if dims >= 3
// constexpr int kSparse = Gemm::kSparse;
// How many elements of A are covered per ElementE
// constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
// The size of individual meta data
// constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
cutlass::gemm::GemmCoord problem_size(M, N, K);
cutlass::TensorRef<ElementInputA, LayoutInputA> input_ref(
input.data_ptr<ElementInputA>(), LayoutInputA(input.stride(-2)));
cutlass::TensorRef<ElementInputB, LayoutInputB> weight_ref(
weight.data_ptr<ElementInputB>(), LayoutInputB::packed(weight_size));
cutlass::TensorRef<ElementOutput, LayoutOutput> out_ref(
out.data_ptr<ElementOutput>(), LayoutOutput(out.stride(-2)));
typename Gemm::Arguments arguments{
problem_size, // <- problem size of matrix multiplication
input_ref, // <- reference to matrix A on device
weight_ref, // <- reference to matrix B on device
out_ref, // <- reference to matrix C on device
out_ref, // <- reference to matrix D on device
{ElementOutput(alpha), ElementOutput(beta)},
1};
Gemm gemm_op;
// Using the arguments, query for extra workspace required for matrix
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
// cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
BufferCUDA workspace(workspace_size);
// Check the problem size is supported or not
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error(format("cutlass cannot implement M={} N={} K={}", M, N, K));
}
// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm_op.initialize(arguments, workspace.getPtr());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}
status = gemm_op();
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot run");
}
dispatchF16(weight.dtype(), [&]<typename half_t>() {
using ElementOutput = half_t;
using ElementAccumulator = float;
using ElementComputeEpilogue = half_t;
using ElementInputA = half_t; // <- data type of elements in input matrix A
using ElementInputB = half_t; // <- data type of elements in input matrix B
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
// #if CUDA_ARCH >= 800
using Gemm = cutlass::gemm::device::Gemm<
ElementInputA, cutlass::layout::RowMajor, ElementInputB, cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementComputeEpilogue>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
auto input_size = cutlass::MatrixCoord(M, K);
auto weight_size = cutlass::MatrixCoord(K, N);
auto output_size = cutlass::MatrixCoord(M, N);
auto device = input.device();
// use the broadcasted bias as the output
// auto out = bias.to(device).view({1, -1}).repeat({M, 1});
if (!out.valid()) {
auto out_shape = TensorShape(input.shape.dataExtent);
out_shape[-1] = N;
out = Tensor::empty(out_shape, input.scalar_type(), input.device());
}
// FIXME: check contiguous of input if dims >= 3
assert(input.stride(-1) == 1);
// assert(input.is_contiguous());
assert(weight.is_contiguous());
assert(out.dtype() == input.scalar_type());
assert(out.shape[-1] == N);
assert(out.numel() / out.shape[-1] == M);
assert(out.stride(-1) == 1);
// FIXME: check contiguous of output if dims >= 3
assert(!bias.valid() || (bias.ndims() == 1 && bias.shape[0] == N));
// constexpr int kSparse = Gemm::kSparse;
// How many elements of A are covered per ElementE
// constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
// The size of individual meta data
// constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
cutlass::gemm::GemmCoord problem_size(M, N, K);
cutlass::TensorRef<ElementInputA, LayoutInputA> input_ref(
input.data_ptr<ElementInputA>(), LayoutInputA(input.stride(-2)));
cutlass::TensorRef<ElementInputB, LayoutInputB> weight_ref(
weight.data_ptr<ElementInputB>(), LayoutInputB::packed(weight_size));
cutlass::TensorRef<ElementOutput, LayoutOutput> bias_ref(
bias.valid() ? bias.data_ptr<ElementOutput>() : out.data_ptr<ElementOutput>(), LayoutOutput(0));
cutlass::TensorRef<ElementOutput, LayoutOutput> out_ref(
out.data_ptr<ElementOutput>(), LayoutOutput(out.stride(-2)));
typename Gemm::Arguments arguments{
problem_size, // <- problem size of matrix multiplication
input_ref, // <- reference to matrix A on device
weight_ref, // <- reference to matrix B on device
bias_ref, // <- reference to matrix C on device
out_ref, // <- reference to matrix D on device
{ElementOutput(alpha), ElementOutput(bias.valid() ? 1.0f : 0.0f)},
1};
Gemm gemm_op;
// Using the arguments, query for extra workspace required for matrix
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
// cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
BufferCUDA workspace(workspace_size);
// Check the problem size is supported or not
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error(format("cutlass cannot implement M={} N={} K={}", M, N, K));
}
// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm_op.initialize(arguments, workspace.getPtr());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}
status = gemm_op();
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot run");
}
});
return out;
}
......@@ -7,6 +7,6 @@ Tensor gemm_f16(
Tensor input, // FP16
Tensor weight, // FP16
Tensor out, // FP16
float alpha,
float beta
Tensor bias,
float alpha
);
\ No newline at end of file
......@@ -82,7 +82,7 @@ Tensor gemm_w8a8_fp16(Tensor input, // INT8
// auto out = bias.to(device).view({1, -1}).repeat({M, 1});
if (!out.valid()) {
auto out_shape = input.shape;
auto out_shape = TensorShape(input.shape.dataExtent);
out_shape[-1] = N;
out = Tensor::empty(out_shape, Tensor::FP16, input.device());
}
......
......@@ -2,6 +2,8 @@
#include "misc_kernels.h"
#include "dispatch_utils.h"
namespace nunchaku::kernels {
Tensor add(Tensor a, Tensor b) {
assert(a.shape.dataExtent == b.shape.dataExtent);
assert(a.dtype() == b.dtype());
......@@ -34,11 +36,11 @@ void mul_add(Tensor x, Tensor scale, Tensor bias) {
constexpr int unroll = 8;
assert((uintptr_t)x.data_ptr() % (x.scalar_size() * unroll) == 0);
assert((uintptr_t)scale.data_ptr() % (x.scalar_size() * unroll) == 0);
assert(!scale.valid() || (uintptr_t)scale.data_ptr() % (x.scalar_size() * unroll) == 0);
assert((uintptr_t)bias.data_ptr() % (x.scalar_size() * unroll) == 0);
assert(x.numel() % unroll == 0);
assert(scale.numel() % unroll == 0);
assert(!scale.valid() || scale.numel() % unroll == 0);
assert(bias.numel() % unroll == 0);
int threadsPerBlock = 1024;
......@@ -47,8 +49,60 @@ void mul_add(Tensor x, Tensor scale, Tensor bias) {
auto stream = getCurrentCUDAStream();
dispatch(x.scalar_type(), [&]<typename scalar_t>() {
mul_add_kernel<scalar_t, unroll><<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
x.data_ptr<scalar_t>(), scale.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), x.numel(), scale.numel(), bias.numel());
if (scale.valid()) {
mul_add_kernel<scalar_t, unroll, false><<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
x.data_ptr<scalar_t>(), scale.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), 0, x.numel(), scale.numel(), bias.numel(), 0, 0, 0);
} else {
mul_add_kernel<scalar_t, unroll, true><<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
x.data_ptr<scalar_t>(), nullptr, bias.data_ptr<scalar_t>(), 0, x.numel(), 1, bias.numel(), 0, 0, 0);
}
});
}
void mul_add_batch(Tensor x, Tensor scale, bool batch_scale, double scale_shift, Tensor bias, bool batch_bias) {
const int batch_size = x.shape[0];
assert(!batch_scale || scale.shape[0] == batch_size);
assert(!batch_bias || bias.shape[0] == batch_size);
const int numel = x.numel() / batch_size;
const int numel_scale = scale.valid() ? (scale.numel() / (batch_scale ? batch_size : 1)) : 1;
const int numel_bias = bias.numel() / (batch_bias ? batch_size : 1);
assert(numel % numel_scale == 0);
assert(numel % numel_bias == 0);
assert(!scale.valid() || x.dtype() == scale.dtype());
assert(x.dtype() == bias.dtype());
constexpr int unroll = 8;
assert((uintptr_t)x.data_ptr() % (x.scalar_size() * unroll) == 0);
assert(!scale.valid() || (uintptr_t)scale.data_ptr() % (x.scalar_size() * unroll) == 0);
assert((uintptr_t)bias.data_ptr() % (x.scalar_size() * unroll) == 0);
assert(numel % unroll == 0);
assert(!scale.valid() || numel_scale % unroll == 0);
assert(numel_bias % unroll == 0);
int threadsPerBlock = 1024;
dim3 grid(ceilDiv(numel, threadsPerBlock * unroll), batch_size);
auto stream = getCurrentCUDAStream();
dispatch(x.scalar_type(), [&]<typename scalar_t>() {
if (scale.valid()) {
mul_add_kernel<scalar_t, unroll, false><<<grid, threadsPerBlock, 0, stream>>>(
x.data_ptr<scalar_t>(), scale.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(),
(scalar_t)scale_shift,
numel, numel_scale, numel_bias,
x.stride(0), batch_scale ? scale.stride(0) : 0, batch_bias ? bias.stride(0) : 0);
} else {
mul_add_kernel<scalar_t, unroll, true><<<grid, threadsPerBlock, 0, stream>>>(
x.data_ptr<scalar_t>(), nullptr, bias.data_ptr<scalar_t>(),
(scalar_t)scale_shift,
numel, 1, numel_bias,
x.stride(0), 0, batch_bias ? bias.stride(0) : 0);
}
});
}
......@@ -219,7 +273,7 @@ Tensor topk(Tensor x, int k) {
assert(k <= N);
assert(k <= MAXK);
auto outShape = x.shape;
auto outShape = TensorShape(x.shape.dataExtent);
outShape[-1] = k;
outShape.dataStride.clear();
......@@ -252,4 +306,6 @@ template std::array<Tensor, 2> split_mod<2>(Tensor input);
template std::array<Tensor, 3> split_mod<3>(Tensor input);
template std::array<Tensor, 4> split_mod<4>(Tensor input);
template std::array<Tensor, 5> split_mod<5>(Tensor input);
template std::array<Tensor, 6> split_mod<6>(Tensor input);
\ No newline at end of file
template std::array<Tensor, 6> split_mod<6>(Tensor input);
}; // namespace nunchaku::kernels
\ No newline at end of file
......@@ -3,8 +3,12 @@
#include "common.h"
#include "Tensor.h"
namespace nunchaku::kernels {
Tensor add(Tensor a, Tensor b);
void mul_add(Tensor x, Tensor scale, Tensor bias);
void mul_add_batch(Tensor x, Tensor scale, bool batch_scale, double scale_shift, Tensor bias, bool batch_bias);
Tensor embedding(Tensor input_id, Tensor lookup);
Tensor argmax_sample(Tensor logits);
void splitqkv(Tensor qkv, Tensor q, Tensor k, Tensor v);
......@@ -16,4 +20,6 @@ void cast(Tensor input, Tensor output);
Tensor topk(Tensor x, int k);
template<size_t N>
std::array<Tensor, N> split_mod(Tensor input);
\ No newline at end of file
std::array<Tensor, N> split_mod(Tensor input);
}; // namespace nunchaku::kernels
\ No newline at end of file
......@@ -7,6 +7,8 @@
#include "utils.cuh"
#include "activation_kernels_impl.cuh"
namespace nunchaku::kernels {
template<typename T>
__global__ void add_kernel(T *a, T *b, T *c, size_t length) {
......@@ -21,9 +23,9 @@ struct alignas(sizeof(T) * unroll) Tvec {
T data[unroll];
};
template<typename T, int unroll>
__global__ void mul_add_kernel(T *x, T *scale, T *bias, size_t length, int mod_scale, int mod_bias) {
template<typename T, int unroll, bool no_scale>
__global__ void mul_add_kernel(T *x, T *scale, T *bias, T scale_shift, size_t length, int mod_scale, int mod_bias, int64_t batch_stride_x, int64_t batch_stride_scale, int64_t batch_stride_bias) {
const int batch_id = blockIdx.y;
int thread = threadIdx.x + blockIdx.x * blockDim.x;
int i = thread * unroll;
int i_scale = i % mod_scale;
......@@ -33,15 +35,20 @@ __global__ void mul_add_kernel(T *x, T *scale, T *bias, size_t length, int mod_s
return;
}
using Tvec = ::Tvec<T, unroll>;
using Tvec = nunchaku::kernels::Tvec<T, unroll>;
Tvec rx = *reinterpret_cast<Tvec *>(&x[i]);
Tvec rscale = *reinterpret_cast<Tvec *>(&scale[i_scale]);
Tvec rbias = *reinterpret_cast<Tvec *>(&bias[i_bias]);
Tvec rx = *reinterpret_cast<Tvec *>(&x[i + batch_stride_x * batch_id]);
Tvec rscale = *reinterpret_cast<Tvec *>(&scale[i_scale + batch_stride_scale * batch_id]);
Tvec rbias = *reinterpret_cast<Tvec *>(&bias[i_bias + batch_stride_bias * batch_id]);
#pragma unroll
for (int k = 0; k < unroll; k++) {
T tmp = rx.data[k] * rscale.data[k] + rbias.data[k];
T tmp;
if constexpr (no_scale) {
tmp = rx.data[k] + rbias.data[k];
} else {
tmp = rx.data[k] * (rscale.data[k] + scale_shift) + rbias.data[k];
}
if constexpr (std::is_same_v<T, half>) {
tmp = __hmin(tmp, (half)65504);
tmp = __hmax(tmp, (half)-65504);
......@@ -49,7 +56,7 @@ __global__ void mul_add_kernel(T *x, T *scale, T *bias, size_t length, int mod_s
rx.data[k] = tmp;
}
*reinterpret_cast<Tvec *>(&x[i]) = rx;
*reinterpret_cast<Tvec *>(&x[i + batch_stride_x * batch_id]) = rx;
// #pragma unroll
// for (int k = 0; k < unroll; k++) {
......@@ -127,8 +134,8 @@ __global__ void quant_kernel_static(const T * input, int8_t * output, T scale, s
return;
}
using Tvec = ::Tvec<T, unroll>;
using I8vec = ::Tvec<int8_t, unroll>;
using Tvec = nunchaku::kernels::Tvec<T, unroll>;
using I8vec = nunchaku::kernels::Tvec<int8_t, unroll>;
Tvec rinput = *reinterpret_cast<const Tvec *>(&input[i]);
I8vec routput;
......@@ -149,8 +156,8 @@ __global__ void quant_kernel_static_fuse_gelu(const T * input, int8_t * output,
return;
}
using Tvec = ::Tvec<T, unroll>;
using I8vec = ::Tvec<int8_t, unroll>;
using Tvec = nunchaku::kernels::Tvec<T, unroll>;
using I8vec = nunchaku::kernels::Tvec<int8_t, unroll>;
Tvec rinput = *reinterpret_cast<const Tvec *>(&input[i]);
I8vec routput;
......@@ -168,8 +175,8 @@ template<typename Tin, typename Tout, int unroll>
__global__ void cast_kernel(const Tin *input, Tout *output, size_t length) {
const int i = (blockIdx.x * blockDim.x + threadIdx.x) * unroll;
using Tvec_in = ::Tvec<Tin, unroll>;
using Tvec_out = ::Tvec<Tout, unroll>;
using Tvec_in = nunchaku::kernels::Tvec<Tin, unroll>;
using Tvec_out = nunchaku::kernels::Tvec<Tout, unroll>;
Tvec_in rinput = *reinterpret_cast<const Tvec_in *>(&input[i]);
Tvec_out routput;
......@@ -250,4 +257,6 @@ void topk_kernel(const T *input, int *output, int N, int strideInput, int numRow
for (int i = 0; i < K; i++) {
output[row * K + i] = idx[K - i - 1];
}
}
\ No newline at end of file
}
}; // namespace nunchaku::kernels
\ No newline at end of file
#pragma once
#include "common.h"
#include "../utils.cuh"
#include "../dispatch_utils.h"
#include "gemm_utils.cuh"
#pragma nv_diag_suppress 177
#ifdef _MSC_VER
#define ALWAYSINLINE [[msvc::forceinline]]
#else
#define ALWAYSINLINE __attribute__((always_inline))
#endif
// #define ENABLE_NAN_CHECK 1
#if ENABLE_NAN_CHECK
#define STRINGIZE(x) STRINGIZE2(x)
#define STRINGIZE2(x) #x
#define CHECK_NAN(data, name) checkNan(data, name " at " STRINGIZE(__LINE__))
#else
#define CHECK_NAN(data, name)
#endif
namespace nunchaku::kernels {
template<bool bf16>
class GEMMConfig_W4A4 {
public:
// BE CAREFUL: weights need to be repacked when the tiling size changes
static constexpr int BLOCK_M = 256;
static constexpr int BLOCK_N = 128;
static constexpr int WARP_SIZE = 32;
static constexpr int NUM_WARPS = 8;
static constexpr int INSN_M = 16;
static constexpr int INSN_N = 16;
static constexpr int INSN_K = 64;
using half_t = typename std::conditional_t<bf16, __nv_bfloat16, half>;
using half2_t = typename std::conditional_t<bf16, __nv_bfloat162, half2>;
};
using GEMMConfig_W4A4_FP16 = GEMMConfig_W4A4<false>;
using GEMMConfig_W4A4_BF16 = GEMMConfig_W4A4<true>;
class GEMMConfig_W8A8 {
public:
static constexpr int BLOCK_M = 256;
static constexpr int BLOCK_N = 128;
static constexpr int WARP_SIZE = 32;
static constexpr int NUM_WARPS = 8;
static constexpr int INSN_M = 16;
static constexpr int INSN_N = 16;
static constexpr int INSN_K = 32;
#if 0
using half_t = half;
using half2_t = half2;
#else
using half_t = __nv_bfloat16;
using half2_t = __nv_bfloat162;
#endif
};
template<class Config>
class GEMMBase : public Config {
public:
using Config::BLOCK_M;
using Config::BLOCK_N;
using Config::WARP_SIZE;
using Config::NUM_WARPS;
using Config::INSN_M;
using Config::INSN_N;
using Config::INSN_K;
using typename Config::half_t;
using typename Config::half2_t;
static constexpr int WARP_M = BLOCK_M / NUM_WARPS;
static constexpr int WARP_N = BLOCK_N;
static constexpr int WARP_K = INSN_K;
static constexpr int WARP_M_TILES = WARP_M / INSN_M;
static constexpr int WARP_N_TILES = WARP_N / INSN_N;
static constexpr int WARP_K_TILES = WARP_K / INSN_K;
/**
* refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16864-c
*
* wscales store order: (pack = 4)
* 0 1 8 9 <-- load by lane 0, broadcast to lane {0, 4, 8, ..., 28} (8x)
* 2 3 10 11 <-- load by lane 1, broadcast to lane {1, 5, 9, ..., 29} (8x)
* 4 5 12 13 <-- load by lane 2, broadcast to lane {2, 6, 10, ..., 30} (8x)
* 6 7 14 15 <-- load by lane 3, broadcast to lane {3, 7, 11, ..., 31} (8x)
*
* 16 17 24 25 <-- load by lane 4, broadcast to lane {0, 4, 8, ..., 28} (8x)
* ...
* 22 23 30 31 <-- load by lane 7, broadcast to lane {3, 7, 11, ..., 31} (8x)
* ... ...
* 112 113 120 121 <-- load by lane 28, broadcast to lane {0, 4, 8, ..., 28} (8x)
* ...
* 118 119 126 127 <-- load by lane 31, broadcast to lane {3, 7, 11, ..., 31} (8x)
*
* wscales store order: (pack = 8)
* 0 1 8 9 16 17 24 25 <-- load by lane 0, broadcast to lane {0, 4, 8, ..., 28} (8x)
* 2 3 10 11 18 19 26 27 <-- load by lane 1, broadcast to lane {1, 5, 9, ..., 29} (8x)
* 4 5 12 13 20 21 28 29 <-- load by lane 2, broadcast to lane {2, 6, 10, ..., 30} (8x)
* 6 7 14 15 22 23 30 31 <-- load by lane 3, broadcast to lane {3, 7, 11, ..., 31} (8x)
*
* 224 225 232 233 240 241 248 249 <-- load by lane 28, broadcast to lane {0, 4, 8, ..., 28} (8x)
* ...
* 230 231 238 239 246 247 254 255 <-- load by lane 31, broadcast to lane {3, 7, 11, ..., 31} (8x)
*
* {k}-th wscale used by lane {i} => {k // (WSCALES_PACK_SIZE * WARP_SIZE)}-th pack, in lane {4*(k // WSCALES_PACK_SIZE) + i % 4}, element {k % WSCALES_PACK_SIZE}
*
* max pack size set to 8 since max load size is 16 bytes / lane
* min pack size set to 2 since shuffle granularity is 32b 2*half
* */
static constexpr int WSCALES_PACK_SIZE = clamp(WARP_N / WARP_SIZE, 4 / sizeof(half), 16 / sizeof(half));
static constexpr int WSCALES_NUM_PACKS = ceilDiv(WARP_N, (WSCALES_PACK_SIZE * WARP_SIZE));
static constexpr int WSCALES_VALID_LANES = std::min(WARP_SIZE, WARP_N / WSCALES_PACK_SIZE);
/**
* ascales store order: (pack = 2)
* 0 8 <-- load by lane 0, broadcast to lane {0, 1, 2, 3} (4x)
* 1 9 <-- load by lane 1, broadcast to lane {4, 5, 6, 7} (4x)
* 2 10
* ...
* 6 14
* 7 15 <-- load by lane 7, broadcast to lane {28, 29, 30, 31} (4x)
* ... ...
* 48 56 <-- load by lane 24, broadcast to lane {0, 1, 2, 3} (4x)
* 49 57
* ...
* 54 62
* 55 63 <-- load by lane 31, broadcast to lane {28, 29, 30, 31} (4x)
*
* {k}-th wscale used by lane {i} => {k // (ASCALES_PACK_SIZE * WARP_SIZE)}-th pack, in lane {8*(k // ASCALES_PACK_SIZE) + i // 4}, element {k % ASCALES_PACK_SIZE}
*/
static constexpr int ASCALES_PACK_SIZE = clamp(WARP_M / WARP_SIZE, 4 / sizeof(half), 16 / sizeof(half));
static constexpr int ASCALES_NUM_PACKS = ceilDiv(WARP_M, (ASCALES_PACK_SIZE * WARP_SIZE));
static constexpr int ASCALES_VALID_LANES = std::min(WARP_SIZE, WARP_M / ASCALES_PACK_SIZE);
using packed_act_t = uint4;
using packed_wgt_t = uint4;
struct alignas(32) packed_psum_t {
int data[8];
};
struct alignas(16) packed_fpsum_t {
half2_t data[4];
};
struct alignas(8) packed_gated_fpsum_t {
half_t data[4];
};
// 16 * 16 matrix
struct alignas(32) packed_f32psum_t {
float data[8];
static constexpr packed_f32psum_t zeros() {
packed_f32psum_t result;
for (int i = 0; i < 8; i++) {
result.data[i] = 0;
}
return result;
};
};
struct packed_wscale_t {
half2_t data[WSCALES_PACK_SIZE / 2];
};
struct packed_ascale_t {
half2_t data[ASCALES_PACK_SIZE / 2];
};
using act_warp = std::array<packed_act_t, WARP_M_TILES>;
using wgt_warp = std::array<packed_wgt_t, WARP_N_TILES>;
using ascale_warp = std::array<packed_ascale_t, ASCALES_NUM_PACKS>;
using wscale_warp = std::array<packed_wscale_t, WSCALES_NUM_PACKS>;
using fpsum_warp = std::array<packed_fpsum_t, WARP_M_TILES * WARP_N_TILES>;
using f32psum_warp = std::array<packed_f32psum_t, WARP_M_TILES * WARP_N_TILES>;
using gated_fpsum_warp = std::array<packed_gated_fpsum_t, WARP_M_TILES * WARP_N_TILES>;
struct BlockInfo {
int bm;
int bn;
int numBlocksM;
int numBlocksN;
};
__device__ __forceinline__
static packed_f32psum_t mma_f16xf16_f32(packed_fpsum_t a, packed_fpsum_t b, packed_f32psum_t psum) {
static_assert(std::is_same_v<half_t, half> || std::is_same_v<half_t, __nv_bfloat16>);
if constexpr (std::is_same_v<half_t, half>) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=f"(psum.data[0]), "=f"(psum.data[1]), "=f"(psum.data[2]), "=f"(psum.data[3])
:
"r"(*reinterpret_cast<unsigned int *>(&a.data[0])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[1])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[2])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[3])),
"r"(*reinterpret_cast<unsigned int *>(&b.data[0])),
"r"(*reinterpret_cast<unsigned int *>(&b.data[1])),
// "r"(0), "r"(0), "r"(0), "r"(0)
"f"(psum.data[0]), "f"(psum.data[1]), "f"(psum.data[2]), "f"(psum.data[3])
);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=f"(psum.data[4]), "=f"(psum.data[5]), "=f"(psum.data[6]), "=f"(psum.data[7])
:
"r"(*reinterpret_cast<unsigned int *>(&a.data[0])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[1])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[2])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[3])),
"r"(*reinterpret_cast<unsigned int *>(&b.data[2])),
"r"(*reinterpret_cast<unsigned int *>(&b.data[3])),
// "r"(0), "r"(0), "r"(0), "r"(0)
"f"(psum.data[4]), "f"(psum.data[5]), "f"(psum.data[6]), "f"(psum.data[7])
);
}
if constexpr (std::is_same_v<half_t, __nv_bfloat16>) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=f"(psum.data[0]), "=f"(psum.data[1]), "=f"(psum.data[2]), "=f"(psum.data[3])
:
"r"(*reinterpret_cast<unsigned int *>(&a.data[0])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[1])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[2])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[3])),
"r"(*reinterpret_cast<unsigned int *>(&b.data[0])),
"r"(*reinterpret_cast<unsigned int *>(&b.data[1])),
// "r"(0), "r"(0), "r"(0), "r"(0)
"f"(psum.data[0]), "f"(psum.data[1]), "f"(psum.data[2]), "f"(psum.data[3])
);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=f"(psum.data[4]), "=f"(psum.data[5]), "=f"(psum.data[6]), "=f"(psum.data[7])
:
"r"(*reinterpret_cast<unsigned int *>(&a.data[0])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[1])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[2])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[3])),
"r"(*reinterpret_cast<unsigned int *>(&b.data[2])),
"r"(*reinterpret_cast<unsigned int *>(&b.data[3])),
// "r"(0), "r"(0), "r"(0), "r"(0)
"f"(psum.data[4]), "f"(psum.data[5]), "f"(psum.data[6]), "f"(psum.data[7])
);
}
return psum;
}
__device__ __forceinline__
static packed_fpsum_t packed_fp32_to_fp16(packed_f32psum_t input) {
packed_fpsum_t results;
for (int i = 0; i < 4; i++) {
results.data[i] = float22half2<half2_t>(float2(input.data[i * 2], input.data[i * 2 + 1]));
}
return results;
}
__device__ __forceinline__
static packed_f32psum_t packed_fp16_to_fp32(packed_fpsum_t input) {
packed_f32psum_t results;
for (int i = 0; i < 4; i++) {
float2 tmp = half22float2(input.data[i]);
results.data[i * 2] = tmp.x;
results.data[i * 2 + 1] = tmp.y;
}
return results;
}
__device__ __forceinline__
static fpsum_warp packed_fp32_to_fp16(f32psum_warp input) {
fpsum_warp results;
#pragma unroll
for (int i = 0; i < results.size(); i++) {
results[i] = packed_fp32_to_fp16(input[i]);
}
return results;
}
// activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t
__device__ __forceinline__
static void load_act(const packed_act_t *act, int k, int K, act_warp &out, bool pred) {
int laneId = threadIdx.x % WARP_SIZE;
int warpId = threadIdx.x / WARP_SIZE;
#pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) {
if (pred) {
// out[i] = load(&act[((warpId * WARP_M_TILES + i) * K / WARP_K + k) * WARP_SIZE + laneId]);
out[i] = load(&act[((k * NUM_WARPS + warpId) * WARP_M_TILES + i) * WARP_SIZE + laneId]);
}
}
}
// weight: column major: [N / BLOCK_N, 1, K / WARP_K, WARP_N_TILES, WARP_SIZE] of packed_wgt_t
__device__ __forceinline__
static void load_wgt(const packed_wgt_t *wgt, int k, int K, wgt_warp &out, bool pred) {
int laneId = threadIdx.x % WARP_SIZE;
// const packed_wgt_t *ptr = &wgt[(0 * K / WARP_K + k) * WARP_SIZE + laneId];
const packed_wgt_t *ptr = &wgt[(0 + k * WARP_N_TILES) * WARP_SIZE + laneId];
// int offset = K / WARP_K * WARP_SIZE;
#pragma unroll
for (int i = 0; i < WARP_N_TILES; i++) {
if (pred) {
// out[i] = load(&wgt[(i * K / WARP_K + k) * WARP_SIZE + laneId]);
// out[i] = load(&wgt[(i + k * WARP_N_TILES) * WARP_SIZE + laneId]);
out[i] = load(&ptr[i * WARP_SIZE]);
// ptr += offset;
}
}
}
// ascales: row major [M / BLOCK_M, K / group size, NUM_WARPS, ASCALES_NUM_PACKS, ASCALES_VALID_LANES] of packed_ascale_t
__device__ __forceinline__
static void load_ascale(const packed_ascale_t *ascales, int group, int M, ascale_warp &out, bool pred) {
int laneId = threadIdx.x % WARP_SIZE;
int warpId = threadIdx.x / WARP_SIZE;
#pragma unroll
for (int i = 0; i < ASCALES_NUM_PACKS; i++) {
if (pred && laneId < ASCALES_VALID_LANES) {
// out[i] = ascales[(group * M / WARP_M + warpId) * ASCALES_VALID_LANES * ASCALES_NUM_PACKS + i * ASCALES_VALID_LANES + laneId];
out[i] = ascales[(group * NUM_WARPS + warpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES + i * ASCALES_VALID_LANES + laneId];
}
}
}
// wscales: column major [N / BLOCK_N, K / group size, 1, WSCALES_NUM_PACKS, WSCALES_VALID_LANES] of packed_wscale_t </del>
__device__ __forceinline__
static void load_wscale(const packed_wscale_t *wscales, int group, int N, wscale_warp &out, bool pred) {
int laneId = threadIdx.x % WARP_SIZE;
// static_assert(WSCALES_NUM_PACKS * WSCALES_VALID_LANES == 32);
// static_assert(sizeof(packed_wscale_t) == 8);
// const packed_wscale_t *ptr = &wscales[(group * WSCALES_NUM_PACKS + 0) * WSCALES_VALID_LANES + laneId];
// // const packed_wscale_t *ptr = (const packed_wscale_t *)((const char *)wscales) + ((group * WSCALES_NUM_PACKS + 0) * WSCALES_VALID_LANES + laneId) * sizeof(packed_wscale_t);
#pragma unroll
for (int i = 0; i < WSCALES_NUM_PACKS; i++) {
if (pred && laneId < WSCALES_VALID_LANES) {
// out[i] = wscales[group * N / WARP_N * WSCALES_VALID_LANES * WSCALES_NUM_PACKS + i * WSCALES_VALID_LANES + laneId];
// out[i] = load(&wscales[group * N / WARP_N * WSCALES_VALID_LANES * WSCALES_NUM_PACKS + i * WSCALES_VALID_LANES + laneId]);
out[i] = load(&wscales[(group * WSCALES_NUM_PACKS + i) * WSCALES_VALID_LANES + laneId]);
// out[i] = load(&ptr[i * WSCALES_VALID_LANES]);
}
}
}
// get {k}-th and {k+1}-th wscale from the block, k must be multiples of 2, k must be uniform across all lanes
__device__ __forceinline__
static half2_t broadcast_wscale(wscale_warp block, int k, int laneId) {
const int packIdx = k / (WSCALES_PACK_SIZE * WARP_SIZE);
const int srcLane = 4 * (k / WSCALES_PACK_SIZE) + laneId % 4;
const int elementIdx = k % WSCALES_PACK_SIZE / 2;
return __shfl_sync(~0, block[packIdx].data[elementIdx], srcLane);
}
// get {k}-th and {k+1}-th ascale from the block, k must be multiples of 2, k must be uniform across all lanes
__device__ __forceinline__
static half2_t broadcast_ascale(ascale_warp block, int k, int laneId) {
const int packIdx = k / (ASCALES_PACK_SIZE * WARP_SIZE);
const int srcLane = 8 * (k / ASCALES_PACK_SIZE) + laneId / 4;
const int elementIdx = k % ASCALES_PACK_SIZE / 2;
return __shfl_sync(~0, block[packIdx].data[elementIdx], srcLane);
}
template<typename F>
__device__ __forceinline__
static void apply_scales(F &&getpsum, ascale_warp ascale, wscale_warp wscale, fpsum_warp &fpsum) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
half2_t asx[WARP_M_TILES];
half2_t asy[WARP_M_TILES];
for (int i = 0; i < WARP_M_TILES; i++) {
half2_t as = broadcast_ascale(ascale, i * 2, laneId);
asx[i] = half2_t(as.x, as.x);
asy[i] = half2_t(as.y, as.y);
}
for (int j = 0; j < WARP_N_TILES; j++) {
half2_t ws1 = broadcast_wscale(wscale, j * 4, laneId);
half2_t ws2 = broadcast_wscale(wscale, j * 4 + 2, laneId);
for (int i = 0; i < WARP_M_TILES; i++) {
auto &fsum = fpsum[i * WARP_N_TILES + j];
packed_psum_t psum = getpsum(i, j);
// constexpr int target = 0;
// if (threadIdx.x == 3 && j == 1 && i == 0) {
// printf("before ws2 = %f %f fsum.data[%d] = %f %f\n", (float)ws2.x, (float)ws2.y, target, (float)fsum.data[target].x, (float)fsum.data[target].y);
// }
fsum.data[0] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[0]), __int2float_rn(psum.data[1]))), __hmul2(asx[i], ws1), fsum.data[0]);
fsum.data[1] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[2]), __int2float_rn(psum.data[3]))), __hmul2(asy[i], ws1), fsum.data[1]);
fsum.data[2] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[4]), __int2float_rn(psum.data[5]))), __hmul2(asx[i], ws2), fsum.data[2]);
fsum.data[3] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[6]), __int2float_rn(psum.data[7]))), __hmul2(asy[i], ws2), fsum.data[3]);
// if (threadIdx.x == 3 && j == 1 && i == 0) {
// printf("before ws2 = %f %f fsum.data[%d] = %f %f\n", (float)ws2.x, (float)ws2.y, target, (float)fsum.data[target].x, (float)fsum.data[target].y);
// }
}
}
}
template<typename F>
__device__ __forceinline__
static void apply_scales(F &&getpsum, ascale_warp ascale, wscale_warp wscale, f32psum_warp &fpsum) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
float2 asx[WARP_M_TILES];
float2 asy[WARP_M_TILES];
for (int i = 0; i < WARP_M_TILES; i++) {
half2_t as = broadcast_ascale(ascale, i * 2, laneId);
asx[i] = half22float2(half2_t(as.x, as.x));
asy[i] = half22float2(half2_t(as.y, as.y));
}
auto fma2 = [](float2 a, float2 b, float &cx, float &cy) ALWAYSINLINE {
cx += a.x * b.x;
cy += a.y * b.y;
};
for (int j = 0; j < WARP_N_TILES; j++) {
float2 ws1 = half22float2(broadcast_wscale(wscale, j * 4, laneId));
float2 ws2 = half22float2(broadcast_wscale(wscale, j * 4 + 2, laneId));
for (int i = 0; i < WARP_M_TILES; i++) {
auto &fsum = fpsum[i * WARP_N_TILES + j];
packed_psum_t psum = getpsum(i, j);
fma2(make_float2(__int2float_rn(psum.data[0]), __int2float_rn(psum.data[1])), asx[i] * ws1, fsum.data[0], fsum.data[1]);
fma2(make_float2(__int2float_rn(psum.data[2]), __int2float_rn(psum.data[3])), asy[i] * ws1, fsum.data[2], fsum.data[3]);
fma2(make_float2(__int2float_rn(psum.data[4]), __int2float_rn(psum.data[5])), asx[i] * ws2, fsum.data[4], fsum.data[5]);
fma2(make_float2(__int2float_rn(psum.data[6]), __int2float_rn(psum.data[7])), asy[i] * ws2, fsum.data[6], fsum.data[7]);
}
}
}
/**
* input: WARP_M of half (in shared memory, per-warp)
* output: [..., ASCALES_NUM_PACKS, ASCALES_VALID_LANES] in global memory, per-warp
*/
__device__ __forceinline__
static void pack_ascales(const half_t *input, packed_ascale_t *output) {
const int laneId = threadIdx.x % WARP_SIZE;
#pragma unroll
for (int j = 0; j < ASCALES_NUM_PACKS; j++) {
if (laneId < ASCALES_VALID_LANES) {
packed_ascale_t tmp;
#pragma unroll
for (int i = 0; i < ASCALES_PACK_SIZE; i += 2) {
tmp.data[i / 2].x = input[j * ASCALES_PACK_SIZE * WARP_SIZE + laneId / 8 * 8 * ASCALES_PACK_SIZE + laneId % 8 + i * 8];
tmp.data[i / 2].y = input[j * ASCALES_PACK_SIZE * WARP_SIZE + laneId / 8 * 8 * ASCALES_PACK_SIZE + laneId % 8 + (i + 1) * 8];
}
output[j * ASCALES_VALID_LANES + laneId] = tmp;
}
}
}
/**
* input: WARP_N of half (in shared memory, per-warp)
* output: [..., WSCALES_NUM_PACKS, WSCALES_VALID_LANES] in global memory, per-warp
*/
__device__ __forceinline__
static void pack_wscales(const half_t *input, packed_wscale_t *output) {
const int laneId = threadIdx.x % WARP_SIZE;
#pragma unroll
for (int j = 0; j < WSCALES_NUM_PACKS; j++) {
if (laneId < WSCALES_VALID_LANES) {
packed_wscale_t tmp;
#pragma unroll
for (int i = 0; i < WSCALES_PACK_SIZE; i += 2) {
tmp.data[i / 2] = *reinterpret_cast<const half2_t *>(&input[j * WSCALES_PACK_SIZE * WARP_SIZE + laneId / 4 * 4 * WSCALES_PACK_SIZE + laneId % 4 * 2 + i * 4]);
}
store(&output[j * WSCALES_VALID_LANES + laneId], tmp);
}
}
}
struct unpack_fpsum {
// +8 to prevent bank conflicts
using matrix_t = half_t[8][WARP_N + 8];
static constexpr int SHMEM_SIZE = sizeof(matrix_t);
static constexpr int PACK_SIZE = WARP_N / WARP_SIZE;
using pack_t = std::array<half_t, PACK_SIZE>;
// F (int rowId, pack_t &pack)
template<typename ...F>
__device__ __forceinline__
void operator()(fpsum_warp fpsum, half_t *output, int stride, int maxRows, int maxCols, void *shmem, F &&...plugins) {
const int laneId = threadIdx.x % WARP_SIZE;
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
// pack_t reduce_tmp;
// constexpr bool enableReduce = !std::is_void_v<FuncReduce>;
// if constexpr (enableReduce) {
// reduce_tmp.fill(reduce_initval);
// // reduce_tmp = load<true>(reinterpret_cast<pack_t *>(&reduce_result[laneId * PACK_SIZE]));
// }
// auto doReduce = [&reduce_tmp](pack_t pack) {
// if constexpr (enableReduce) {
// for (int i = 0; i < PACK_SIZE; i++) {
// reduce_tmp[i] = FuncReduce()(reduce_tmp[i], pack[i]);
// }
// }
// };
#pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) {
#pragma unroll
for (int j = 0; j < WARP_N_TILES; j++) {
packed_fpsum_t &fsum = fpsum[i * WARP_N_TILES + j];
int row = laneId / 4;
int col = laneId % 4 * 2 + j * INSN_N;
*reinterpret_cast<half2_t *>(&mat[row][col + 0]) = fsum.data[0];
*reinterpret_cast<half2_t *>(&mat[row][col + 8]) = fsum.data[2];
}
__syncwarp();
#pragma unroll
for (int row = 0; row < 8; row++) {
pack_t pack = *reinterpret_cast<pack_t *>(&mat[row][laneId * PACK_SIZE]);
// if constexpr (enableReduce) {
// doReduce(pack);
// }
(plugins(i * INSN_M + row, pack), ...);
bool pred = i * INSN_M + row < maxRows && laneId * PACK_SIZE < maxCols;
if (pred) {
store(reinterpret_cast<pack_t *>(&output[(i * INSN_M + row) * stride + laneId * PACK_SIZE]), pack);
}
}
__syncwarp();
#pragma unroll
for (int j = 0; j < WARP_N_TILES; j++) {
packed_fpsum_t &fsum = fpsum[i * WARP_N_TILES + j];
int row = laneId / 4;
int col = laneId % 4 * 2 + j * INSN_N;
*reinterpret_cast<half2_t *>(&mat[row][col + 0]) = fsum.data[1];
*reinterpret_cast<half2_t *>(&mat[row][col + 8]) = fsum.data[3];
}
__syncwarp();
#pragma unroll
for (int row = 0; row < 8; row++) {
pack_t pack = *reinterpret_cast<pack_t *>(&mat[row][laneId * PACK_SIZE]);
// if constexpr (enableReduce) {
// doReduce(pack);
// }
(plugins(i * INSN_M + 8 + row, pack), ...);
bool pred = i * INSN_M + 8 + row < maxRows && laneId * PACK_SIZE < maxCols;
if (pred) {
store(reinterpret_cast<pack_t *>(&output[(i * INSN_M + 8 + row) * stride + laneId * PACK_SIZE]), pack);
}
}
__syncwarp();
}
// if (enableReduce) {
// store<true>(reinterpret_cast<pack_t *>(&reduce_result[laneId * PACK_SIZE]), reduce_tmp);
// }
}
};
template<typename F>
__device__ __forceinline__
static fpsum_warp apply_act(fpsum_warp fpsum, F func) {
fpsum_warp result;
#pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) {
#pragma unroll
for (int j = 0; j < WARP_N_TILES; j++) {
#pragma unroll
for (int k = 0; k < 4; k++) {
half2_t &dst = result[i * WARP_N_TILES + j].data[k];
half2_t src = fpsum[i * WARP_N_TILES + j].data[k];
dst.x = func(src.x);
dst.y = func(src.y);
}
}
}
return result;
}
struct EpilogueDefault {
struct Arguments {
half_t *out;
int actualM, actualN;
};
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, Arguments args) {
const int warpId = threadIdx.x / WARP_SIZE;
__shared__ alignas(128) uint8_t shmem[NUM_WARPS][ceilDiv(unpack_fpsum::SHMEM_SIZE, 128) * 128];
const int m_offset = binfo.bm * BLOCK_M + warpId * WARP_M;
const int n_offset = binfo.bn * BLOCK_N;
unpack_fpsum()(
fpsum,
args.out + m_offset * args.actualN + n_offset,
args.actualN,
args.actualM - m_offset,
args.actualN - n_offset,
shmem[warpId],
[&](int rowId, unpack_fpsum::pack_t &pack) ALWAYSINLINE {
if constexpr (std::is_same_v<half_t, half>) {
#pragma unroll
for (int i = 0; i < pack.size(); i++) {
pack[i] = __hmin(pack[i], (half)65504);
pack[i] = __hmax(pack[i], (half)-65504);
}
}
}
);
}
};
struct EpilogueNop {
// workaround for layout mismatch between host and device code
struct Arguments { size_t unused; };
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, Arguments args) {
}
};
struct EpilogueBias {
struct Arguments {
const packed_wscale_t *bias; // [N / BLOCK_N, WSCALES_NUM_PACKS, WSCALES_VALID_LANES] of packed_wscale_t
};
__device__ __forceinline__
void apply_bias(fpsum_warp &fpsum, int M, int N, int K, const packed_wscale_t *bias) {
const int laneId = threadIdx.x % WARP_SIZE;
// if (laneId == 0) {
// printf("block.x=%d block.y=%d warpId=%d bias=%p\n", blockIdx.x, blockIdx.y, threadIdx.x / WARP_SIZE, bias);
// }
wscale_warp b;
load_wscale(bias, 0, N, b, true);
for (int j = 0; j < WARP_N_TILES; j++) {
half2_t b1 = broadcast_wscale(b, j * 4, laneId);
half2_t b2 = broadcast_wscale(b, j * 4 + 2, laneId);
for (int i = 0; i < WARP_M_TILES; i++) {
auto &fsum = fpsum[i * WARP_N_TILES + j];
fsum.data[0] = __hadd2(fsum.data[0], b1);
fsum.data[1] = __hadd2(fsum.data[1], b1);
fsum.data[2] = __hadd2(fsum.data[2], b2);
fsum.data[3] = __hadd2(fsum.data[3], b2);
}
}
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, Arguments args) {
const int bn = binfo.bn;
apply_bias(
fpsum, M, N, K,
args.bias + bn * WSCALES_NUM_PACKS * WSCALES_VALID_LANES
);
}
};
struct EpilogueSilu {
struct Arguments { size_t unused; };
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, Arguments args) {
fpsum = apply_act(fpsum, [](half_t x) { return silu(x); });
}
};
template<typename ...Epilogues>
struct EpilogueCombination {
using Arguments = std::tuple<typename Epilogues::Arguments...>;
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, Arguments args) {
// this function makes intellisense crashes :(
#if __INTELLISENSE__
__trap(); // should not happen when actually compiling
#else
std::tuple<Epilogues...> epilogues;
auto run = [&]<size_t idx>() {
std::get<idx>(epilogues).operator()(binfo, fpsum, M, N, K, std::get<idx>(args));
};
auto foreach = [&]<size_t ...Is>(std::index_sequence<Is...>) {
(run.template operator()<Is>(), ...);
};
foreach(std::make_index_sequence<sizeof...(Epilogues)>());
#endif
}
};
};
#define IMPORT_GEMM_BASE(config) \
using Base = GEMMBase<config>; \
using Base::BLOCK_M; \
using Base::BLOCK_N; \
using Base::WARP_SIZE; \
using Base::NUM_WARPS; \
using Base::INSN_M; \
using Base::INSN_N; \
using Base::INSN_K; \
using typename Base::half_t; \
using typename Base::half2_t; \
using Base::WARP_M; \
using Base::WARP_N; \
using Base::WARP_K; \
using Base::WARP_M_TILES; \
using Base::WARP_N_TILES; \
using Base::WARP_K_TILES; \
using Base::WSCALES_PACK_SIZE; \
using Base::WSCALES_NUM_PACKS; \
using Base::WSCALES_VALID_LANES; \
using Base::ASCALES_PACK_SIZE; \
using Base::ASCALES_NUM_PACKS; \
using Base::ASCALES_VALID_LANES; \
using typename Base::packed_act_t; \
using typename Base::packed_wgt_t; \
using typename Base::packed_psum_t; \
using typename Base::packed_fpsum_t; \
using typename Base::packed_gated_fpsum_t; \
using typename Base::packed_f32psum_t; \
using typename Base::packed_wscale_t; \
using typename Base::packed_ascale_t; \
using typename Base::act_warp; \
using typename Base::wgt_warp; \
using typename Base::ascale_warp; \
using typename Base::wscale_warp; \
using typename Base::fpsum_warp; \
using typename Base::f32psum_warp; \
using typename Base::gated_fpsum_warp; \
using typename Base::BlockInfo; \
using typename Base::unpack_fpsum; \
using typename Base::EpilogueDefault; \
using typename Base::EpilogueNop; \
using typename Base::EpilogueBias;
template<typename kernel, typename ...T>
__global__
static void invoke_kernel(T ...args) {
kernel()(args...);
}
template<typename T>
__global__
static void test_sizeof_device() {
printf("sizeof on device = %d\n", (int)sizeof(T));
}
template<typename T>
static void test_sizeof_host() {
printf("sizeof on host = %d\n", (int)sizeof(T));
}
template<typename T>
static void test_sizeof() {
printf("typeid = %s\n", typeid(T).name());
test_sizeof_host<T>();
test_sizeof_device<T><<<1, 1>>>();
checkCUDA(cudaDeviceSynchronize());
}
}; // namespace nunchaku::kernels
\ No newline at end of file
......@@ -2,7 +2,9 @@
#include <cstdint>
#include "common.h"
#include "utils.cuh"
#include "../utils.cuh"
namespace nunchaku::kernels {
static constexpr int clamp(int val, int min, int max) {
if (val < min)
......@@ -74,25 +76,19 @@ static void store(T *addr, T val) {
*addr = val;
}
template<typename T>
__device__ __forceinline__
float2 half22float2(T val);
template<>
__device__ __forceinline__
float2 half22float2<half2>(half2 val) {
static float2 half22float2(half2 val) {
return __half22float2(val);
}
template<>
__device__ __forceinline__
float2 half22float2<__nv_bfloat162>(__nv_bfloat162 val) {
static float2 half22float2(__nv_bfloat162 val) {
return __bfloat1622float2(val);
}
template<typename T>
__device__ __forceinline__
T float22half2(float2 val);
static T float22half2(float2 val) = delete;
template<>
__device__ __forceinline__
......@@ -108,7 +104,7 @@ __nv_bfloat162 float22half2<__nv_bfloat162>(float2 val) {
template<typename T>
__device__ __forceinline__
void unused_var(T &val, bool alwaysfalse) {
static void unused_var(T &val, bool alwaysfalse) {
volatile T *ptr = nullptr;
if (alwaysfalse) {
*ptr = val;
......@@ -218,7 +214,7 @@ static float cuda_sigmoidf (float a)
template<typename T>
__device__ __forceinline__
static T gelu_half2(T x) {
float2 xf = half22float2<T>(x);
float2 xf = half22float2(x);
float2 x3f = xf * xf * xf;
float t1 = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf.x + (0.044715f * x3f.x)));
float t2 = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf.y + (0.044715f * x3f.y)));
......@@ -242,6 +238,25 @@ static T silu(const T &x) {
// return (T)__fdividef((float)x, 1.0f + __expf((float)-x));
}
__device__ __forceinline__
static half2 h2div(half2 a, half2 b) {
float2 af = half22float2(a);
float2 bf = half22float2(b);
float2 of;
of.x = __fdividef(af.x, bf.x);
of.y = __fdividef(af.y, bf.y);
return float22half2<half2>(of);
};
__device__ __forceinline__
static __nv_bfloat162 h2div(__nv_bfloat162 a, __nv_bfloat162 b) {
float2 af = half22float2(a);
float2 bf = half22float2(b);
float2 of;
of.x = __fdividef(af.x, bf.x);
of.y = __fdividef(af.y, bf.y);
return float22half2<__nv_bfloat162>(of);
};
__device__ __forceinline__
static void reduce_add(float *addr, float val) {
asm volatile ("red.relaxed.gpu.global.add.f32 [%0], %1;" :: "l"(addr), "f"(val));
......@@ -254,4 +269,6 @@ static void unrolled_loop(F &&lambda) {
(lambda.template operator()<Is>(), ...);
};
call(std::make_integer_sequence<int, cnt>());
}
\ No newline at end of file
}
}; // namespace nunchaku::kernels
\ No newline at end of file
#include "zgemm.h"
#include "gemm_w4a4_launch.cuh"
namespace nunchaku::kernels {
template<typename F>
static void invoke_launch(Tensor::ScalarType dtype, F &&launch) {
if (dtype == Tensor::FP16) {
launch.template operator()<GEMMConfig_W4A4_FP16>();
} else if (dtype == Tensor::BF16) {
launch.template operator()<GEMMConfig_W4A4_BF16>();
} else {
assert(false);
}
}
void gemm_w4a4(
Tensor act, // packed act [M, K / 2]
Tensor wgt, // packed act [N, K / 2]
Tensor out, // linear [M, N]
Tensor qout, // packed act [M, N / 2]
Tensor ascales, // packed as [K / 64, M]
Tensor wscales, // packed ws [K / 64, N]
Tensor oscales, // packed as [N / 64, M]
Tensor poolout, // linear [M / PoolSize, N]
Tensor lora_act_in, // packed lora_act [M, R]
Tensor lora_up, // packed lora_wgt [N, R]
Tensor lora_down, // packed lora_wgt [N, R]
Tensor lora_act_out, // packed lora_act [M, R]
Tensor norm_q, // linear [HEAD_DIM]
Tensor norm_k, // linear [HEAD_DIM]
Tensor rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
Tensor bias, // packed ws [N]
Tensor smooth_factor, // packed ws [N], for quantization of the next layer
Tensor out_vk, // linear [B, num_heads, head_dim + 1, head_dim]
Tensor out_linearattn,// linear [B, (M), N / 3]
bool act_unsigned,
std::vector<float> lora_scales, // [R / 16]
bool fuse_silu
) {
invoke_launch(ascales.dtype(), [&]<typename Config>() {
GEMM_W4A4_Launch<Config>::gemm_w4a4(
act,
wgt,
out,
qout,
ascales,
wscales,
oscales,
poolout,
lora_act_in,
lora_up,
lora_down,
lora_act_out,
norm_q,
norm_k,
rotary_emb,
bias,
smooth_factor,
out_vk,
out_linearattn,
act_unsigned,
lora_scales,
fuse_silu
);
});
}
void linearattn_vk_mul_q(Tensor q, Tensor vk) {
invoke_launch(q.dtype(), [&]<typename Config>() {
GEMM_W4A4_Launch<Config>::linearattn_vk_mul_q(q, vk);
});
}
void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu) {
invoke_launch(input.dtype(), [&]<typename Config>() {
GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(
input, output, oscales, lora_down, lora_act_out, smooth, fuse_glu
);
});
}
void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales) {
invoke_launch(input.dtype(), [&]<typename Config>() {
GEMM_W4A4_Launch<Config>::quantize_w4a4_act(
input, output, oscales
);
});
}
void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales) {
invoke_launch(input.dtype(), [&]<typename Config>() {
GEMM_W4A4_Launch<Config>::quantize_w4a4_wgt(
input, output, oscales
);
});
}
};
\ No newline at end of file
#include "common.h"
#include "Tensor.h"
#pragma once
#include "utils.cuh"
#include "gemm_utils.cuh"
#include "gemm_base.cuh"
#include "dispatch_utils.h"
namespace nunchaku::kernels {
#pragma nv_diag_suppress 177
template<typename Config>
class GEMM_W4A4;
#ifdef _MSC_VER
#define ALWAYSINLINE [[msvc::forceinline]]
#ifndef __INTELLISENSE__
template<typename Config>
class GEMM_W4A4 : public GEMMBase<Config> {
#else
#define ALWAYSINLINE __attribute__((always_inline))
template<>
class GEMM_W4A4<GEMMConfig_W4A4_FP16> : public GEMMBase<GEMMConfig_W4A4_FP16> {
using Config = GEMMConfig_W4A4_FP16;
#endif
// #define ENABLE_NAN_CHECK 1
#if ENABLE_NAN_CHECK
#define STRINGIZE(x) STRINGIZE2(x)
#define STRINGIZE2(x) #x
#define CHECK_NAN(data, name) checkNan(data, name " at " STRINGIZE(__LINE__))
#else
#define CHECK_NAN(data, name)
#endif
class GEMMConfig_W4A4 {
public:
// BE CAREFUL: weights need to be repacked when the tiling size changes
static constexpr int BLOCK_M = 256;
static constexpr int BLOCK_N = 128;
static constexpr int WARP_SIZE = 32;
static constexpr int NUM_WARPS = 8;
static constexpr int INSN_M = 16;
static constexpr int INSN_N = 16;
static constexpr int INSN_K = 64;
#if 0
using half_t = half;
using half2_t = half2;
#else
using half_t = __nv_bfloat16;
using half2_t = __nv_bfloat162;
#endif
};
class GEMMConfig_W8A8 {
public:
static constexpr int BLOCK_M = 256;
static constexpr int BLOCK_N = 128;
static constexpr int WARP_SIZE = 32;
static constexpr int NUM_WARPS = 8;
IMPORT_GEMM_BASE(Config);
static constexpr int INSN_M = 16;
static constexpr int INSN_N = 16;
static constexpr int INSN_K = 32;
using half_t = half;
using half2_t = half2;
};
template<class Config>
class GEMMBase : public Config {
public:
using Config::BLOCK_M;
using Config::BLOCK_N;
using Config::WARP_SIZE;
using Config::NUM_WARPS;
using Config::INSN_M;
using Config::INSN_N;
using Config::INSN_K;
using typename Config::half_t;
using typename Config::half2_t;
static constexpr int WARP_M = BLOCK_M / NUM_WARPS;
static constexpr int WARP_N = BLOCK_N;
static constexpr int WARP_K = INSN_K;
static constexpr int WARP_M_TILES = WARP_M / INSN_M;
static constexpr int WARP_N_TILES = WARP_N / INSN_N;
static constexpr int WARP_K_TILES = WARP_K / INSN_K;
/**
* refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16864-c
*
* wscales store order: (pack = 4)
* 0 1 8 9 <-- load by lane 0, broadcast to lane {0, 4, 8, ..., 28} (8x)
* 2 3 10 11 <-- load by lane 1, broadcast to lane {1, 5, 9, ..., 29} (8x)
* 4 5 12 13 <-- load by lane 2, broadcast to lane {2, 6, 10, ..., 30} (8x)
* 6 7 14 15 <-- load by lane 3, broadcast to lane {3, 7, 11, ..., 31} (8x)
*
* 16 17 24 25 <-- load by lane 4, broadcast to lane {0, 4, 8, ..., 28} (8x)
* ...
* 22 23 30 31 <-- load by lane 7, broadcast to lane {3, 7, 11, ..., 31} (8x)
* ... ...
* 112 113 120 121 <-- load by lane 28, broadcast to lane {0, 4, 8, ..., 28} (8x)
* ...
* 118 119 126 127 <-- load by lane 31, broadcast to lane {3, 7, 11, ..., 31} (8x)
*
* wscales store order: (pack = 8)
* 0 1 8 9 16 17 24 25 <-- load by lane 0, broadcast to lane {0, 4, 8, ..., 28} (8x)
* 2 3 10 11 18 19 26 27 <-- load by lane 1, broadcast to lane {1, 5, 9, ..., 29} (8x)
* 4 5 12 13 20 21 28 29 <-- load by lane 2, broadcast to lane {2, 6, 10, ..., 30} (8x)
* 6 7 14 15 22 23 30 31 <-- load by lane 3, broadcast to lane {3, 7, 11, ..., 31} (8x)
*
* 224 225 232 233 240 241 248 249 <-- load by lane 28, broadcast to lane {0, 4, 8, ..., 28} (8x)
* ...
* 230 231 238 239 246 247 254 255 <-- load by lane 31, broadcast to lane {3, 7, 11, ..., 31} (8x)
*
* {k}-th wscale used by lane {i} => {k // (WSCALES_PACK_SIZE * WARP_SIZE)}-th pack, in lane {4*(k // WSCALES_PACK_SIZE) + i % 4}, element {k % WSCALES_PACK_SIZE}
*
* max pack size set to 8 since max load size is 16 bytes / lane
* min pack size set to 2 since shuffle granularity is 32b 2*half
* */
static constexpr int WSCALES_PACK_SIZE = clamp(WARP_N / WARP_SIZE, 4 / sizeof(half), 16 / sizeof(half));
static constexpr int WSCALES_NUM_PACKS = ceilDiv(WARP_N, (WSCALES_PACK_SIZE * WARP_SIZE));
static constexpr int WSCALES_VALID_LANES = std::min(WARP_SIZE, WARP_N / WSCALES_PACK_SIZE);
/**
* ascales store order: (pack = 2)
* 0 8 <-- load by lane 0, broadcast to lane {0, 1, 2, 3} (4x)
* 1 9 <-- load by lane 1, broadcast to lane {4, 5, 6, 7} (4x)
* 2 10
* ...
* 6 14
* 7 15 <-- load by lane 7, broadcast to lane {28, 29, 30, 31} (4x)
* ... ...
* 48 56 <-- load by lane 24, broadcast to lane {0, 1, 2, 3} (4x)
* 49 57
* ...
* 54 62
* 55 63 <-- load by lane 31, broadcast to lane {28, 29, 30, 31} (4x)
*
* {k}-th wscale used by lane {i} => {k // (ASCALES_PACK_SIZE * WARP_SIZE)}-th pack, in lane {8*(k // ASCALES_PACK_SIZE) + i // 4}, element {k % ASCALES_PACK_SIZE}
*/
static constexpr int ASCALES_PACK_SIZE = clamp(WARP_M / WARP_SIZE, 4 / sizeof(half), 16 / sizeof(half));
static constexpr int ASCALES_NUM_PACKS = ceilDiv(WARP_M, (ASCALES_PACK_SIZE * WARP_SIZE));
static constexpr int ASCALES_VALID_LANES = std::min(WARP_SIZE, WARP_M / ASCALES_PACK_SIZE);
using packed_act_t = uint4;
using packed_wgt_t = uint4;
struct alignas(32) packed_psum_t {
int data[8];
};
struct alignas(16) packed_fpsum_t {
half2_t data[4];
};
struct alignas(8) packed_gated_fpsum_t {
half_t data[4];
};
// 16 * 16 matrix
struct alignas(32) packed_f32psum_t {
float data[8];
static constexpr packed_f32psum_t zeros() {
packed_f32psum_t result;
for (int i = 0; i < 8; i++) {
result.data[i] = 0;
}
return result;
};
};
struct packed_wscale_t {
half2_t data[WSCALES_PACK_SIZE / 2];
};
struct packed_ascale_t {
half2_t data[ASCALES_PACK_SIZE / 2];
};
using act_warp = std::array<packed_act_t, WARP_M_TILES>;
using wgt_warp = std::array<packed_wgt_t, WARP_N_TILES>;
using ascale_warp = std::array<packed_ascale_t, ASCALES_NUM_PACKS>;
using wscale_warp = std::array<packed_wscale_t, WSCALES_NUM_PACKS>;
using fpsum_warp = std::array<packed_fpsum_t, WARP_M_TILES * WARP_N_TILES>;
using f32psum_warp = std::array<packed_f32psum_t, WARP_M_TILES * WARP_N_TILES>;
using gated_fpsum_warp = std::array<packed_gated_fpsum_t, WARP_M_TILES * WARP_N_TILES>;
struct BlockInfo {
int bm;
int bn;
int numBlocksM;
int numBlocksN;
};
template<bool ACT_UNSIGNED>
__device__ __forceinline__
static packed_f32psum_t mma_f16xf16_f32(packed_fpsum_t a, packed_fpsum_t b, packed_f32psum_t psum) {
static_assert(std::is_same_v<half_t, half> || std::is_same_v<half_t, __nv_bfloat16>);
static packed_psum_t mma(packed_act_t act, packed_wgt_t wgt) {
packed_psum_t psum;
if constexpr (std::is_same_v<half_t, half>) {
if constexpr (!ACT_UNSIGNED) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=f"(psum.data[0]), "=f"(psum.data[1]), "=f"(psum.data[2]), "=f"(psum.data[3])
"=r"(psum.data[0]), "=r"(psum.data[1]), "=r"(psum.data[2]), "=r"(psum.data[3])
:
"r"(*reinterpret_cast<unsigned int *>(&a.data[0])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[1])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[2])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[3])),
"r"(*reinterpret_cast<unsigned int *>(&b.data[0])),
"r"(*reinterpret_cast<unsigned int *>(&b.data[1])),
// "r"(0), "r"(0), "r"(0), "r"(0)
"f"(psum.data[0]), "f"(psum.data[1]), "f"(psum.data[2]), "f"(psum.data[3])
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.x), "r"(wgt.y),
"r"(0), "r"(0), "r"(0), "r"(0)
// "r"(psum.data[0]), "r"(psum.data[1]), "r"(psum.data[2]), "r"(psum.data[3])
);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=f"(psum.data[4]), "=f"(psum.data[5]), "=f"(psum.data[6]), "=f"(psum.data[7])
"=r"(psum.data[4]), "=r"(psum.data[5]), "=r"(psum.data[6]), "=r"(psum.data[7])
:
"r"(*reinterpret_cast<unsigned int *>(&a.data[0])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[1])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[2])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[3])),
"r"(*reinterpret_cast<unsigned int *>(&b.data[2])),
"r"(*reinterpret_cast<unsigned int *>(&b.data[3])),
// "r"(0), "r"(0), "r"(0), "r"(0)
"f"(psum.data[4]), "f"(psum.data[5]), "f"(psum.data[6]), "f"(psum.data[7])
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.z), "r"(wgt.w),
"r"(0), "r"(0), "r"(0), "r"(0)
// "r"(psum.data[4]), "r"(psum.data[5]), "r"(psum.data[6]), "r"(psum.data[7])
);
}
if constexpr (std::is_same_v<half_t, __nv_bfloat16>) {
if constexpr (ACT_UNSIGNED) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=f"(psum.data[0]), "=f"(psum.data[1]), "=f"(psum.data[2]), "=f"(psum.data[3])
"=r"(psum.data[0]), "=r"(psum.data[1]), "=r"(psum.data[2]), "=r"(psum.data[3])
:
"r"(*reinterpret_cast<unsigned int *>(&a.data[0])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[1])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[2])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[3])),
"r"(*reinterpret_cast<unsigned int *>(&b.data[0])),
"r"(*reinterpret_cast<unsigned int *>(&b.data[1])),
// "r"(0), "r"(0), "r"(0), "r"(0)
"f"(psum.data[0]), "f"(psum.data[1]), "f"(psum.data[2]), "f"(psum.data[3])
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.x), "r"(wgt.y),
"r"(0), "r"(0), "r"(0), "r"(0)
// "r"(psum.data[0]), "r"(psum.data[1]), "r"(psum.data[2]), "r"(psum.data[3])
);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=f"(psum.data[4]), "=f"(psum.data[5]), "=f"(psum.data[6]), "=f"(psum.data[7])
"=r"(psum.data[4]), "=r"(psum.data[5]), "=r"(psum.data[6]), "=r"(psum.data[7])
:
"r"(*reinterpret_cast<unsigned int *>(&a.data[0])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[1])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[2])),
"r"(*reinterpret_cast<unsigned int *>(&a.data[3])),
"r"(*reinterpret_cast<unsigned int *>(&b.data[2])),
"r"(*reinterpret_cast<unsigned int *>(&b.data[3])),
// "r"(0), "r"(0), "r"(0), "r"(0)
"f"(psum.data[4]), "f"(psum.data[5]), "f"(psum.data[6]), "f"(psum.data[7])
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.z), "r"(wgt.w),
"r"(0), "r"(0), "r"(0), "r"(0)
// "r"(psum.data[4]), "r"(psum.data[5]), "r"(psum.data[6]), "r"(psum.data[7])
);
}
return psum;
}
// template<bool si>
template<bool use_unsigned>
__device__ __forceinline__
static packed_fpsum_t packed_fp32_to_fp16(packed_f32psum_t input) {
packed_fpsum_t results;
for (int i = 0; i < 4; i++) {
results.data[i] = float22half2<half2_t>(float2(input.data[i * 2], input.data[i * 2 + 1]));
}
return results;
}
static void quantize_w4a4_from_fpsum_warp(const packed_fpsum_t (&fpsum)[INSN_K / INSN_N], packed_act_t &output, half_t *output_scale) {
const int laneId = threadIdx.x % WARP_SIZE;
__device__ __forceinline__
static packed_f32psum_t packed_fp16_to_fp32(packed_fpsum_t input) {
packed_f32psum_t results;
for (int i = 0; i < 4; i++) {
float2 tmp = half22float2(input.data[i]);
results.data[i * 2] = tmp.x;
results.data[i * 2 + 1] = tmp.y;
}
return results;
}
constexpr float QVALUE_MAX_SIGNED = 7.0f;
constexpr float QVALUE_MAX_UNSIGNED = 15.0f;
constexpr float RECPI_QVALUE_MAX_SIGNED = 1 / QVALUE_MAX_SIGNED;
constexpr float RECPI_QVALUE_MAX_UNSIGNED = 1 / QVALUE_MAX_UNSIGNED;
__device__ __forceinline__
static fpsum_warp packed_fp32_to_fp16(f32psum_warp input) {
fpsum_warp results;
#pragma unroll
for (int i = 0; i < results.size(); i++) {
results[i] = packed_fp32_to_fp16(input[i]);
}
return results;
}
constexpr float QVALUE_MAX = use_unsigned ? QVALUE_MAX_UNSIGNED : QVALUE_MAX_SIGNED;
constexpr float RECPI_QVALUE_MAX = use_unsigned ? RECPI_QVALUE_MAX_UNSIGNED : RECPI_QVALUE_MAX_SIGNED;
// constexpr int QUANTIZE_BITMASK = 0xf;
// 0 for row 0-7; 1 for row 8-15
half2_t input[2][INSN_K / INSN_N * 2];
// activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t
__device__ __forceinline__
static void load_act(const packed_act_t *act, int k, int K, act_warp &out, bool pred) {
int laneId = threadIdx.x % WARP_SIZE;
int warpId = threadIdx.x / WARP_SIZE;
#pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) {
if (pred) {
// out[i] = load(&act[((warpId * WARP_M_TILES + i) * K / WARP_K + k) * WARP_SIZE + laneId]);
out[i] = load(&act[((k * NUM_WARPS + warpId) * WARP_M_TILES + i) * WARP_SIZE + laneId]);
}
for (int i = 0; i < INSN_K / INSN_N; i++) {
input[0][i * 2 + 0] = fpsum[i].data[0];
input[0][i * 2 + 1] = fpsum[i].data[2];
input[1][i * 2 + 0] = fpsum[i].data[1];
input[1][i * 2 + 1] = fpsum[i].data[3];
}
}
// weight: column major: [N / BLOCK_N, 1, K / WARP_K, WARP_N_TILES, WARP_SIZE] of packed_wgt_t
__device__ __forceinline__
static void load_wgt(const packed_wgt_t *wgt, int k, int K, wgt_warp &out, bool pred) {
int laneId = threadIdx.x % WARP_SIZE;
// const packed_wgt_t *ptr = &wgt[(0 * K / WARP_K + k) * WARP_SIZE + laneId];
const packed_wgt_t *ptr = &wgt[(0 + k * WARP_N_TILES) * WARP_SIZE + laneId];
// int offset = K / WARP_K * WARP_SIZE;
half_t maxvalue[2];
maxvalue[0] = 0;
maxvalue[1] = 0;
#pragma unroll
for (int i = 0; i < WARP_N_TILES; i++) {
if (pred) {
// out[i] = load(&wgt[(i * K / WARP_K + k) * WARP_SIZE + laneId]);
// out[i] = load(&wgt[(i + k * WARP_N_TILES) * WARP_SIZE + laneId]);
out[i] = load(&ptr[i * WARP_SIZE]);
// ptr += offset;
}
for (int i = 0; i < INSN_K / INSN_M * 2; i++) {
half2_t abs0 = __habs2(input[0][i]);
half2_t abs1 = __habs2(input[1][i]);
maxvalue[0] = __hmax(maxvalue[0], __hmax(abs0.x, abs0.y));
maxvalue[1] = __hmax(maxvalue[1], __hmax(abs1.x, abs1.y));
}
}
// ascales: row major [M / BLOCK_M, K / group size, NUM_WARPS, ASCALES_NUM_PACKS, ASCALES_VALID_LANES] of packed_ascale_t
__device__ __forceinline__
static void load_ascale(const packed_ascale_t *ascales, int group, int M, ascale_warp &out, bool pred) {
int laneId = threadIdx.x % WARP_SIZE;
int warpId = threadIdx.x / WARP_SIZE;
#pragma unroll
for (int i = 0; i < ASCALES_NUM_PACKS; i++) {
if (pred && laneId < ASCALES_VALID_LANES) {
// out[i] = ascales[(group * M / WARP_M + warpId) * ASCALES_VALID_LANES * ASCALES_NUM_PACKS + i * ASCALES_VALID_LANES + laneId];
out[i] = ascales[(group * NUM_WARPS + warpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES + i * ASCALES_VALID_LANES + laneId];
}
for (int mask = 2; mask > 0; mask /= 2) {
maxvalue[0] = __hmax(maxvalue[0], __shfl_xor_sync(~0, maxvalue[0], mask));
maxvalue[1] = __hmax(maxvalue[1], __shfl_xor_sync(~0, maxvalue[1], mask));
}
}
// wscales: column major [N / BLOCK_N, K / group size, 1, WSCALES_NUM_PACKS, WSCALES_VALID_LANES] of packed_wscale_t </del>
__device__ __forceinline__
static void load_wscale(const packed_wscale_t *wscales, int group, int N, wscale_warp &out, bool pred) {
int laneId = threadIdx.x % WARP_SIZE;
maxvalue[0] = __shfl_sync(~0, maxvalue[0], laneId / 4 * 4);
maxvalue[1] = __shfl_sync(~0, maxvalue[1], laneId / 4 * 4);
// static_assert(WSCALES_NUM_PACKS * WSCALES_VALID_LANES == 32);
// static_assert(sizeof(packed_wscale_t) == 8);
float scale[2];
// scale[0] = float(maxvalue[0]) / QVALUE_MAX;
// scale[1] = float(maxvalue[1]) / QVALUE_MAX;
scale[0] = float(maxvalue[0]) * RECPI_QVALUE_MAX;
scale[1] = float(maxvalue[1]) * RECPI_QVALUE_MAX;
if (laneId % 4 == 0) {
output_scale[laneId / 4] = half_t(scale[0]);
output_scale[laneId / 4 + 8] = half_t(scale[1]);
}
// const packed_wscale_t *ptr = &wscales[(group * WSCALES_NUM_PACKS + 0) * WSCALES_VALID_LANES + laneId];
// // const packed_wscale_t *ptr = (const packed_wscale_t *)((const char *)wscales) + ((group * WSCALES_NUM_PACKS + 0) * WSCALES_VALID_LANES + laneId) * sizeof(packed_wscale_t);
float rscale[2];
// rscale[0] = QVALUE_MAX / float(maxvalue[0]);
// rscale[1] = QVALUE_MAX / float(maxvalue[1]);
rscale[0] = cuda_frcp(scale[0]);
rscale[1] = cuda_frcp(scale[1]);
uint32_t qpacks[2][INSN_K / INSN_M * 2];
#pragma unroll
for (int i = 0; i < WSCALES_NUM_PACKS; i++) {
if (pred && laneId < WSCALES_VALID_LANES) {
// out[i] = wscales[group * N / WARP_N * WSCALES_VALID_LANES * WSCALES_NUM_PACKS + i * WSCALES_VALID_LANES + laneId];
// out[i] = load(&wscales[group * N / WARP_N * WSCALES_VALID_LANES * WSCALES_NUM_PACKS + i * WSCALES_VALID_LANES + laneId]);
out[i] = load(&wscales[(group * WSCALES_NUM_PACKS + i) * WSCALES_VALID_LANES + laneId]);
// out[i] = load(&ptr[i * WSCALES_VALID_LANES]);
for (int i = 0; i < INSN_K / INSN_M * 2; i++) {
#pragma unroll
for (int j = 0; j < 2; j++) {
// half2_t hval = __hmul2(input[j][i], half2_t(rscale[j], rscale[j]));
// float2 fval = half22float2(hval);
float2 fval = half22float2(input[j][i]) * make_float2(rscale[j], rscale[j]);
qpacks[j][i] = quantize_float2<4, use_unsigned>(fval) << (laneId % 4 * 8);
}
}
}
// get {k}-th and {k+1}-th wscale from the block, k must be multiples of 2, k must be uniform across all lanes
__device__ __forceinline__
static half2_t broadcast_wscale(wscale_warp block, int k, int laneId) {
const int packIdx = k / (WSCALES_PACK_SIZE * WARP_SIZE);
const int srcLane = 4 * (k / WSCALES_PACK_SIZE) + laneId % 4;
const int elementIdx = k % WSCALES_PACK_SIZE / 2;
return __shfl_sync(~0, block[packIdx].data[elementIdx], srcLane);
}
// get {k}-th and {k+1}-th ascale from the block, k must be multiples of 2, k must be uniform across all lanes
__device__ __forceinline__
static half2_t broadcast_ascale(ascale_warp block, int k, int laneId) {
const int packIdx = k / (ASCALES_PACK_SIZE * WARP_SIZE);
const int srcLane = 8 * (k / ASCALES_PACK_SIZE) + laneId / 4;
const int elementIdx = k % ASCALES_PACK_SIZE / 2;
return __shfl_sync(~0, block[packIdx].data[elementIdx], srcLane);
}
template<typename F>
__device__ __forceinline__
static void apply_scales(F &&getpsum, ascale_warp ascale, wscale_warp wscale, fpsum_warp &fpsum) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
half2_t asx[WARP_M_TILES];
half2_t asy[WARP_M_TILES];
for (int i = 0; i < WARP_M_TILES; i++) {
half2_t as = broadcast_ascale(ascale, i * 2, laneId);
asx[i] = half2_t(as.x, as.x);
asy[i] = half2_t(as.y, as.y);
// 2 * 8 * 2 = 32 instructions => 256 cycles
#pragma unroll
for (int mask = 1; mask <= 2; mask *= 2) {
#pragma unroll
for (int i = 0; i < INSN_K / INSN_M * 2; i++) {
#pragma unroll
for (int j = 0; j < 2; j++) {
qpacks[j][i] |= __shfl_xor_sync(~0, qpacks[j][i], mask);
}
}
}
// lane 0,1,2,3 / 4,5,6,7 / ... should have identical qpacks now
for (int j = 0; j < WARP_N_TILES; j++) {
half2_t ws1 = broadcast_wscale(wscale, j * 4, laneId);
half2_t ws2 = broadcast_wscale(wscale, j * 4 + 2, laneId);
for (int i = 0; i < WARP_M_TILES; i++) {
auto &fsum = fpsum[i * WARP_N_TILES + j];
packed_psum_t psum = getpsum(i, j);
// constexpr int target = 0;
// if (threadIdx.x == 3 && j == 1 && i == 0) {
// printf("before ws2 = %f %f fsum.data[%d] = %f %f\n", (float)ws2.x, (float)ws2.y, target, (float)fsum.data[target].x, (float)fsum.data[target].y);
// }
fsum.data[0] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[0]), __int2float_rn(psum.data[1]))), __hmul2(asx[i], ws1), fsum.data[0]);
fsum.data[1] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[2]), __int2float_rn(psum.data[3]))), __hmul2(asy[i], ws1), fsum.data[1]);
fsum.data[2] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[4]), __int2float_rn(psum.data[5]))), __hmul2(asx[i], ws2), fsum.data[2]);
fsum.data[3] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[6]), __int2float_rn(psum.data[7]))), __hmul2(asy[i], ws2), fsum.data[3]);
// if (threadIdx.x == 3 && j == 1 && i == 0) {
// printf("before ws2 = %f %f fsum.data[%d] = %f %f\n", (float)ws2.x, (float)ws2.y, target, (float)fsum.data[target].x, (float)fsum.data[target].y);
// }
#pragma unroll
for (int i = 0; i < 4; i++) {
if (laneId % 4 == i) {
output.x = qpacks[0][0 + i];
output.y = qpacks[1][0 + i];
output.z = qpacks[0][4 + i];
output.w = qpacks[1][4 + i];
}
}
}
template<typename F>
__device__ __forceinline__
static void apply_scales(F &&getpsum, ascale_warp ascale, wscale_warp wscale, f32psum_warp &fpsum) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
float2 asx[WARP_M_TILES];
float2 asy[WARP_M_TILES];
for (int i = 0; i < WARP_M_TILES; i++) {
half2_t as = broadcast_ascale(ascale, i * 2, laneId);
asx[i] = half22float2(half2_t(as.x, as.x));
asy[i] = half22float2(half2_t(as.y, as.y));
}
auto fma2 = [](float2 a, float2 b, float &cx, float &cy) ALWAYSINLINE {
cx += a.x * b.x;
cy += a.y * b.y;
};
// loads act of [WARP_M, WARP_N] and stores to fpsum_warp
// [WARP_M, WARP_N * 2] when fuse_glu
template<bool fuse_glu>
struct load_act_to_fpsum {
using matrix_t = half_t[WARP_M][WARP_N + 8];
static constexpr size_t SHMEM_SIZE = sizeof(matrix_t);
for (int j = 0; j < WARP_N_TILES; j++) {
float2 ws1 = half22float2(broadcast_wscale(wscale, j * 4, laneId));
float2 ws2 = half22float2(broadcast_wscale(wscale, j * 4 + 2, laneId));
__device__ __forceinline__
void operator()(const half_t *input, int stride, int maxRows, int maxCols, fpsum_warp &out, void *shmem) {
const int laneId = threadIdx.x % WARP_SIZE;
for (int i = 0; i < WARP_M_TILES; i++) {
auto &fsum = fpsum[i * WARP_N_TILES + j];
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
packed_psum_t psum = getpsum(i, j);
constexpr int PACK_SIZE = WARP_N / WARP_SIZE;
using packed_input = std::array<half_t, PACK_SIZE>;
using packed_raw_input = std::array<half2_t, PACK_SIZE>;
fma2(make_float2(__int2float_rn(psum.data[0]), __int2float_rn(psum.data[1])), asx[i] * ws1, fsum.data[0], fsum.data[1]);
fma2(make_float2(__int2float_rn(psum.data[2]), __int2float_rn(psum.data[3])), asy[i] * ws1, fsum.data[2], fsum.data[3]);
fma2(make_float2(__int2float_rn(psum.data[4]), __int2float_rn(psum.data[5])), asx[i] * ws2, fsum.data[4], fsum.data[5]);
fma2(make_float2(__int2float_rn(psum.data[6]), __int2float_rn(psum.data[7])), asy[i] * ws2, fsum.data[6], fsum.data[7]);
#pragma unroll
for (int row = 0; row < WARP_M; row++) {
packed_input pack;
// TODO: numCols not multiples of PACK_SIZE
if constexpr (fuse_glu) {
packed_raw_input raw;
raw.fill(half2_t(0, 0));
bool pred = row < maxRows && laneId * PACK_SIZE * 2 < maxCols;
if (pred) {
raw = load(reinterpret_cast<const packed_raw_input *>(input + row * stride + laneId * PACK_SIZE * 2));
}
#pragma unroll
for (int j = 0; j < PACK_SIZE; j++) {
pack[j] = raw[j].x * silu(raw[j].y);
}
} else {
pack.fill(half_t(0));
bool pred = row < maxRows && laneId * PACK_SIZE < maxCols;
if (pred) {
pack = load(reinterpret_cast<const packed_input *>(input + row * stride + laneId * PACK_SIZE));
}
}
store<true>(reinterpret_cast<packed_input *>(&mat[row][laneId * PACK_SIZE]), pack);
}
}
}
__syncwarp();
/**
* input: WARP_M of half (in shared memory, per-warp)
* output: [..., ASCALES_NUM_PACKS, ASCALES_VALID_LANES] in global memory, per-warp
*/
__device__ __forceinline__
static void pack_ascales(const half_t *input, packed_ascale_t *output) {
const int laneId = threadIdx.x % WARP_SIZE;
#pragma unroll
for (int j = 0; j < ASCALES_NUM_PACKS; j++) {
if (laneId < ASCALES_VALID_LANES) {
packed_ascale_t tmp;
#pragma unroll
for (int i = 0; i < ASCALES_PACK_SIZE; i += 2) {
tmp.data[i / 2].x = input[j * ASCALES_PACK_SIZE * WARP_SIZE + laneId / 8 * 8 * ASCALES_PACK_SIZE + laneId % 8 + i * 8];
tmp.data[i / 2].y = input[j * ASCALES_PACK_SIZE * WARP_SIZE + laneId / 8 * 8 * ASCALES_PACK_SIZE + laneId % 8 + (i + 1) * 8];
for (int m = 0; m < WARP_M_TILES; m++) {
for (int n = 0; n < WARP_N_TILES; n++) {
const int row = m * INSN_M + laneId % 16;
const int col = n * INSN_N + laneId / 16 * 8;
uint4 tmp;
ldmatrix(&mat[row][col], tmp);
*reinterpret_cast<uint4 *>(&out[m * WARP_N_TILES + n]) = tmp;
}
output[j * ASCALES_VALID_LANES + laneId] = tmp;
}
__syncwarp();
}
}
};
/**
* input: WARP_N of half (in shared memory, per-warp)
* output: [..., WSCALES_NUM_PACKS, WSCALES_VALID_LANES] in global memory, per-warp
* each warp quantizes a INSN_M * INSN_K (16 * 64) matrix
* input is per-warp (in global memory)
* output is per-thread (in regs)
* output_scale is per-warp (in shared memory)
* shmem must be at least INSN_M * INSN_K * sizeof(element) (16 * 64 * 0.5 = 512 Bytes)
* default to quantize activation, if quantize weight, input should be column-majored and output should be transposed ({x, y, z, w} = {x, z, y, w})
*/
__device__ __forceinline__
static void pack_wscales(const half_t *input, packed_wscale_t *output) {
static void quantize_w4a4_warp(const half_t *input, int stride, packed_act_t &output, half_t *output_scale, void *shmem) {
const int laneId = threadIdx.x % WARP_SIZE;
#pragma unroll
for (int j = 0; j < WSCALES_NUM_PACKS; j++) {
if (laneId < WSCALES_VALID_LANES) {
packed_wscale_t tmp;
#pragma unroll
for (int i = 0; i < WSCALES_PACK_SIZE; i += 2) {
tmp.data[i / 2] = *reinterpret_cast<const half2_t *>(&input[j * WSCALES_PACK_SIZE * WARP_SIZE + laneId / 4 * 4 * WSCALES_PACK_SIZE + laneId % 4 * 2 + i * 4]);
}
store(&output[j * WSCALES_VALID_LANES + laneId], tmp);
}
}
}
struct unpack_fpsum {
// +8 to prevent bank conflicts
using matrix_t = half_t[8][WARP_N + 8];
static constexpr int SHMEM_SIZE = sizeof(matrix_t);
static constexpr int PACK_SIZE = WARP_N / WARP_SIZE;
using pack_t = std::array<half_t, PACK_SIZE>;
constexpr int QUANTIZE_BITWIDTH = 4;
constexpr int QVALUE_MAX = 7; // 4 bit => [-8, 7]
// F (int rowId, pack_t &pack)
template<typename ...F>
__device__ __forceinline__
void operator()(fpsum_warp fpsum, half_t *output, int stride, void *shmem, F &&...plugins) {
const int laneId = threadIdx.x % WARP_SIZE;
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
// 1 lane = 1 pack
// 1 warp = 32 lanes = 32 packs = 1 packwarp
// a pack is {a0, ..., a7} in figure https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=ex2#mma-16864-a
// PACK_SIZE * 4 = INSN_K / 2
constexpr int PACK_SIZE = INSN_K / 8; // = 8 for 4bit
constexpr int NUM_PACKS_PER_ROW = INSN_K / PACK_SIZE;
constexpr int NUM_ROWS_PER_PACKWARP = PACK_SIZE * WARP_SIZE / INSN_K;
constexpr int NUM_PACKWARPS = INSN_M / NUM_ROWS_PER_PACKWARP;
using packed_input = std::array<half_t, PACK_SIZE>;
packed_input packs[NUM_PACKWARPS];
// pack_t reduce_tmp;
// constexpr bool enableReduce = !std::is_void_v<FuncReduce>;
// load
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
int rowId = i * NUM_ROWS_PER_PACKWARP + laneId / NUM_PACKS_PER_ROW;
int colId = laneId % NUM_PACKS_PER_ROW * PACK_SIZE;
packs[i] = load(reinterpret_cast<const packed_input *>(input + rowId * stride + colId));
}
// if constexpr (enableReduce) {
// reduce_tmp.fill(reduce_initval);
// // reduce_tmp = load<true>(reinterpret_cast<pack_t *>(&reduce_result[laneId * PACK_SIZE]));
// }
// auto doReduce = [&reduce_tmp](pack_t pack) {
// if constexpr (enableReduce) {
// for (int i = 0; i < PACK_SIZE; i++) {
// reduce_tmp[i] = FuncReduce()(reduce_tmp[i], pack[i]);
// }
// }
// };
// find max
half_t maxvalue[NUM_PACKWARPS];
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
maxvalue[i] = __habs(packs[i][0]);
#pragma unroll
for (int j = 1; j < PACK_SIZE; j++) {
maxvalue[i] = __hmax(maxvalue[i], __habs(packs[i][j]));
}
}
#pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) {
#pragma unroll
for (int j = 0; j < WARP_N_TILES; j++) {
packed_fpsum_t &fsum = fpsum[i * WARP_N_TILES + j];
int row = laneId / 4;
int col = laneId % 4 * 2 + j * INSN_N;
*reinterpret_cast<half2_t *>(&mat[row][col + 0]) = fsum.data[0];
*reinterpret_cast<half2_t *>(&mat[row][col + 8]) = fsum.data[2];
}
__syncwarp();
// warp reduce (max)
#pragma unroll
for (int mask = NUM_PACKS_PER_ROW / 2; mask > 0; mask /= 2) {
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
maxvalue[i] = __hmax(maxvalue[i], __shfl_xor_sync(~0, maxvalue[i], mask));
}
}
#pragma unroll
for (int row = 0; row < 8; row++) {
pack_t pack = *reinterpret_cast<pack_t *>(&mat[row][laneId * PACK_SIZE]);
// broadcast (max)
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
maxvalue[i] = __shfl_sync(~0, maxvalue[i], laneId / NUM_PACKS_PER_ROW * NUM_PACKS_PER_ROW);
}
// if constexpr (enableReduce) {
// doReduce(pack);
// }
// quantize
using matrix_t = uint32_t[INSN_M][NUM_PACKS_PER_ROW];
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
half_t scale = maxvalue[i] / half_t(QVALUE_MAX);
half_t rscale = half_t(QVALUE_MAX) / maxvalue[i];
if (laneId % NUM_PACKS_PER_ROW == 0) {
output_scale[i * NUM_ROWS_PER_PACKWARP + laneId / NUM_PACKS_PER_ROW] = scale;
}
(plugins(i * INSN_M + row, pack), ...);
uint32_t qpack = 0;
// #pragma unroll
// for (int j = 0; j < PACK_SIZE; j++) {
// int intvalue = __half2int_rn(packs[i][j] / scale);
// intvalue = clamp(intvalue, -QVALUE_MAX, QVALUE_MAX);
// qpack |= (intvalue & QUANTIZE_BITMASK) << (QUANTIZE_BITWIDTH * j);
// }
#pragma unroll
for (int j = 0; j < PACK_SIZE; j += 2) {
half2_t hval = __hmul2(half2_t(rscale, rscale), half2_t(packs[i][j], packs[i][j + 1]));
qpack |= quantize_float2<QUANTIZE_BITWIDTH, false>(half22float2(hval)) << (j * QUANTIZE_BITWIDTH);
}
mat[i * NUM_ROWS_PER_PACKWARP + laneId / NUM_PACKS_PER_ROW][laneId % NUM_PACKS_PER_ROW] = qpack;
}
__syncwarp();
// convert to imma format
int row = laneId % 16;
int col = laneId / 16 * 4;
ldmatrix(&mat[row][col], output);
store(reinterpret_cast<pack_t *>(&output[(i * INSN_M + row) * stride + laneId * PACK_SIZE]), pack);
}
__syncwarp();
__syncwarp();
}
#pragma unroll
for (int j = 0; j < WARP_N_TILES; j++) {
packed_fpsum_t &fsum = fpsum[i * WARP_N_TILES + j];
int row = laneId / 4;
int col = laneId % 4 * 2 + j * INSN_N;
*reinterpret_cast<half2_t *>(&mat[row][col + 0]) = fsum.data[1];
*reinterpret_cast<half2_t *>(&mat[row][col + 8]) = fsum.data[3];
}
__syncwarp();
// each thread block (1 warp) quantize WARP_M * WARP_K tile (32 * 64)
struct quantize_w4a4_act_kernel {
__device__
void operator()(const half_t *input, packed_act_t *output, packed_ascale_t *oscales, int K) {
const int laneId = threadIdx.x % WARP_SIZE;
#pragma unroll
for (int row = 0; row < 8; row++) {
pack_t pack = *reinterpret_cast<pack_t *>(&mat[row][laneId * PACK_SIZE]);
const int bm = blockIdx.x / (BLOCK_M / WARP_M);
const int bk = blockIdx.y;
const int warpId = blockIdx.x % (BLOCK_M / WARP_M);
// if constexpr (enableReduce) {
// doReduce(pack);
// }
const int row = blockIdx.x * WARP_M;
const int col = blockIdx.y * WARP_K;
(plugins(i * INSN_M + 8 + row, pack), ...);
__shared__ alignas(128) half_t oscale_shmem[WARP_M];
__shared__ alignas(128) uint8_t tmp_shmem[INSN_M * INSN_K / 2];
store(reinterpret_cast<pack_t *>(&output[(i * INSN_M + 8 + row) * stride + laneId * PACK_SIZE]), pack);
}
__syncwarp();
}
// if (enableReduce) {
// store<true>(reinterpret_cast<pack_t *>(&reduce_result[laneId * PACK_SIZE]), reduce_tmp);
// }
}
};
for (int tileId = 0; tileId < WARP_M_TILES; tileId++) {
packed_act_t tmpout;
quantize_w4a4_warp(
input + (row + tileId * INSN_M) * K + col,
K,
tmpout,
oscale_shmem + tileId * INSN_M,
tmp_shmem
);
struct EpilogueDefault {
// workaround for layout mismatch between host and device code
struct Arguments { size_t unused; };
store(&output[(((bm * K / WARP_K + bk) * NUM_WARPS + warpId) * WARP_M_TILES + tileId) * WARP_SIZE + laneId], tmpout);
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, half_t *out, int M, int N, int K, Arguments args) {
const int warpId = threadIdx.x / WARP_SIZE;
__shared__ alignas(128) uint8_t shmem[NUM_WARPS][ceilDiv(unpack_fpsum::SHMEM_SIZE, 128) * 128];
unpack_fpsum()(fpsum, out + warpId * WARP_M * N, N, shmem[warpId]);
// if (threadIdx.x == 0) {
// printf("Block (%d, %d) => offset = %d\n", blockIdx.x, blockIdx.y, (bm * K / WARP_K + bk) * NUM_WARPS + warpId);
// }
pack_ascales(oscale_shmem, &oscales[((bm * K / WARP_K + bk) * NUM_WARPS + warpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES]);
}
};
// each thread block (1 warp) quantize WARP_N * WARP_K tile (128 * 64)
struct quantize_w4a4_wgt_kernel {
__device__
void operator()(const half_t *input, packed_wgt_t *output, packed_wscale_t *oscales, int K) {
const int laneId = threadIdx.x % WARP_SIZE;
struct EpilogueNop {
struct Arguments { size_t unused; };
const int bn = blockIdx.x / (BLOCK_N / WARP_N);
const int bk = blockIdx.y;
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, half_t *out, int M, int N, int K, Arguments args) {
}
};
const int col = blockIdx.x * WARP_N;
const int row = blockIdx.y * WARP_K;
template<typename ...Epilogues>
struct EpilogueCombination {
using Arguments = std::tuple<typename Epilogues::Arguments...>;
__shared__ alignas(128) half_t oscale_shmem[WARP_N];
__shared__ alignas(128) uint8_t tmp_shmem[INSN_M * INSN_K / 2];
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, half_t *out, int M, int N, int K, Arguments args) {
// this function makes intellisense crashes :(
#if __INTELLISENSE__
__trap(); // should not happen when actually compiling
#else
std::tuple<Epilogues...> epilogues;
auto run = [&]<size_t idx>() {
std::get<idx>(epilogues).operator()(binfo, fpsum, out, M, N, K, std::get<idx>(args));
};
auto foreach = [&]<size_t ...Is>(std::index_sequence<Is...>) {
(run.template operator()<Is>(), ...);
};
foreach(std::make_index_sequence<sizeof...(Epilogues)>());
#endif
}
};
for (int tileId = 0; tileId < WARP_N_TILES; tileId++) {
packed_wgt_t tmpout;
quantize_w4a4_warp(
input + (col + tileId * INSN_N) * K + row,
K,
tmpout,
oscale_shmem + tileId * INSN_N,
tmp_shmem
);
};
std::swap(tmpout.y, tmpout.z);
class GEMM_W4A4 : public GEMMBase<GEMMConfig_W4A4> {
public:
template<bool ACT_UNSIGNED>
__device__ __forceinline__
static packed_psum_t mma(packed_act_t act, packed_wgt_t wgt) {
packed_psum_t psum;
store(&output[((bn * K / WARP_K + bk) * WARP_N_TILES + tileId) * WARP_SIZE + laneId], tmpout);
}
if constexpr (!ACT_UNSIGNED) {
asm volatile(
"mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(psum.data[0]), "=r"(psum.data[1]), "=r"(psum.data[2]), "=r"(psum.data[3])
:
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.x), "r"(wgt.y),
"r"(0), "r"(0), "r"(0), "r"(0)
// "r"(psum.data[0]), "r"(psum.data[1]), "r"(psum.data[2]), "r"(psum.data[3])
);
asm volatile(
"mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(psum.data[4]), "=r"(psum.data[5]), "=r"(psum.data[6]), "=r"(psum.data[7])
:
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.z), "r"(wgt.w),
"r"(0), "r"(0), "r"(0), "r"(0)
// "r"(psum.data[4]), "r"(psum.data[5]), "r"(psum.data[6]), "r"(psum.data[7])
);
pack_wscales(oscale_shmem, &oscales[(bn * K / WARP_K + bk) * WSCALES_NUM_PACKS * WSCALES_VALID_LANES]);
}
};
if constexpr (ACT_UNSIGNED) {
asm volatile(
"mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(psum.data[0]), "=r"(psum.data[1]), "=r"(psum.data[2]), "=r"(psum.data[3])
:
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.x), "r"(wgt.y),
"r"(0), "r"(0), "r"(0), "r"(0)
// "r"(psum.data[0]), "r"(psum.data[1]), "r"(psum.data[2]), "r"(psum.data[3])
);
asm volatile(
"mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(psum.data[4]), "=r"(psum.data[5]), "=r"(psum.data[6]), "=r"(psum.data[7])
:
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.z), "r"(wgt.w),
"r"(0), "r"(0), "r"(0), "r"(0)
// "r"(psum.data[4]), "r"(psum.data[5]), "r"(psum.data[6]), "r"(psum.data[7])
);
}
return psum;
template<bool ACT_UNSIGNED, typename T>
__device__ __forceinline__
static void compute(act_warp A, wgt_warp W, ascale_warp ascale, wscale_warp wscale, T &fpsum) {
apply_scales([&](int i, int j) {
return mma<ACT_UNSIGNED>(A[i], W[j]);
}, ascale, wscale, fpsum);
}
// template<bool si>
template<bool use_unsigned>
__device__ __forceinline__
static void quantize_w4a4_from_fpsum_warp(const packed_fpsum_t (&fpsum)[INSN_K / INSN_N], packed_act_t &output, half_t *output_scale) {
static void checkNan(fpsum_warp fpsum, const char *info = "") {
#if ENABLE_NAN_CHECK
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
constexpr float QVALUE_MAX_SIGNED = 7.0f;
constexpr float QVALUE_MAX_UNSIGNED = 15.0f;
constexpr float RECPI_QVALUE_MAX_SIGNED = 1 / QVALUE_MAX_SIGNED;
constexpr float RECPI_QVALUE_MAX_UNSIGNED = 1 / QVALUE_MAX_UNSIGNED;
constexpr float QVALUE_MAX = use_unsigned ? QVALUE_MAX_UNSIGNED : QVALUE_MAX_SIGNED;
constexpr float RECPI_QVALUE_MAX = use_unsigned ? RECPI_QVALUE_MAX_UNSIGNED : RECPI_QVALUE_MAX_SIGNED;
// constexpr int QUANTIZE_BITMASK = 0xf;
// 0 for row 0-7; 1 for row 8-15
half2_t input[2][INSN_K / INSN_N * 2];
#pragma unroll
for (int i = 0; i < INSN_K / INSN_N; i++) {
input[0][i * 2 + 0] = fpsum[i].data[0];
input[0][i * 2 + 1] = fpsum[i].data[2];
input[1][i * 2 + 0] = fpsum[i].data[1];
input[1][i * 2 + 1] = fpsum[i].data[3];
}
half_t maxvalue[2];
maxvalue[0] = 0;
maxvalue[1] = 0;
#pragma unroll
for (int i = 0; i < INSN_K / INSN_M * 2; i++) {
half2_t abs0 = __habs2(input[0][i]);
half2_t abs1 = __habs2(input[1][i]);
maxvalue[0] = __hmax(maxvalue[0], __hmax(abs0.x, abs0.y));
maxvalue[1] = __hmax(maxvalue[1], __hmax(abs1.x, abs1.y));
}
#pragma unroll
for (int mask = 2; mask > 0; mask /= 2) {
maxvalue[0] = __hmax(maxvalue[0], __shfl_xor_sync(~0, maxvalue[0], mask));
maxvalue[1] = __hmax(maxvalue[1], __shfl_xor_sync(~0, maxvalue[1], mask));
}
maxvalue[0] = __shfl_sync(~0, maxvalue[0], laneId / 4 * 4);
maxvalue[1] = __shfl_sync(~0, maxvalue[1], laneId / 4 * 4);
float scale[2];
// scale[0] = float(maxvalue[0]) / QVALUE_MAX;
// scale[1] = float(maxvalue[1]) / QVALUE_MAX;
scale[0] = float(maxvalue[0]) * RECPI_QVALUE_MAX;
scale[1] = float(maxvalue[1]) * RECPI_QVALUE_MAX;
if (laneId % 4 == 0) {
output_scale[laneId / 4] = half_t(scale[0]);
output_scale[laneId / 4 + 8] = half_t(scale[1]);
}
float rscale[2];
// rscale[0] = QVALUE_MAX / float(maxvalue[0]);
// rscale[1] = QVALUE_MAX / float(maxvalue[1]);
rscale[0] = cuda_frcp(scale[0]);
rscale[1] = cuda_frcp(scale[1]);
uint32_t qpacks[2][INSN_K / INSN_M * 2];
#pragma unroll
for (int i = 0; i < INSN_K / INSN_M * 2; i++) {
#pragma unroll
for (int j = 0; j < 2; j++) {
// half2_t hval = __hmul2(input[j][i], half2_t(rscale[j], rscale[j]));
// float2 fval = half22float2(hval);
float2 fval = half22float2(input[j][i]) * make_float2(rscale[j], rscale[j]);
qpacks[j][i] = quantize_float2<4, use_unsigned>(fval) << (laneId % 4 * 8);
}
}
// 2 * 8 * 2 = 32 instructions => 256 cycles
#pragma unroll
for (int mask = 1; mask <= 2; mask *= 2) {
#pragma unroll
for (int i = 0; i < INSN_K / INSN_M * 2; i++) {
#pragma unroll
for (int j = 0; j < 2; j++) {
qpacks[j][i] |= __shfl_xor_sync(~0, qpacks[j][i], mask);
for (int i = 0; i < fpsum.size(); i++) {
for (int j = 0; j < 4; j++) {
bool abnormal = !isfinite((float)fpsum[i].data[j].x) || !isfinite((float)fpsum[i].data[j].y);
if (abnormal) {
printf("abnormal value detected at block.x=%d block.y=%d warpId=%d laneId=%d fpsum_warp (%s) i=%d j=%d data.x=%f data.y=%f\n",
blockIdx.x, blockIdx.y,
warpId, laneId,
info,
i, j,
(float)fpsum[i].data[j].x,
(float)fpsum[i].data[j].y
);
__trap();
}
}
}
// lane 0,1,2,3 / 4,5,6,7 / ... should have identical qpacks now
#pragma unroll
for (int i = 0; i < 4; i++) {
if (laneId % 4 == i) {
output.x = qpacks[0][0 + i];
output.y = qpacks[1][0 + i];
output.z = qpacks[0][4 + i];
output.w = qpacks[1][4 + i];
}
}
#endif
}
// loads act of [WARP_M, WARP_N] and stores to fpsum_warp
struct load_act_to_fpsum {
using matrix_t = half_t[WARP_M][WARP_N + 8];
static constexpr size_t SHMEM_SIZE = sizeof(matrix_t);
__device__ __forceinline__
void operator()(const half_t *input, int stride, fpsum_warp &out, void *shmem /*, const half_t *smooth_factor */) {
const int laneId = threadIdx.x % WARP_SIZE;
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
constexpr int PACK_SIZE = WARP_N / WARP_SIZE;
using packed_input = std::array<half_t, PACK_SIZE>;
// packed_input pack_smooth;
// if (smooth_factor) {
// pack_smooth = load(reinterpret_cast<const packed_input *>(input + laneId * PACK_SIZE));
// }
for (int row = 0; row < WARP_M; row++) {
auto pack = load(reinterpret_cast<const packed_input *>(input + row * stride + laneId * PACK_SIZE));
// if (smooth_factor) {
// for (int i = 0; i < PACK_SIZE; i++) {
// pack[i] *= pack_smooth[i];
// }
// }
store<true>(reinterpret_cast<packed_input *>(&mat[row][laneId * PACK_SIZE]), pack);
}
__syncwarp();
__device__ __forceinline__
static void checkNan(packed_f32psum_t fpsum, const char *info = "") {
#if ENABLE_NAN_CHECK
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
for (int m = 0; m < WARP_M_TILES; m++) {
for (int n = 0; n < WARP_N_TILES; n++) {
const int row = m * INSN_M + laneId % 16;
const int col = n * INSN_N + laneId / 16 * 8;
uint4 tmp;
ldmatrix(&mat[row][col], tmp);
*reinterpret_cast<uint4 *>(&out[m * WARP_N_TILES + n]) = tmp;
}
for (int j = 0; j < 8; j++) {
bool abnormal = !isfinite(fpsum.data[j]);
if (abnormal) {
printf("abnormal value detected at bm=%d bn=%d warpId=%d laneId=%d packed_f32psum_t (%s) j=%d data=%f\n",
blockIdx.x, blockIdx.y,
warpId, laneId,
info,
j,
fpsum.data[j]
);
__trap();
}
__syncwarp();
}
};
#endif
}
/**
* each warp quantizes a INSN_M * INSN_K (16 * 64) matrix
* input is per-warp (in global memory)
* output is per-thread (in regs)
* output_scale is per-warp (in shared memory)
* shmem must be at least INSN_M * INSN_K * sizeof(element) (16 * 64 * 0.5 = 512 Bytes)
* default to quantize activation, if quantize weight, input should be column-majored and output should be transposed ({x, y, z, w} = {x, z, y, w})
*/
__device__ __forceinline__
static void quantize_w4a4_warp(const half_t *input, int stride, packed_act_t &output, half_t *output_scale, void *shmem) {
static void checkNan(packed_fpsum_t fpsum, const char *info = "") {
#if ENABLE_NAN_CHECK
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
constexpr int QUANTIZE_BITWIDTH = 4;
constexpr int QVALUE_MAX = 7; // 4 bit => [-8, 7]
// 1 lane = 1 pack
// 1 warp = 32 lanes = 32 packs = 1 packwarp
// a pack is {a0, ..., a7} in figure https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=ex2#mma-16864-a
// PACK_SIZE * 4 = INSN_K / 2
constexpr int PACK_SIZE = INSN_K / 8; // = 8 for 4bit
constexpr int NUM_PACKS_PER_ROW = INSN_K / PACK_SIZE;
constexpr int NUM_ROWS_PER_PACKWARP = PACK_SIZE * WARP_SIZE / INSN_K;
constexpr int NUM_PACKWARPS = INSN_M / NUM_ROWS_PER_PACKWARP;
using packed_input = std::array<half_t, PACK_SIZE>;
packed_input packs[NUM_PACKWARPS];
// load
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
int rowId = i * NUM_ROWS_PER_PACKWARP + laneId / NUM_PACKS_PER_ROW;
int colId = laneId % NUM_PACKS_PER_ROW * PACK_SIZE;
packs[i] = load(reinterpret_cast<const packed_input *>(input + rowId * stride + colId));
}
// find max
half_t maxvalue[NUM_PACKWARPS];
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
maxvalue[i] = __habs(packs[i][0]);
#pragma unroll
for (int j = 1; j < PACK_SIZE; j++) {
maxvalue[i] = __hmax(maxvalue[i], __habs(packs[i][j]));
for (int j = 0; j < 4; j++) {
bool abnormal = !isfinite((float)fpsum.data[j].x) || !isfinite((float)fpsum.data[j].y);
if (abnormal) {
printf("abnormal value detected at bm=%d bn=%d warpId=%d laneId=%d packed_fpsum_t (%s) j=%d data.x=%f data.y=%f\n",
blockIdx.x, blockIdx.y,
warpId, laneId,
info,
j,
(float)fpsum.data[j].x,
(float)fpsum.data[j].y
);
__trap();
}
}
#endif
}
// warp reduce (max)
#pragma unroll
for (int mask = NUM_PACKS_PER_ROW / 2; mask > 0; mask /= 2) {
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
maxvalue[i] = __hmax(maxvalue[i], __shfl_xor_sync(~0, maxvalue[i], mask));
}
}
__device__ __forceinline__
static void checkNan(float data, const char *info = "") {
#if ENABLE_NAN_CHECK
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
// broadcast (max)
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
maxvalue[i] = __shfl_sync(~0, maxvalue[i], laneId / NUM_PACKS_PER_ROW * NUM_PACKS_PER_ROW);
bool abnormal = !isfinite(data);
if (abnormal) {
printf("abnormal value detected at bm=%d bn=%d warpId=%d laneId=%d packed_fpsum_t (%s) data=%f\n",
blockIdx.x, blockIdx.y,
warpId, laneId,
info,
data
);
__trap();
}
// quantize
using matrix_t = uint32_t[INSN_M][NUM_PACKS_PER_ROW];
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
half_t scale = maxvalue[i] / half_t(QVALUE_MAX);
half_t rscale = half_t(QVALUE_MAX) / maxvalue[i];
if (laneId % NUM_PACKS_PER_ROW == 0) {
output_scale[i * NUM_ROWS_PER_PACKWARP + laneId / NUM_PACKS_PER_ROW] = scale;
}
uint32_t qpack = 0;
// #pragma unroll
// for (int j = 0; j < PACK_SIZE; j++) {
// int intvalue = __half2int_rn(packs[i][j] / scale);
// intvalue = clamp(intvalue, -QVALUE_MAX, QVALUE_MAX);
// qpack |= (intvalue & QUANTIZE_BITMASK) << (QUANTIZE_BITWIDTH * j);
// }
#pragma unroll
for (int j = 0; j < PACK_SIZE; j += 2) {
half2_t hval = __hmul2(half2_t(rscale, rscale), half2_t(packs[i][j], packs[i][j + 1]));
qpack |= quantize_float2<QUANTIZE_BITWIDTH, false>(half22float2(hval)) << (j * QUANTIZE_BITWIDTH);
}
mat[i * NUM_ROWS_PER_PACKWARP + laneId / NUM_PACKS_PER_ROW][laneId % NUM_PACKS_PER_ROW] = qpack;
}
__syncwarp();
// convert to imma format
int row = laneId % 16;
int col = laneId / 16 * 4;
ldmatrix(&mat[row][col], output);
__syncwarp();
}
// each thread block (1 warp) quantize WARP_M * WARP_K tile (32 * 64)
struct quantize_w4a4_act_kernel {
__device__
void operator()(const half_t *input, packed_act_t *output, packed_ascale_t *oscales, int K) {
const int laneId = threadIdx.x % WARP_SIZE;
const int bm = blockIdx.x / (BLOCK_M / WARP_M);
const int bk = blockIdx.y;
const int warpId = blockIdx.x % (BLOCK_M / WARP_M);
const int row = blockIdx.x * WARP_M;
const int col = blockIdx.y * WARP_K;
__shared__ alignas(128) half_t oscale_shmem[WARP_M];
__shared__ alignas(128) uint8_t tmp_shmem[INSN_M * INSN_K / 2];
for (int tileId = 0; tileId < WARP_M_TILES; tileId++) {
packed_act_t tmpout;
quantize_w4a4_warp(
input + (row + tileId * INSN_M) * K + col,
K,
tmpout,
oscale_shmem + tileId * INSN_M,
tmp_shmem
);
store(&output[(((bm * K / WARP_K + bk) * NUM_WARPS + warpId) * WARP_M_TILES + tileId) * WARP_SIZE + laneId], tmpout);
}
// if (threadIdx.x == 0) {
// printf("Block (%d, %d) => offset = %d\n", blockIdx.x, blockIdx.y, (bm * K / WARP_K + bk) * NUM_WARPS + warpId);
// }
pack_ascales(oscale_shmem, &oscales[((bm * K / WARP_K + bk) * NUM_WARPS + warpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES]);
}
};
// each thread block (1 warp) quantize WARP_N * WARP_K tile (128 * 64)
struct quantize_w4a4_wgt_kernel {
__device__
void operator()(const half_t *input, packed_wgt_t *output, packed_wscale_t *oscales, int K) {
const int laneId = threadIdx.x % WARP_SIZE;
const int bn = blockIdx.x / (BLOCK_N / WARP_N);
const int bk = blockIdx.y;
const int col = blockIdx.x * WARP_N;
const int row = blockIdx.y * WARP_K;
__shared__ alignas(128) half_t oscale_shmem[WARP_N];
__shared__ alignas(128) uint8_t tmp_shmem[INSN_M * INSN_K / 2];
for (int tileId = 0; tileId < WARP_N_TILES; tileId++) {
packed_wgt_t tmpout;
quantize_w4a4_warp(
input + (col + tileId * INSN_N) * K + row,
K,
tmpout,
oscale_shmem + tileId * INSN_N,
tmp_shmem
);
std::swap(tmpout.y, tmpout.z);
store(&output[((bn * K / WARP_K + bk) * WARP_N_TILES + tileId) * WARP_SIZE + laneId], tmpout);
}
pack_wscales(oscale_shmem, &oscales[(bn * K / WARP_K + bk) * WSCALES_NUM_PACKS * WSCALES_VALID_LANES]);
}
};
template<bool ACT_UNSIGNED, typename T>
__device__ __forceinline__
static void compute(act_warp A, wgt_warp W, ascale_warp ascale, wscale_warp wscale, T &fpsum) {
apply_scales([&](int i, int j) {
return mma<ACT_UNSIGNED>(A[i], W[j]);
}, ascale, wscale, fpsum);
}
__device__ __forceinline__
static void checkNan(fpsum_warp fpsum, const char *info = "") {
#if ENABLE_NAN_CHECK
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
for (int i = 0; i < fpsum.size(); i++) {
for (int j = 0; j < 4; j++) {
bool abnormal = !isfinite((float)fpsum[i].data[j].x) || !isfinite((float)fpsum[i].data[j].y);
if (abnormal) {
printf("abnormal value detected at block.x=%d block.y=%d warpId=%d laneId=%d fpsum_warp (%s) i=%d j=%d data.x=%f data.y=%f\n",
blockIdx.x, blockIdx.y,
warpId, laneId,
info,
i, j,
(float)fpsum[i].data[j].x,
(float)fpsum[i].data[j].y
);
__trap();
}
}
}
#endif
}
__device__ __forceinline__
static void checkNan(packed_f32psum_t fpsum, const char *info = "") {
#if ENABLE_NAN_CHECK
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
for (int j = 0; j < 8; j++) {
bool abnormal = !isfinite(fpsum.data[j]);
if (abnormal) {
printf("abnormal value detected at bm=%d bn=%d warpId=%d laneId=%d packed_f32psum_t (%s) j=%d data=%f\n",
blockIdx.x, blockIdx.y,
warpId, laneId,
info,
j,
fpsum.data[j]
);
__trap();
}
}
#endif
}
__device__ __forceinline__
static void checkNan(packed_fpsum_t fpsum, const char *info = "") {
#if ENABLE_NAN_CHECK
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
for (int j = 0; j < 4; j++) {
bool abnormal = !isfinite((float)fpsum.data[j].x) || !isfinite((float)fpsum.data[j].y);
if (abnormal) {
printf("abnormal value detected at bm=%d bn=%d warpId=%d laneId=%d packed_fpsum_t (%s) j=%d data.x=%f data.y=%f\n",
blockIdx.x, blockIdx.y,
warpId, laneId,
info,
j,
(float)fpsum.data[j].x,
(float)fpsum.data[j].y
);
__trap();
}
}
#endif
}
__device__ __forceinline__
static void checkNan(float data, const char *info = "") {
#if ENABLE_NAN_CHECK
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
bool abnormal = !isfinite(data);
if (abnormal) {
printf("abnormal value detected at bm=%d bn=%d warpId=%d laneId=%d packed_fpsum_t (%s) data=%f\n",
blockIdx.x, blockIdx.y,
warpId, laneId,
info,
data
);
__trap();
}
#endif
}
#endif
}
// out: [M / BLOCK_M, N / BLOCK_N, NUM_WARPS, 1, NUM_M_TILES, NUM_N_TILES, WARP_SIZE] of fpsum_warp
template<typename Epilogue, bool ACT_UNSIGNED>
......@@ -1135,1313 +516,719 @@ public:
static void gemm_w4a4_block(
const BlockInfo binfo,
const packed_act_t *act,
const packed_wgt_t *wgt,
const packed_ascale_t *ascales,
const packed_wscale_t *wscales,
// const packed_wscale_t *bias_ptr,
half_t *out,
int M, int N, int K,
Epilogue::Arguments epilogueArgs,
bool alwaysfalse)
{
constexpr int NUM_STAGES = 2;
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
act_warp A[NUM_STAGES]; // 8
wgt_warp W[NUM_STAGES]; // 32
ascale_warp ascale[NUM_STAGES]; // 1
wscale_warp wscale[NUM_STAGES]; // 2
fpsum_warp fpsum; // 64
// load_wscale<true>(wscales, wscale[0], true);
// load_wscale<false>(wscales, wscale[1], true);
// load_wscale<false>(wscales, wscale[2], true);
for (int k = 0; k < NUM_STAGES - 1; k++) {
load_act(act, k, K, A[k], true);
load_wgt(wgt, k, K, W[k], true);
load_ascale(ascales, k, M, ascale[k], true);
load_wscale(wscales, k, N, wscale[k], true);
}
for (auto &pack : fpsum) {
#if 1
for (int i = 0; i < 4; i++) {
pack.data[i].x = 0;
pack.data[i].y = 0;
}
#else
for (int i = 0; i < 8; i++) {
pack.data[i] = 0;
}
#endif
}
int dummy = 0;
for (int k1 = 0; k1 < K / WARP_K; k1 += NUM_STAGES) {
#pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) {
int nextk = k1 + k2 + NUM_STAGES - 1;
int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
bool pred = nextk < K / WARP_K;
load_act(act, nextk, K, A[idx], pred);
load_wgt(wgt, nextk, K, W[idx], pred);
load_ascale(ascales, nextk, M, ascale[idx], pred);
load_wscale(wscales, nextk, N, wscale[idx], pred);
// load_wscale<false>(wscales, wscale[idx], pred);
// __syncthreads();
// if (alwaysfalse) {
// dummy = clock();
// }
compute<ACT_UNSIGNED>(A[k2], W[k2], ascale[k2], wscale[k2], fpsum);
if (alwaysfalse) {
dummy = clock();
}
// asm volatile ("membar.cta;");
}
}
unused_var(dummy, alwaysfalse);
#if 0
auto f16psum = packed_fp32_to_fp16(fpsum);
#else
auto f16psum = fpsum;
#endif
CHECK_NAN(f16psum, "f16psum");
Epilogue()(binfo, f16psum, out, M, N, K, epilogueArgs);
}
template<bool FUSE_GELU, bool USE_UNSIGNED>
struct EpilogueQuantize {
struct Arguments {
packed_act_t *qout;
packed_ascale_t *oscales;
half_t shift_value;
const packed_wscale_t *smooth_factor;
};
static constexpr int NUM_PACKS = INSN_K / INSN_N;
static constexpr int NUM_GROUPS = WARP_N_TILES / NUM_PACKS;
__device__ __forceinline__
void apply_quantize(fpsum_warp fpsum, half_t *out, int M, int N, int K, packed_act_t *qout, packed_ascale_t *oscales, half_t shift_value, const packed_wscale_t *smooth_factor) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
__shared__ half_t oscale_shmem[NUM_WARPS][WARP_M];
wscale_warp smooth;
load_wscale(smooth_factor, 0, N, smooth, true);
#pragma unroll
for (int group = 0; group < NUM_GROUPS; group++) {
#pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) {
packed_fpsum_t tmp[NUM_PACKS];
#pragma unroll
for (int j = 0; j < NUM_PACKS; j++) {
half2_t ws1 = broadcast_wscale(smooth, (group * NUM_PACKS + j) * 4, laneId);
half2_t ws2 = broadcast_wscale(smooth, (group * NUM_PACKS + j) * 4 + 2, laneId);
#pragma unroll
for (int k = 0; k < 4; k++) {
half2_t src = fpsum[i * WARP_N_TILES + group * NUM_PACKS + j].data[k];
half2_t &dst = tmp[j].data[k];
// dst.x = gelu(src.x);
// dst.y = gelu(src.y);
if constexpr (FUSE_GELU) {
dst = gelu_half2(src);
} else {
dst = src;
}
dst += half2_t(shift_value, shift_value);
// dst = src;
}
auto h2div = [](half2_t a, half2_t b) ALWAYSINLINE {
float2 af = half22float2(a);
float2 bf = half22float2(b);
float2 of;
of.x = __fdividef(af.x, bf.x);
of.y = __fdividef(af.y, bf.y);
return float22half2<half2_t>(of);
};
tmp[j].data[0] = h2div(tmp[j].data[0], ws1);
tmp[j].data[1] = h2div(tmp[j].data[1], ws1);
tmp[j].data[2] = h2div(tmp[j].data[2], ws2);
tmp[j].data[3] = h2div(tmp[j].data[3], ws2);
}
packed_act_t qresult;
quantize_w4a4_from_fpsum_warp<USE_UNSIGNED>(tmp, qresult, &oscale_shmem[warpId][i * INSN_M]);
store(&qout[((group * NUM_WARPS + warpId) * WARP_M_TILES + i) * WARP_SIZE + laneId], qresult);
}
__syncwarp();
pack_ascales(&oscale_shmem[warpId][0], &oscales[(group * NUM_WARPS + warpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES]);
__syncwarp();
}
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, half_t *out, int M, int N, int K, Arguments args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
apply_quantize(
fpsum, out, M, N, K,
args.qout + (bm * N / WARP_K + bn * NUM_GROUPS) * NUM_WARPS * WARP_M_TILES * WARP_SIZE,
args.oscales + (bm * N / WARP_K + bn * NUM_GROUPS) * NUM_WARPS * ASCALES_NUM_PACKS * ASCALES_VALID_LANES,
args.shift_value,
args.smooth_factor + bn * WSCALES_NUM_PACKS * WSCALES_VALID_LANES
);
}
};
// using EpilogueQuantizeFuseGelu = EpilogueQuantize<true>;
template<int rank = 32>
struct Lora {
static_assert(rank % 16 == 0);
static constexpr int LORA_RANK = rank;
static constexpr int LORA_M_TILES = WARP_M / 16;
static constexpr int LORA_R_TILES = LORA_RANK / 16;
static constexpr int LORA_N_TILES = WARP_N / 16;
static_assert(LORA_M_TILES == WARP_M_TILES);
static_assert(LORA_N_TILES == WARP_N_TILES);
// lora_down: [WARP_M, WARP_N] x [WARP_N, R] (row-wise) = [WARP_M, R]
// lora up: [WARP_M, R] x [WARP_N, R] (col-wise) = [WARP_M, WARP_N]
// we use fp32 for lora activation since there's no bf16 reduction in sm_89 :(
using lora_act_warp = std::array<packed_f32psum_t, LORA_M_TILES * LORA_R_TILES>;
using lora_act16_warp = std::array<packed_fpsum_t, LORA_M_TILES * LORA_R_TILES>;
using lora_wgt_warp = std::array<packed_fpsum_t, LORA_N_TILES * LORA_R_TILES>;
using scale_t = std::array<float, LORA_R_TILES>;
// lora_wgt: [N / 16, LORA_R_TILES, WARP_SIZE] of packed_fpsum_t
__device__ __forceinline__
static lora_wgt_warp load_lora_wgt(const packed_fpsum_t *ptr) {
const int laneId = threadIdx.x % WARP_SIZE;
const packed_fpsum_t *ptr_lane = ptr + laneId;
lora_wgt_warp result;
#if 0
#pragma unroll
for (int n = 0; n < LORA_N_TILES; n++) {
#pragma unroll
for (int r = 0; r < LORA_R_TILES; r++) {
result[n * LORA_R_TILES + r] = load(ptr_lane + (n * LORA_R_TILES + r) * WARP_SIZE);
}
}
#else
unrolled_loop<LORA_N_TILES>([&]<int n>() {
unrolled_loop<LORA_R_TILES>([&]<int r>() {
constexpr int offset = (n * LORA_R_TILES + r) * WARP_SIZE;
result[n * LORA_R_TILES + r] = load(ptr_lane + offset);
});
});
#endif
return result;
}
// lora_act: [M / BLOCK_M, NUM_WARPS, LORA_M_TILES, LORA_R_TILES, 8, WARP_SIZE] of float
__device__ __forceinline__
static lora_act16_warp load_lora_act(const float *ptr, scale_t scales) {
const int laneId = threadIdx.x % WARP_SIZE;
const float *ptrlane = ptr + laneId;
lora_act16_warp result;
#if 0
#pragma unroll
for (int i = 0; i < LORA_M_TILES * LORA_R_TILES; i++) {
packed_f32psum_t tmp;
#pragma unroll
for (int j = 0; j < 8; j++) {
const int offset = i * 8 * WARP_SIZE + j * WARP_SIZE;
tmp.data[j] = ptrlane[offset];
// tmp.data[j] = ptr[i * 8 * WARP_SIZE + j * WARP_SIZE + laneId];
}
CHECK_NAN(tmp, "load_lora_act.tmp");
result[i] = packed_fp32_to_fp16(tmp);
}
#else
unrolled_loop<LORA_M_TILES>([&]<int m>() {
unrolled_loop<LORA_R_TILES>([&]<int r>{
constexpr int i = m * LORA_R_TILES + r;
packed_f32psum_t tmp;
unrolled_loop<8>([&]<int j>() {
constexpr int offset = i * 8 * WARP_SIZE + j * WARP_SIZE;
tmp.data[j] = ptrlane[offset] * scales[r];
});
CHECK_NAN(tmp, "load_lora_act.tmp");
result[i] = packed_fp32_to_fp16(tmp);
});
});
#endif
return result;
}
// no vector reduction in sm_89 :(
__device__ __forceinline__
static void reduce_lora_act(float *ptr, lora_act_warp val) {
const int laneId = threadIdx.x % WARP_SIZE;
float *ptrlane = ptr + laneId;
// #pragma unroll
// for (int i = 0; i < LORA_M_TILES * LORA_R_TILES; i++) {
// #pragma unroll
// for (int j = 0; j < 8; j++) {
// int offset = i * 8 * WARP_SIZE + j * WARP_SIZE;
// reduce_add(&ptrlane[offset], val[i].data[j]);
// }
// }
unrolled_loop<LORA_M_TILES * LORA_R_TILES>([&]<int i>() {
unrolled_loop<8>([&]<int j>() {
constexpr int offset = i * 8 * WARP_SIZE + j * WARP_SIZE;
reduce_add(&ptrlane[offset], val[i].data[j]);
});
});
}
// __device__ __forceinline__
// static void reduce_lora_act(float *ptr, lora_act_warp val, int m) {
// const int laneId = threadIdx.x % WARP_SIZE;
// float *ptrlane = ptr + laneId + m * LORA_R_TILES * 8 * WARP_SIZE;
// unrolled_loop<LORA_R_TILES>([&]<int r>() {
// unrolled_loop<8>([&]<int j>() {
// constexpr int offset = r * 8 * WARP_SIZE + j * WARP_SIZE;
// reduce_add(&ptrlane[offset], val[m * LORA_R_TILES + r].data[j]);
// });
// });
// }
struct EpilogueLoraUp {
struct Arguments {
const float *lora_act;
const packed_fpsum_t *lora_wgt_up;
scale_t scales;
};
__device__ __forceinline__
static void apply_lora_up(fpsum_warp &fpsum, half_t *out, int M, int N, int K, const float *act, const packed_fpsum_t *wgt, const scale_t scales, const BlockInfo binfo) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
if constexpr (rank > 0) {
lora_act16_warp lora_act = load_lora_act(act + warpId * (LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), scales);
lora_wgt_warp lora_wgt = load_lora_wgt(wgt);
for (int m = 0; m < LORA_M_TILES; m++) {
for (int n = 0; n < LORA_N_TILES; n++) {
packed_f32psum_t psum = packed_fp16_to_fp32(fpsum[m * WARP_N_TILES + n]);
for (int r = 0; r < LORA_R_TILES; r++) {
CHECK_NAN(lora_act[m * LORA_R_TILES + r], "lora_act");
CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "lora_wgt");
psum = mma_f16xf16_f32(lora_act[m * LORA_R_TILES + r], lora_wgt[n * LORA_R_TILES + r], psum);
}
fpsum[m * WARP_N_TILES + n] = packed_fp32_to_fp16(psum);
}
}
}
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, half_t *out, int M, int N, int K, Arguments args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
CHECK_NAN(fpsum, "fpsum");
if constexpr (rank == 0) {
return;
}
apply_lora_up(
fpsum, out, M, N, K,
args.lora_act + bm * (NUM_WARPS * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE),
args.lora_wgt_up + bn * (BLOCK_N / 16) * LORA_R_TILES * WARP_SIZE,
args.scales,
binfo // for debug
);
CHECK_NAN(fpsum, "fpsum");
}
};
struct EpilogueLoraDown {
struct Arguments {
const packed_fpsum_t *lora_wgt_down;
float *lora_act;
};
__device__ __forceinline__
static void apply_lora_down(fpsum_warp &fpsum, half_t *out, int M, int N, int K, float *act, const packed_fpsum_t *wgt) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
if constexpr (rank > 0) {
lora_act_warp lora_act;
lora_act.fill(packed_f32psum_t::zeros());
lora_wgt_warp lora_wgt = load_lora_wgt(wgt);
// clock_t dummy = 0;
#pragma unroll
for (int m = 0; m < LORA_M_TILES; m++) {
#pragma unroll
for (int n = 0; n < LORA_N_TILES; n++) {
#pragma unroll
for (int r = 0; r < LORA_R_TILES; r++) {
auto &psum = lora_act[m * LORA_R_TILES + r];
CHECK_NAN(fpsum[m * WARP_N_TILES + n], "apply_lora_down.fpsum");
CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "apply_lora_down.lora_wgt");
psum = mma_f16xf16_f32(fpsum[m * WARP_N_TILES + n], lora_wgt[n * LORA_R_TILES + r], psum);
CHECK_NAN(psum, "apply_lora_down.psum");
}
}
// reduce_lora_act(act + warpId * (LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), lora_act, m);
// if (alwaysfalse) {
// dummy = clock();
// }
}
reduce_lora_act(act + warpId * (LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), lora_act);
// unused_var(dummy, alwaysfalse);
}
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, half_t *out, int M, int N, int K, Arguments args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
if constexpr (rank == 0) {
return;
}
apply_lora_down(
fpsum, out, M, N, K,
args.lora_act + bm * (NUM_WARPS * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE),
args.lora_wgt_down + bn * (BLOCK_N / 16) * LORA_R_TILES * WARP_SIZE
);
}
};
struct quantize_w4a4_fuse_lora_kernel {
static constexpr size_t SHMEM_PER_WARP = ceilDiv<size_t>(load_act_to_fpsum::SHMEM_SIZE, 128) * 128;
static constexpr size_t SHMEM_SIZE = SHMEM_PER_WARP * NUM_WARPS;
struct Arguments {
const half_t *input;
const packed_wscale_t *smooth_factor;
packed_act_t *output;
packed_ascale_t *oscales;
const packed_fpsum_t *lora_wgt_down;
float *lora_act;
int M, N;
};
__device__ __forceinline__
void operator()(Arguments args)
{
const BlockInfo binfo = {
.bm = (int)blockIdx.x,
.bn = (int)blockIdx.y,
.numBlocksM = (int)gridDim.x,
.numBlocksN = (int)gridDim.y,
};
const int bm = binfo.bm;
const int bn = binfo.bn;
const int warpId = threadIdx.x / WARP_SIZE;
const int m_offset = bm * BLOCK_M + warpId * WARP_M;
const int n_offset = bn * BLOCK_N;
extern __shared__ uint8_t shmem[];
fpsum_warp fpsum;
// FIXME: smooth factor should change to EpilogueQuantize
load_act_to_fpsum()(
args.input + m_offset * args.N + n_offset,
args.N,
fpsum,
shmem + warpId * SHMEM_PER_WARP
// args.smooth_factor ? args.smooth_factor + n_offset : nullptr
);
CHECK_NAN(fpsum, "fpsum");
// for (int i = 0; i < 16; i++) {
// printf("bm=%d bn=%d warp=%d lane=%d fpsum[%d][0:1]=%f %f\n",
// bm, bn, warpId, threadIdx.x % WARP_SIZE, i,
// (float)fpsum[i].data[0].x, (float)fpsum[i].data[0].y);
// }
EpilogueLoraDown()(binfo, fpsum, nullptr, args.M, args.N, 0, typename EpilogueLoraDown::Arguments{
.lora_wgt_down = args.lora_wgt_down,
.lora_act = args.lora_act,
});
EpilogueQuantize<false, false>()(binfo, fpsum, nullptr, args.M, args.N, 0, EpilogueQuantize<false, false>::Arguments{
.qout = args.output,
.oscales = args.oscales,
.shift_value = 0,
.smooth_factor = args.smooth_factor
});
}
};
};
struct EpilogueBias {
struct Arguments {
const packed_wscale_t *bias; // [N / BLOCK_N, WSCALES_NUM_PACKS, WSCALES_VALID_LANES] of packed_wscale_t
};
__device__ __forceinline__
void apply_bias(fpsum_warp &fpsum, half_t *out, int M, int N, int K, const packed_wscale_t *bias) {
const int laneId = threadIdx.x % WARP_SIZE;
// if (laneId == 0) {
// printf("block.x=%d block.y=%d warpId=%d bias=%p\n", blockIdx.x, blockIdx.y, threadIdx.x / WARP_SIZE, bias);
// }
wscale_warp b;
load_wscale(bias, 0, N, b, true);
for (int j = 0; j < WARP_N_TILES; j++) {
half2_t b1 = broadcast_wscale(b, j * 4, laneId);
half2_t b2 = broadcast_wscale(b, j * 4 + 2, laneId);
for (int i = 0; i < WARP_M_TILES; i++) {
auto &fsum = fpsum[i * WARP_N_TILES + j];
fsum.data[0] = __hadd2(fsum.data[0], b1);
fsum.data[1] = __hadd2(fsum.data[1], b1);
fsum.data[2] = __hadd2(fsum.data[2], b2);
fsum.data[3] = __hadd2(fsum.data[3], b2);
}
}
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, half_t *out, int M, int N, int K, Arguments args) {
const int bn = binfo.bn;
apply_bias(
fpsum, out, M, N, K,
args.bias + bn * WSCALES_NUM_PACKS * WSCALES_VALID_LANES
);
}
};
struct EpilogueGelu {
struct Arguments { size_t unused; };
// static constexpr float SHIFT_VALUE = 0.171875f;
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, half_t *out, int M, int N, int K, Arguments args) {
#pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) {
#pragma unroll
for (int j = 0; j < WARP_N_TILES; j++) {
#pragma unroll
for (int k = 0; k < 4; k++) {
half2_t &data = fpsum[i * WARP_N_TILES + j].data[k];
data = gelu_half2(data);
// data = __hadd2(data, half2_t(SHIFT_VALUE, SHIFT_VALUE));
}
}
}
}
};
// template<int PoolSize = 128>
struct EpilogueQKVProj {
struct Arguments {
half_t *pool_out; // [M / PoolSize, N]
const float *rotary_emb; // [M, HEAD_DIM / 2, ROTARY_EMB_NUM_ELEMENTS]
const half_t *rmsnorm_weight_q; // [HEAD_DIM]
const half_t *rmsnorm_weight_k; // [HEAD_DIM]
float epsilon;
};
static constexpr int HEAD_DIM = 128;
static constexpr int NUM_HEADS_PER_WARP = WARP_N / HEAD_DIM;
static constexpr int PoolSize = 128;
static constexpr int NUM_WARPS_PER_POOL = PoolSize / WARP_M;
static constexpr int NUM_POOLS_PER_BLOCK = BLOCK_M / PoolSize;
static constexpr int ROTARY_EMB_NUM_ELEMENTS = 2; // 1 for theta, 2 for {sin, cos} pair
const packed_wgt_t *wgt,
const packed_ascale_t *ascales,
const packed_wscale_t *wscales,
// const packed_wscale_t *bias_ptr,
// half_t *out,
int M, int N, int K,
Epilogue::Arguments epilogueArgs,
bool alwaysfalse)
{
constexpr int NUM_STAGES = 2;
__device__ __forceinline__
static void apply(fpsum_warp fpsum, half_t *out, int M, int N, int K, half_t *pool_out, const float *rotary_emb, const half_t *rmsnorm_weight, float epsilon) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
__shared__ alignas(128) uint8_t shmem[NUM_WARPS][ceilDiv(unpack_fpsum::SHMEM_SIZE, 128) * 128];
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
constexpr int PACK_SIZE = unpack_fpsum::PACK_SIZE;
using pack_t = unpack_fpsum::pack_t;
act_warp A[NUM_STAGES]; // 8
wgt_warp W[NUM_STAGES]; // 32
ascale_warp ascale[NUM_STAGES]; // 1
wscale_warp wscale[NUM_STAGES]; // 2
fpsum_warp fpsum; // 64
using pack_rope_t = std::array<float, PACK_SIZE / 2 * ROTARY_EMB_NUM_ELEMENTS>;
constexpr int LANES_PER_HEAD = HEAD_DIM / PACK_SIZE;
// load_wscale<true>(wscales, wscale[0], true);
// load_wscale<false>(wscales, wscale[1], true);
// load_wscale<false>(wscales, wscale[2], true);
pack_t reduce_tmp;
__shared__ alignas(128) pack_t pool[NUM_WARPS];
for (int k = 0; k < NUM_STAGES - 1; k++) {
load_act(act, k, K, A[k], true);
load_wgt(wgt, k, K, W[k], true);
load_ascale(ascales, k, M, ascale[k], true);
load_wscale(wscales, k, N, wscale[k], true);
}
// load rmsnorm scales
pack_t rms;
if (laneId < LANES_PER_HEAD) {
rms = load(reinterpret_cast<const pack_t *>(&rmsnorm_weight[laneId * PACK_SIZE]));
for (auto &pack : fpsum) {
#if 1
for (int i = 0; i < 4; i++) {
pack.data[i].x = 0;
pack.data[i].y = 0;
}
if constexpr (LANES_PER_HEAD < WARP_SIZE) {
for (int i = 0; i < PACK_SIZE; i++) {
rms[i] = __shfl_sync(~0, rms[i], laneId % LANES_PER_HEAD);
}
#else
for (int i = 0; i < 8; i++) {
pack.data[i] = 0;
}
#endif
}
int dummy = 0;
const float *rotary_emb_base_addr = &rotary_emb[(warpId * WARP_M) * HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS + laneId * PACK_SIZE / 2 * ROTARY_EMB_NUM_ELEMENTS];
CHECK_NAN(fpsum, "fpsum");
unpack_fpsum()(fpsum, out + warpId * WARP_M * N, N, shmem[warpId], [&](int rowId, pack_t &pack) ALWAYSINLINE {
// load rope
pack_rope_t rope;
if (laneId < LANES_PER_HEAD) {
// freq = load(reinterpret_cast<pack_freq_t *>(&freqs_cis[(warpId * WARP_M + rowId) * HEAD_DIM * 2 + laneId * PACK_SIZE * 2]));
rope = load(reinterpret_cast<const pack_rope_t *>(&rotary_emb_base_addr[rowId * HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS]));
}
if constexpr (LANES_PER_HEAD < WARP_SIZE) {
for (int i = 0; i < rope.size(); i++) {
rope[i] = __shfl_sync(~0, rope[i], laneId % LANES_PER_HEAD);
}
}
for (int k1 = 0; k1 < K / WARP_K; k1 += NUM_STAGES) {
#pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) {
int nextk = k1 + k2 + NUM_STAGES - 1;
int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
bool pred = nextk < K / WARP_K;
load_act(act, nextk, K, A[idx], pred);
load_wgt(wgt, nextk, K, W[idx], pred);
load_ascale(ascales, nextk, M, ascale[idx], pred);
load_wscale(wscales, nextk, N, wscale[idx], pred);
// load_wscale<false>(wscales, wscale[idx], pred);
// rmsnorm
float sqrsum = 0.0f;
for (int i = 0; i < PACK_SIZE; i++) {
sqrsum += float(pack[i]) * float(pack[i]);
CHECK_NAN(sqrsum, "sqrsum");
}
#pragma unroll
for (int mask = LANES_PER_HEAD / 2; mask > 0; mask /= 2) {
sqrsum += __shfl_xor_sync(~0, sqrsum, mask);
}
sqrsum /= HEAD_DIM;
float coef = cuda_frsqrt(sqrsum + epsilon);
CHECK_NAN(coef, "coef");
// __syncthreads();
// if (alwaysfalse) {
// dummy = clock();
// }
for (int i = 0; i < PACK_SIZE; i++) {
pack[i] *= coef * float(rms[i]);
compute<ACT_UNSIGNED>(A[k2], W[k2], ascale[k2], wscale[k2], fpsum);
CHECK_NAN(rms[i], "rms.wgt");
CHECK_NAN(pack[i], "rms.out");
if (alwaysfalse) {
dummy = clock();
}
#if 1
// rope
for (int i = 0; i < PACK_SIZE; i += 2) {
float2 pack2 = half22float2(half2_t(pack[i], pack[i+1]));
CHECK_NAN(freq[i].x, "rope.freq");
CHECK_NAN(freq[i].y, "rope.freq");
CHECK_NAN(freq[i+1].x, "rope.freq");
CHECK_NAN(freq[i+1].y, "rope.freq");
// half2_t tmp = __hmul2(freq[i], pack2);
// tmp = __hfma2(freq[i+1], pack2, tmp);
// pack[i] = tmp.x;
// pack[i+1] = tmp.y;
// asm volatile ("membar.cta;");
}
}
// printf("block.x=%d block.y=%d warpId=%d rowId=%d (%d) freqs = %f %f %f %f\n",
// blockIdx.x, blockIdx.y, warpId, rowId,
// blockIdx.x * BLOCK_M + warpId * WARP_M + rowId,
// (float)freq[i].x, (float)freq[i].y, (float)freq[i+1].x, (float)freq[i+1].y
// );
// __trap();
unused_var(dummy, alwaysfalse);
// half2_t tmp = __hmul2(half2_t(pack2.x, pack2.x), freq[i]);
// tmp = __hfma2(half2_t(pack2.y, pack2.y), freq[i+1], tmp);
// pack[i] = tmp.x;
// pack[i+1] = tmp.y;
#if 0
auto f16psum = packed_fp32_to_fp16(fpsum);
#else
auto f16psum = fpsum;
#endif
float sin, cos;
CHECK_NAN(f16psum, "f16psum");
if constexpr (ROTARY_EMB_NUM_ELEMENTS == 1) {
sin = cuda_sin(rope[i / 2]);
cos = cuda_cos(rope[i / 2]);
}
if constexpr (ROTARY_EMB_NUM_ELEMENTS == 2) {
sin = rope[i];
cos = rope[i+1];
}
Epilogue()(binfo, f16psum, M, N, K, epilogueArgs);
}
// pack[i] = pack2.x * freq[i].x + pack2.y * freq[i].y;
// pack[i+1] = pack2.x * freq[i+1].x + pack2.y * freq[i+1].y;
template<bool FUSE_GELU, bool USE_UNSIGNED>
struct EpilogueQuantize {
struct Arguments {
packed_act_t *qout;
packed_ascale_t *oscales;
pack[i] = half_t(pack2.x * cos - pack2.y * sin);
pack[i+1] = half_t(pack2.x * sin + pack2.y * cos);
half_t shift_value;
const packed_wscale_t *smooth_factor;
};
CHECK_NAN(pack[i], "rope.out");
CHECK_NAN(pack[i+1], "rope.out");
}
#endif
static constexpr int NUM_PACKS = INSN_K / INSN_N;
static constexpr int NUM_GROUPS = WARP_N_TILES / NUM_PACKS;
// mean pool
for (int i = 0; i < PACK_SIZE; i++) {
reduce_tmp[i] += pack[i];
}
});
__device__ __forceinline__
void apply_quantize(fpsum_warp fpsum, int M, int N, int K, packed_act_t *qout, packed_ascale_t *oscales, half_t shift_value, const packed_wscale_t *smooth_factor) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
if (!pool_out) {
return;
}
__shared__ half_t oscale_shmem[NUM_WARPS][WARP_M];
store<true>(&pool[warpId], reduce_tmp);
__syncthreads();
wscale_warp smooth;
load_wscale(smooth_factor, 0, N, smooth, true);
if (warpId < NUM_POOLS_PER_BLOCK) {
const int row = warpId * NUM_WARPS_PER_POOL;
reduce_tmp = load<true>(&pool[row]);
#pragma unroll
for (int group = 0; group < NUM_GROUPS; group++) {
#pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) {
packed_fpsum_t tmp[NUM_PACKS];
for (int i = 1; i < NUM_WARPS_PER_POOL; i++) {
pack_t pack = load<true>(&pool[row + i]);
for (int j = 0; j < PACK_SIZE; j++) {
reduce_tmp[j] += pack[j];
}
}
for (int j = 0; j < PACK_SIZE; j++) {
reduce_tmp[j] /= PoolSize;
}
store(reinterpret_cast<pack_t *>(pool_out + warpId * N), reduce_tmp);
}
__syncthreads();
}
#pragma unroll
for (int j = 0; j < NUM_PACKS; j++) {
half2_t ws1 = broadcast_wscale(smooth, (group * NUM_PACKS + j) * 4, laneId);
half2_t ws2 = broadcast_wscale(smooth, (group * NUM_PACKS + j) * 4 + 2, laneId);
#pragma unroll
for (int k = 0; k < 4; k++) {
half2_t src = fpsum[i * WARP_N_TILES + group * NUM_PACKS + j].data[k];
half2_t &dst = tmp[j].data[k];
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, half_t *out, int M, int N, int K, Arguments args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
// dst.x = gelu(src.x);
// dst.y = gelu(src.y);
if constexpr (FUSE_GELU) {
dst = gelu_half2(src);
} else {
dst = src;
}
assert(binfo.numBlocksN % 3 == 0);
const bool is_q = bn < binfo.numBlocksN / 3;
const bool is_k = !is_q && bn < binfo.numBlocksN / 3 * 2;
dst += half2_t(shift_value, shift_value);
// dst = src;
}
if (is_q || is_k) {
apply(
fpsum, out, M, N, K,
args.pool_out ? args.pool_out + bm * BLOCK_M / PoolSize * N : nullptr,
args.rotary_emb + bm * BLOCK_M * (HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS),
is_q ? args.rmsnorm_weight_q : args.rmsnorm_weight_k,
args.epsilon
);
} else {
EpilogueDefault()(binfo, fpsum, out, M, N, K, {});
}
}
};
// auto h2div = [](half2_t a, half2_t b) ALWAYSINLINE {
// float2 af = half22float2(a);
// float2 bf = half22float2(b);
// float2 of;
// of.x = __fdividef(af.x, bf.x);
// of.y = __fdividef(af.y, bf.y);
// return float22half2<half2_t>(of);
// };
template<typename Epilogue, bool ACT_UNSIGNED>
struct gemm_w4a4_kernel {
__device__
void operator()(
const packed_act_t *act,
const packed_wgt_t *wgt,
const packed_ascale_t *ascales,
const packed_wscale_t *wscales,
// const packed_wscale_t *bias,
half_t *out,
int M, int N, int K,
Epilogue::Arguments epilogueArgs,
bool swapBlockXY,
bool alwaysfalse)
{
// printf("Device sizeof(args) = %d", (int)sizeof(epilogueArgs));
tmp[j].data[0] = h2div(tmp[j].data[0], ws1);
tmp[j].data[1] = h2div(tmp[j].data[1], ws1);
tmp[j].data[2] = h2div(tmp[j].data[2], ws2);
tmp[j].data[3] = h2div(tmp[j].data[3], ws2);
}
BlockInfo binfo = {
.bm = (int)blockIdx.x,
.bn = (int)blockIdx.y,
.numBlocksM = (int)gridDim.x,
.numBlocksN = (int)gridDim.y,
};
packed_act_t qresult;
quantize_w4a4_from_fpsum_warp<USE_UNSIGNED>(tmp, qresult, &oscale_shmem[warpId][i * INSN_M]);
store(&qout[((group * NUM_WARPS + warpId) * WARP_M_TILES + i) * WARP_SIZE + laneId], qresult);
}
if (swapBlockXY) {
std::swap(binfo.bm, binfo.bn);
std::swap(binfo.numBlocksM, binfo.numBlocksN);
__syncwarp();
pack_ascales(&oscale_shmem[warpId][0], &oscales[(group * NUM_WARPS + warpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES]);
__syncwarp();
}
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, Arguments args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
// bool fusequant = !out;
gemm_w4a4_block<Epilogue, ACT_UNSIGNED>(
binfo,
act + bm * (K / WARP_K) * NUM_WARPS * WARP_M_TILES * WARP_SIZE,
wgt + bn * (K / WARP_K) * WARP_N_TILES * WARP_SIZE,
ascales + bm * (K / WARP_K) * NUM_WARPS * ASCALES_NUM_PACKS * ASCALES_VALID_LANES,
wscales + bn * (K / WARP_K) * WSCALES_NUM_PACKS * WSCALES_VALID_LANES,
// bias ? bias + bn * WSCALES_NUM_PACKS * WSCALES_VALID_LANES : nullptr,
out + (bm * BLOCK_M * N) + bn * BLOCK_N,
// out + (bm * N / BLOCK_N + bn) * NUM_WARPS * WARP_M_TILES * WARP_N_TILES * WARP_SIZE,
M, N, K,
epilogueArgs,
alwaysfalse
apply_quantize(
fpsum, M, N, K,
args.qout + (bm * N / WARP_K + bn * NUM_GROUPS) * NUM_WARPS * WARP_M_TILES * WARP_SIZE,
args.oscales + (bm * N / WARP_K + bn * NUM_GROUPS) * NUM_WARPS * ASCALES_NUM_PACKS * ASCALES_VALID_LANES,
args.shift_value,
args.smooth_factor + bn * WSCALES_NUM_PACKS * WSCALES_VALID_LANES
);
}
};
};
// using EpilogueQuantizeFuseGelu = EpilogueQuantize<true>;
class GEMM_W8A8 : public GEMMBase<GEMMConfig_W8A8> {
public:
using psum_warp = std::array<packed_psum_t, WARP_M_TILES * WARP_N_TILES>;
template<int rank = 32>
struct Lora {
static_assert(rank % 16 == 0);
__device__ __forceinline__
static packed_psum_t mma(packed_act_t act, packed_wgt_t wgt, packed_psum_t psum) {
// packed_psum_t psum;
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(psum.data[0]), "=r"(psum.data[1]), "=r"(psum.data[2]), "=r"(psum.data[3])
:
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.x), "r"(wgt.y),
// "r"(0), "r"(0), "r"(0), "r"(0)
"r"(psum.data[0]), "r"(psum.data[1]), "r"(psum.data[2]), "r"(psum.data[3])
);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(psum.data[4]), "=r"(psum.data[5]), "=r"(psum.data[6]), "=r"(psum.data[7])
:
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.z), "r"(wgt.w),
// "r"(0), "r"(0), "r"(0), "r"(0)
"r"(psum.data[4]), "r"(psum.data[5]), "r"(psum.data[6]), "r"(psum.data[7])
);
return psum;
}
static constexpr int LORA_RANK = rank;
static constexpr int LORA_M_TILES = WARP_M / 16;
static constexpr int LORA_R_TILES = LORA_RANK / 16;
static constexpr int LORA_N_TILES = WARP_N / 16;
__device__ __forceinline__
static void compute(act_warp A, wgt_warp W, psum_warp &psum) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
static_assert(LORA_M_TILES == WARP_M_TILES);
static_assert(LORA_N_TILES == WARP_N_TILES);
// lora_down: [WARP_M, WARP_N] x [WARP_N, R] (row-wise) = [WARP_M, R]
// lora up: [WARP_M, R] x [WARP_N, R] (col-wise) = [WARP_M, WARP_N]
// we use fp32 for lora activation since there's no bf16 reduction in sm_89 :(
#pragma unroll
for (int j = 0; j < WARP_N_TILES; j++) {
#pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) {
psum[i * WARP_N_TILES + j] = mma(A[i], W[j], psum[i * WARP_N_TILES + j]);
using lora_act_warp = std::array<packed_f32psum_t, LORA_M_TILES * LORA_R_TILES>;
using lora_act16_warp = std::array<packed_fpsum_t, LORA_M_TILES * LORA_R_TILES>;
using lora_wgt_warp = std::array<packed_fpsum_t, LORA_N_TILES * LORA_R_TILES>;
using scale_t = std::array<float, LORA_R_TILES>;
// lora_wgt: [N / 16, LORA_R_TILES, WARP_SIZE] of packed_fpsum_t
__device__ __forceinline__
static lora_wgt_warp load_lora_wgt(const packed_fpsum_t *ptr) {
const int laneId = threadIdx.x % WARP_SIZE;
const packed_fpsum_t *ptr_lane = ptr + laneId;
lora_wgt_warp result;
#if 0
#pragma unroll
for (int n = 0; n < LORA_N_TILES; n++) {
#pragma unroll
for (int r = 0; r < LORA_R_TILES; r++) {
result[n * LORA_R_TILES + r] = load(ptr_lane + (n * LORA_R_TILES + r) * WARP_SIZE);
}
}
#else
unrolled_loop<LORA_N_TILES>([&]<int n>() {
unrolled_loop<LORA_R_TILES>([&]<int r>() {
constexpr int offset = (n * LORA_R_TILES + r) * WARP_SIZE;
result[n * LORA_R_TILES + r] = load(ptr_lane + offset);
});
});
#endif
return result;
}
}
/**
* each warp quantizes a INSN_M * INSN_K (16 * 32) matrix
* input is per-warp (in global memory / shared memory)
* rscales is per-warp (in shared memory)
* output is per-thread (in regs)
* shmem must be at least INSN_M * (INSN_K * sizeof(element) + 16) (16 * 32 = 512 Bytes)
* default to quantize activation, if quantize weight, input should be column-majored and output should be transposed ({x, y, z, w} = {x, z, y, w})
*/
template<bool input_shmem = false>
__device__ __forceinline__
static void quantize_w8a8_warp(const half *input, const half *rscales, int stride, packed_act_t &output, void *shmem) {
const int laneId = threadIdx.x % WARP_SIZE;
// lora_act: [M / BLOCK_M, NUM_WARPS, LORA_M_TILES, LORA_R_TILES, 8, WARP_SIZE] of float
__device__ __forceinline__
static lora_act16_warp load_lora_act(const float *ptr, scale_t scales) {
const int laneId = threadIdx.x % WARP_SIZE;
constexpr int QUANTIZE_BITWIDTH = 8;
// constexpr int QUANTIZE_BITMASK = 0xff;
// constexpr int QVALUE_MAX = 128; // 4 bit => [-128, 127]
const float *ptrlane = ptr + laneId;
// 1 lane = 1 pack
// 1 warp = 32 lanes = 32 packs = 1 packwarp
// a pack is {a0, ..., a7} in figure https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=ex2#mma-16864-a
// PACK_SIZE * 4 = INSN_K / 2
constexpr int PACK_SIZE = INSN_K / 8; // = 4 for 8bit
constexpr int NUM_PACKS_PER_ROW = INSN_K / PACK_SIZE;
constexpr int NUM_ROWS_PER_PACKWARP = PACK_SIZE * WARP_SIZE / INSN_K;
constexpr int NUM_PACKWARPS = INSN_M / NUM_ROWS_PER_PACKWARP;
using packed_input = std::array<half, PACK_SIZE>;
lora_act16_warp result;
#if 0
#pragma unroll
for (int i = 0; i < LORA_M_TILES * LORA_R_TILES; i++) {
packed_f32psum_t tmp;
#pragma unroll
for (int j = 0; j < 8; j++) {
const int offset = i * 8 * WARP_SIZE + j * WARP_SIZE;
tmp.data[j] = ptrlane[offset];
// tmp.data[j] = ptr[i * 8 * WARP_SIZE + j * WARP_SIZE + laneId];
}
CHECK_NAN(tmp, "load_lora_act.tmp");
result[i] = packed_fp32_to_fp16(tmp);
}
#else
unrolled_loop<LORA_M_TILES>([&]<int m>() {
unrolled_loop<LORA_R_TILES>([&]<int r>{
constexpr int i = m * LORA_R_TILES + r;
packed_f32psum_t tmp;
unrolled_loop<8>([&]<int j>() {
constexpr int offset = i * 8 * WARP_SIZE + j * WARP_SIZE;
tmp.data[j] = ptrlane[offset] * scales[r];
});
CHECK_NAN(tmp, "load_lora_act.tmp");
result[i] = packed_fp32_to_fp16(tmp);
});
});
#endif
return result;
}
// no vector reduction in sm_89 :(
__device__ __forceinline__
static void reduce_lora_act(float *ptr, lora_act_warp val) {
const int laneId = threadIdx.x % WARP_SIZE;
packed_input packs[NUM_PACKWARPS];
float *ptrlane = ptr + laneId;
// load
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
int rowId = i * NUM_ROWS_PER_PACKWARP + laneId / NUM_PACKS_PER_ROW;
int colId = laneId % NUM_PACKS_PER_ROW * PACK_SIZE;
packs[i] = load<input_shmem>(reinterpret_cast<const packed_input *>(input + rowId * stride + colId));
// #pragma unroll
// for (int i = 0; i < LORA_M_TILES * LORA_R_TILES; i++) {
// #pragma unroll
// for (int j = 0; j < 8; j++) {
// int offset = i * 8 * WARP_SIZE + j * WARP_SIZE;
// reduce_add(&ptrlane[offset], val[i].data[j]);
// }
// }
unrolled_loop<LORA_M_TILES * LORA_R_TILES>([&]<int i>() {
unrolled_loop<8>([&]<int j>() {
constexpr int offset = i * 8 * WARP_SIZE + j * WARP_SIZE;
reduce_add(&ptrlane[offset], val[i].data[j]);
});
});
}
// quantize
using matrix_t = uint32_t[INSN_M][NUM_PACKS_PER_ROW];
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
const int row = i * NUM_ROWS_PER_PACKWARP + laneId / NUM_PACKS_PER_ROW;
const int col = laneId % NUM_PACKS_PER_ROW;
// __device__ __forceinline__
// static void reduce_lora_act(float *ptr, lora_act_warp val, int m) {
// const int laneId = threadIdx.x % WARP_SIZE;
half rscale = rscales[row];
// float *ptrlane = ptr + laneId + m * LORA_R_TILES * 8 * WARP_SIZE;
uint32_t qpack = 0;
#pragma unroll
for (int j = 0; j < PACK_SIZE; j += 2) {
half2 hval = __hmul2(make_half2(rscale, rscale), make_half2(packs[i][j], packs[i][j + 1]));
qpack |= quantize_float2<QUANTIZE_BITWIDTH, false>(__half22float2(hval)) << (j * QUANTIZE_BITWIDTH);
// unrolled_loop<LORA_R_TILES>([&]<int r>() {
// unrolled_loop<8>([&]<int j>() {
// constexpr int offset = r * 8 * WARP_SIZE + j * WARP_SIZE;
// reduce_add(&ptrlane[offset], val[m * LORA_R_TILES + r].data[j]);
// });
// });
// }
struct EpilogueLoraUp {
struct Arguments {
const float *lora_act;
const packed_fpsum_t *lora_wgt_up;
scale_t scales;
};
__device__ __forceinline__
static void apply_lora_up(fpsum_warp &fpsum, int M, int N, int K, const float *act, const packed_fpsum_t *wgt, const scale_t scales, const BlockInfo binfo) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
if constexpr (rank > 0) {
lora_act16_warp lora_act = load_lora_act(act + warpId * (LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), scales);
lora_wgt_warp lora_wgt = load_lora_wgt(wgt);
for (int m = 0; m < LORA_M_TILES; m++) {
for (int n = 0; n < LORA_N_TILES; n++) {
packed_f32psum_t psum = packed_fp16_to_fp32(fpsum[m * WARP_N_TILES + n]);
for (int r = 0; r < LORA_R_TILES; r++) {
CHECK_NAN(lora_act[m * LORA_R_TILES + r], "lora_act");
CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "lora_wgt");
psum = mma_f16xf16_f32(lora_act[m * LORA_R_TILES + r], lora_wgt[n * LORA_R_TILES + r], psum);
}
fpsum[m * WARP_N_TILES + n] = packed_fp32_to_fp16(psum);
}
}
}
}
mat[row][col] = qpack;
}
__syncwarp();
// convert to imma format
int row = laneId % 16;
int col = laneId / 16 * 4;
ldmatrix(&mat[row][col], output);
__syncwarp();
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, Arguments args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
/**
* each warp finds absmax from a row
*/
template<bool fuse_glu = false>
__device__ __forceinline__
static half findmax_warp(const half_t *input, half_t *output_shmem, int K, bool alwaysfalse) {
const int laneId = threadIdx.x % WARP_SIZE;
CHECK_NAN(fpsum, "fpsum");
using packed_input = std::array<half2_t, 4>;
using packed_gated_input = std::array<half_t, 4>;
if constexpr (rank == 0) {
return;
}
constexpr int PACK_SIZE = sizeof(packed_input) / sizeof(half_t);
constexpr int NUM_STAGES = 2;
apply_lora_up(
fpsum, M, N, K,
args.lora_act + bm * (NUM_WARPS * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE),
args.lora_wgt_up + bn * (BLOCK_N / 16) * LORA_R_TILES * WARP_SIZE,
args.scales,
binfo // for debug
);
CHECK_NAN(fpsum, "fpsum");
}
};
struct EpilogueLoraDown {
struct Arguments {
const packed_fpsum_t *lora_wgt_down;
float *lora_act;
};
half2_t maxvalue2 = { 0, 0 };
packed_input pack[NUM_STAGES];
__device__ __forceinline__
static void apply_lora_down(fpsum_warp &fpsum, int M, int N, int K, float *act, const packed_fpsum_t *wgt) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
#pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) {
const int idx = k * PACK_SIZE * WARP_SIZE + laneId * PACK_SIZE;
if (idx < K) {
pack[k] = load(reinterpret_cast<const packed_input *>(&input[idx]));
} else {
pack[k].fill(make_half2(0, 0));
}
}
if constexpr (rank > 0) {
lora_act_warp lora_act;
lora_act.fill(packed_f32psum_t::zeros());
// int dummy = 0;
lora_wgt_warp lora_wgt = load_lora_wgt(wgt);
// FIXME: pipeline does not work
// TODO: store quantized data to shmem (instead of half)
// clock_t dummy = 0;
for (int k1 = 0; k1 < ceilDiv(K, PACK_SIZE * WARP_SIZE); k1 += NUM_STAGES) {
#pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) {
const int nextidx = (k1 + k2 + NUM_STAGES - 1) * PACK_SIZE * WARP_SIZE + laneId * PACK_SIZE;
const int nextk2 = (k2 + NUM_STAGES - 1) % NUM_STAGES;
#pragma unroll
for (int m = 0; m < LORA_M_TILES; m++) {
#pragma unroll
for (int n = 0; n < LORA_N_TILES; n++) {
#pragma unroll
for (int r = 0; r < LORA_R_TILES; r++) {
auto &psum = lora_act[m * LORA_R_TILES + r];
if (nextidx < K) {
pack[nextk2] = load(reinterpret_cast<const packed_input *>(&input[nextidx]));
} else {
pack[nextk2].fill(make_half2(0, 0));
}
CHECK_NAN(fpsum[m * WARP_N_TILES + n], "apply_lora_down.fpsum");
CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "apply_lora_down.lora_wgt");
packed_input &p = pack[k2];
psum = mma_f16xf16_f32(fpsum[m * WARP_N_TILES + n], lora_wgt[n * LORA_R_TILES + r], psum);
if constexpr (fuse_glu) {
packed_gated_input gated;
CHECK_NAN(psum, "apply_lora_down.psum");
}
}
// reduce_lora_act(act + warpId * (LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), lora_act, m);
#pragma unroll
for (int j = 0; j < p.size(); j++) {
gated[j] = p[j].x * gelu_half(p[j].y);
p[j].x = gated[j];
p[j].y = 0;
// if (alwaysfalse) {
// dummy = clock();
// }
}
int idx = (k1 + k2) * PACK_SIZE / 2 * WARP_SIZE + laneId * PACK_SIZE / 2;
if (idx < K) {
store<true>(reinterpret_cast<packed_gated_input *>(&output_shmem[idx]), gated);
}
}
reduce_lora_act(act + warpId * (LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), lora_act);
#pragma unroll
for (int j = 0; j < p.size(); j++) {
maxvalue2 = __hmax2(maxvalue2, __habs2(p[j]));
// unused_var(dummy, alwaysfalse);
}
}
}
// unused_var(dummy, alwaysfalse);
}
#pragma unroll
for (int mask = 32 / 2; mask > 0; mask /= 2) {
maxvalue2 = __hmax2(maxvalue2, __shfl_xor_sync(~0, maxvalue2, mask));
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, Arguments args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
return __hmax(maxvalue2.x, maxvalue2.y);
}
if constexpr (rank == 0) {
return;
}
// each thread block quantize WARP_M * K tile (32 * K)
template<bool fuse_glu>
struct quantize_w8a8_act_kernel {
static bool check(int M, int K) {
const int K2 = fuse_glu ? K / 2 : K;
return M % WARP_M == 0 && K2 % WARP_K == 0;
}
static dim3 gridSize(int M, int K) {
return dim3(M / WARP_M);
}
static dim3 blockSize(int M, int K) {
return dim3(NUM_WARPS * 32);
}
static size_t smemSize(int M, int K) {
if constexpr (!fuse_glu) {
return 0;
apply_lora_down(
fpsum, M, N, K,
args.lora_act + bm * (NUM_WARPS * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE),
args.lora_wgt_down + bn * (BLOCK_N / 16) * LORA_R_TILES * WARP_SIZE
);
}
const int K2 = fuse_glu ? K / 2 : K;
return INSN_M * K2 * sizeof(half_t);
}
};
__device__
void operator()(const half *input, packed_act_t *output, packed_ascale_t *oscales, int K, bool alwaysfalse) {
// for quantize kernel
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
template<bool fuse_glu>
struct quantize_w4a4_fuse_lora_kernel {
static constexpr size_t SHMEM_PER_WARP = ceilDiv<size_t>(load_act_to_fpsum<fuse_glu>::SHMEM_SIZE, 128) * 128;
static constexpr size_t SHMEM_SIZE = SHMEM_PER_WARP * NUM_WARPS;
struct Arguments {
const half_t *input;
const packed_wscale_t *smooth_factor;
packed_act_t *output;
packed_ascale_t *oscales;
const packed_fpsum_t *lora_wgt_down;
float *lora_act;
const int numWarps = blockDim.x / WARP_SIZE;
// aligned to BLOCK_M and BLOCK_N
int M, N; // N should be the actual K in the next GEMM (needs /2 if fuse_glu)
// the actual M and N (no need to /2 if fuse_glu)
int actualM, actualN;
};
// for GEMM kernel
const int bm = blockIdx.x / (BLOCK_M / WARP_M);
const int gemmWarpId = blockIdx.x % (BLOCK_M / WARP_M);
__device__ __forceinline__
void operator()(Arguments args)
{
const BlockInfo binfo = {
.bm = (int)blockIdx.x,
.bn = (int)blockIdx.y,
.numBlocksM = (int)gridDim.x,
.numBlocksN = (int)gridDim.y,
};
const int bm = binfo.bm;
const int bn = binfo.bn;
const int warpId = threadIdx.x / WARP_SIZE;
__shared__ alignas(128) half_t oscale_shmem[WARP_M];
__shared__ alignas(128) half_t rscale_shmem[WARP_M];
__shared__ alignas(128) uint8_t tmp_shmem[NUM_WARPS][512];
const int m_offset = bm * BLOCK_M + warpId * WARP_M;
const int n_offset = bn * BLOCK_N * (fuse_glu ? 2 : 1);
extern __shared__ uint8_t shmem[];
const int K2 = fuse_glu ? K / 2 : K;
fpsum_warp fpsum;
// INSN_M * K2
extern __shared__ uint8_t smem[];
half_t *shmem = reinterpret_cast<half_t *>(smem);
load_act_to_fpsum<fuse_glu>()(
args.input + m_offset * args.actualN + n_offset,
args.actualN,
args.actualM - m_offset,
args.actualN - n_offset,
fpsum,
shmem + warpId * SHMEM_PER_WARP
// args.smooth_factor ? args.smooth_factor + n_offset : nullptr
);
for (int tileM = 0; tileM < WARP_M_TILES; tileM++) {
CHECK_NAN(fpsum, "fpsum");
// for (int i = 0; i < 16; i++) {
// printf("bm=%d bn=%d warp=%d lane=%d fpsum[%d][0:1]=%f %f\n",
// bm, bn, warpId, threadIdx.x % WARP_SIZE, i,
// (float)fpsum[i].data[0].x, (float)fpsum[i].data[0].y);
// }
for (int i = warpId; i < INSN_M; i += numWarps) {
const int rowLocal = tileM * INSN_M + i;
const int rowGlobal = blockIdx.x * WARP_M + rowLocal;
EpilogueLoraDown()(binfo, fpsum, args.M, args.N, 0, typename EpilogueLoraDown::Arguments{
.lora_wgt_down = args.lora_wgt_down,
.lora_act = args.lora_act,
});
half maxv = findmax_warp<fuse_glu>(input + rowGlobal * K, shmem + i * K2, K, alwaysfalse);
oscale_shmem[rowLocal] = maxv / half(127);
rscale_shmem[rowLocal] = half(127) / maxv;
}
__syncthreads();
EpilogueQuantize<false, false>()(binfo, fpsum, args.M, args.N, 0, typename EpilogueQuantize<false, false>::Arguments{
.qout = args.output,
.oscales = args.oscales,
.shift_value = 0,
.smooth_factor = args.smooth_factor
});
for (int bk = warpId; bk < K2 / WARP_K; bk += numWarps) {
const int rowLocal = tileM * INSN_M;
const int rowGlobal = blockIdx.x * WARP_M + rowLocal;
const int col = bk * WARP_K;
packed_act_t tmpout;
if constexpr (fuse_glu) {
quantize_w8a8_warp<true>(
shmem + col,
rscale_shmem + rowLocal,
K2,
tmpout,
&tmp_shmem[warpId]
);
} else {
quantize_w8a8_warp<false>(
input + rowGlobal * K + col,
rscale_shmem + rowLocal,
K,
tmpout,
&tmp_shmem[warpId]
);
}
store(&output[(((bm * K2 / WARP_K + bk) * NUM_WARPS + gemmWarpId) * WARP_M_TILES + tileM) * WARP_SIZE + laneId], tmpout);
}
__syncthreads();
}
// [M / BLOCK_M, 1, NUM_WARPS, ASCALES_NUM_PACKS, ASCALES_VALID_LANES] of packed_ascale_t
pack_ascales(oscale_shmem, &oscales[(bm * NUM_WARPS + gemmWarpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES]);
}
};
};
struct EpilogueGelu {
struct Arguments { size_t unused; };
__device__ __forceinline__
static gated_fpsum_warp apply_glu(fpsum_warp fpsum) {
gated_fpsum_warp result;
for (int i = 0; i < WARP_M_TILES; i++) {
for (int j = 0; j < WARP_N_TILES; j++) {
for (int k = 0; k < 4; k++) {
half_t &dst = result[i * WARP_N_TILES + j].data[k];
half2_t src = fpsum[i * WARP_N_TILES + j].data[k];
dst = src.x * gelu_half(src.y);
}
}
}
return result;
}
// static constexpr float SHIFT_VALUE = 0.171875f;
template<typename F>
__device__ __forceinline__
static fpsum_warp apply_act(fpsum_warp fpsum, F func) {
fpsum_warp result;
for (int i = 0; i < WARP_M_TILES; i++) {
for (int j = 0; j < WARP_N_TILES; j++) {
for (int k = 0; k < 4; k++) {
half2_t &dst = result[i * WARP_N_TILES + j].data[k];
half2_t src = fpsum[i * WARP_N_TILES + j].data[k];
dst.x = func(src.x);
dst.y = func(src.y);
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, Arguments args) {
#pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) {
#pragma unroll
for (int j = 0; j < WARP_N_TILES; j++) {
#pragma unroll
for (int k = 0; k < 4; k++) {
half2_t &data = fpsum[i * WARP_N_TILES + j].data[k];
data = gelu_half2(data);
// data = __hadd2(data, half2_t(SHIFT_VALUE, SHIFT_VALUE));
}
}
}
}
return result;
}
};
// template<int PoolSize = 128>
struct EpilogueQKVProj {
struct Arguments {
half_t *out;
int actualM, actualN;
static constexpr int unpack_gated_fpsum_shmem_size = INSN_M * (WARP_N / 2 + 8) * sizeof(half_t);
__device__ __forceinline__
static void unpack_gated_fpsum(gated_fpsum_warp fpsum, half_t *output, int stride, void *shmem) {
const int laneId = threadIdx.x % WARP_SIZE;
half_t *pool_out; // [M / PoolSize, N]
const float *rotary_emb; // [M, HEAD_DIM / 2, ROTARY_EMB_NUM_ELEMENTS]
const half_t *rmsnorm_weight_q; // [HEAD_DIM]
const half_t *rmsnorm_weight_k; // [HEAD_DIM]
float epsilon;
};
constexpr int PACK_SIZE = WARP_N / 2 / WARP_SIZE;
using pack_t = std::array<half_t, PACK_SIZE>;
static constexpr int HEAD_DIM = 128;
static constexpr int NUM_HEADS_PER_WARP = WARP_N / HEAD_DIM;
// +8 to prevent bank conflicts
using matrix_t = half_t[INSN_M][WARP_N / 2 + 8];
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
static constexpr int PoolSize = 128;
static constexpr int NUM_WARPS_PER_POOL = PoolSize / WARP_M;
static constexpr int NUM_POOLS_PER_BLOCK = BLOCK_M / PoolSize;
for (int i = 0; i < WARP_M_TILES; i++) {
for (int j = 0; j < WARP_N_TILES; j++) {
packed_gated_fpsum_t &fsum = fpsum[i * WARP_N_TILES + j];
int row = laneId / 4;
int col = laneId % 4 + j * INSN_N / 2;
*reinterpret_cast<half_t *>(&mat[row][col + 0]) = fsum.data[0];
*reinterpret_cast<half_t *>(&mat[row][col + 4]) = fsum.data[2];
*reinterpret_cast<half_t *>(&mat[row + 8][col + 4]) = fsum.data[1];
*reinterpret_cast<half_t *>(&mat[row + 8][col + 4]) = fsum.data[3];
}
__syncwarp();
static constexpr int ROTARY_EMB_NUM_ELEMENTS = 2; // 1 for theta, 2 for {sin, cos} pair
for (int row = 0; row < INSN_M; row++) {
pack_t pack = *reinterpret_cast<pack_t *>(&mat[row][laneId * PACK_SIZE]);
store(reinterpret_cast<pack_t *>(&output[(i * INSN_M + row) * stride + laneId * PACK_SIZE]), pack);
}
__syncwarp();
}
}
__device__ __forceinline__
static void apply(fpsum_warp fpsum, half_t *out, int M, int N, int K, half_t *pool_out, const float *rotary_emb, const half_t *rmsnorm_weight, float epsilon) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
__shared__ alignas(128) uint8_t shmem[NUM_WARPS][ceilDiv(unpack_fpsum::SHMEM_SIZE, 128) * 128];
// out: [M, N] <=> [..., NUM_WARPS, WARP_M, N] of half
template<typename Epilogue>
__device__ __forceinline__
static void gemm_w8a8_block(
const BlockInfo binfo,
const packed_act_t *act,
const packed_wgt_t *wgt,
const packed_ascale_t *ascales,
const packed_wscale_t *wscales,
half_t *out,
int M, int N, int K,
Epilogue::Arguments epilogeParams,
bool alwaysfalse)
{
constexpr int NUM_STAGES = 2;
constexpr int PACK_SIZE = unpack_fpsum::PACK_SIZE;
using pack_t = unpack_fpsum::pack_t;
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
using pack_rope_t = std::array<float, PACK_SIZE / 2 * ROTARY_EMB_NUM_ELEMENTS>;
constexpr int LANES_PER_HEAD = HEAD_DIM / PACK_SIZE;
act_warp A[NUM_STAGES]; // 8
wgt_warp W[NUM_STAGES]; // 32
ascale_warp ascale; // 1
wscale_warp wscale; // 2
psum_warp psum; // 128
pack_t reduce_tmp;
__shared__ alignas(128) pack_t pool[NUM_WARPS];
for (auto &pack : psum) {
for (int i = 0; i < 8; i++) {
pack.data[i] = 0;
// load rmsnorm scales
pack_t rms;
if (laneId < LANES_PER_HEAD) {
rms = load(reinterpret_cast<const pack_t *>(&rmsnorm_weight[laneId * PACK_SIZE]));
}
if constexpr (LANES_PER_HEAD < WARP_SIZE) {
for (int i = 0; i < PACK_SIZE; i++) {
rms[i] = __shfl_sync(~0, rms[i], laneId % LANES_PER_HEAD);
}
}
}
// load_wscale<true>(wscales, wscale[0], true);
// load_wscale<false>(wscales, wscale[1], true);
// load_wscale<false>(wscales, wscale[2], true);
load_ascale(ascales, 0, M, ascale, true);
load_wscale(wscales, 0, N, wscale, true);
for (int k = 0; k < NUM_STAGES - 1; k++) {
load_act(act, k, K, A[k], true);
load_wgt(wgt, k, K, W[k], true);
}
int dummy = 0;
const float *rotary_emb_base_addr = &rotary_emb[(warpId * WARP_M) * HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS + laneId * PACK_SIZE / 2 * ROTARY_EMB_NUM_ELEMENTS];
for (int k1 = 0; k1 < K / WARP_K; k1 += NUM_STAGES) {
#pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) {
int nextk = k1 + k2 + NUM_STAGES - 1;
int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
bool pred = nextk < K / WARP_K;
load_act(act, nextk, K, A[idx], pred);
load_wgt(wgt, nextk, K, W[idx], pred);
// load_wscale<false>(wscales, wscale[idx], pred);
CHECK_NAN(fpsum, "fpsum");
// __syncthreads();
// if (alwaysfalse) {
// dummy = clock();
// }
unpack_fpsum()(fpsum, out + warpId * WARP_M * N, N, INT_MAX, INT_MAX, shmem[warpId], [&](int rowId, pack_t &pack) ALWAYSINLINE {
// load rope
pack_rope_t rope;
if (laneId < LANES_PER_HEAD) {
// freq = load(reinterpret_cast<pack_freq_t *>(&freqs_cis[(warpId * WARP_M + rowId) * HEAD_DIM * 2 + laneId * PACK_SIZE * 2]));
rope = load(reinterpret_cast<const pack_rope_t *>(&rotary_emb_base_addr[rowId * HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS]));
}
if constexpr (LANES_PER_HEAD < WARP_SIZE) {
for (int i = 0; i < rope.size(); i++) {
rope[i] = __shfl_sync(~0, rope[i], laneId % LANES_PER_HEAD);
}
}
// if (alwaysfalse) {
// dummy = clock();
// }
// rmsnorm
float sqrsum = 0.0f;
for (int i = 0; i < PACK_SIZE; i++) {
sqrsum += float(pack[i]) * float(pack[i]);
CHECK_NAN(sqrsum, "sqrsum");
}
#pragma unroll
for (int mask = LANES_PER_HEAD / 2; mask > 0; mask /= 2) {
sqrsum += __shfl_xor_sync(~0, sqrsum, mask);
}
sqrsum /= HEAD_DIM;
float coef = cuda_frsqrt(sqrsum + epsilon);
CHECK_NAN(coef, "coef");
compute(A[k2], W[k2], psum);
for (int i = 0; i < PACK_SIZE; i++) {
pack[i] *= coef * float(rms[i]);
// if (alwaysfalse) {
// dummy = clock();
// }
CHECK_NAN(rms[i], "rms.wgt");
CHECK_NAN(pack[i], "rms.out");
}
// asm volatile ("membar.cta;");
}
}
#if 1
// rope
for (int i = 0; i < PACK_SIZE; i += 2) {
float2 pack2 = half22float2(half2_t(pack[i], pack[i+1]));
CHECK_NAN(freq[i].x, "rope.freq");
CHECK_NAN(freq[i].y, "rope.freq");
CHECK_NAN(freq[i+1].x, "rope.freq");
CHECK_NAN(freq[i+1].y, "rope.freq");
unused_var(dummy, alwaysfalse);
// half2_t tmp = __hmul2(freq[i], pack2);
// tmp = __hfma2(freq[i+1], pack2, tmp);
// pack[i] = tmp.x;
// pack[i+1] = tmp.y;
fpsum_warp fpsum;
apply_scales([&](int i, int j) {
return psum[i * WARP_N_TILES + j];
}, ascale, wscale, fpsum);
// printf("block.x=%d block.y=%d warpId=%d rowId=%d (%d) freqs = %f %f %f %f\n",
// blockIdx.x, blockIdx.y, warpId, rowId,
// blockIdx.x * BLOCK_M + warpId * WARP_M + rowId,
// (float)freq[i].x, (float)freq[i].y, (float)freq[i+1].x, (float)freq[i+1].y
// );
// __trap();
Epilogue()(binfo, fpsum, out, M, N, K, epilogeParams);
}
// half2_t tmp = __hmul2(half2_t(pack2.x, pack2.x), freq[i]);
// tmp = __hfma2(half2_t(pack2.y, pack2.y), freq[i+1], tmp);
// pack[i] = tmp.x;
// pack[i+1] = tmp.y;
float sin, cos;
// out : [M / BLOCK_M, BLOCK_M, N / BLOCK_N, BLOCK_N]
template<typename Epilogue>
struct gemm_w8a8_kernel {
__device__
void operator()(
const packed_act_t *act,
const packed_wgt_t *wgt,
const packed_ascale_t *ascales,
const packed_wscale_t *wscales,
half_t *out,
int M, int N, int K,
Epilogue::Arguments epilogueArgs,
bool swapBlockXY,
bool alwaysfalse)
{
BlockInfo binfo = {
.bm = (int)blockIdx.x,
.bn = (int)blockIdx.y,
.numBlocksM = (int)gridDim.x,
.numBlocksN = (int)gridDim.y,
};
if constexpr (ROTARY_EMB_NUM_ELEMENTS == 1) {
sin = cuda_sin(rope[i / 2]);
cos = cuda_cos(rope[i / 2]);
}
if constexpr (ROTARY_EMB_NUM_ELEMENTS == 2) {
sin = rope[i];
cos = rope[i+1];
}
if (swapBlockXY) {
std::swap(binfo.bm, binfo.bn);
std::swap(binfo.numBlocksM, binfo.numBlocksN);
}
// pack[i] = pack2.x * freq[i].x + pack2.y * freq[i].y;
// pack[i+1] = pack2.x * freq[i+1].x + pack2.y * freq[i+1].y;
const int bm = binfo.bm;
const int bn = binfo.bn;
pack[i] = half_t(pack2.x * cos - pack2.y * sin);
pack[i+1] = half_t(pack2.x * sin + pack2.y * cos);
gemm_w8a8_block<Epilogue>(
binfo,
act + bm * (K / WARP_K) * NUM_WARPS * WARP_M_TILES * WARP_SIZE,
wgt + bn * (K / WARP_K) * WARP_N_TILES * WARP_SIZE,
ascales + bm * (1) * NUM_WARPS * ASCALES_NUM_PACKS * ASCALES_VALID_LANES, // only 1 group in W8A8
wscales + bn * (1) * WSCALES_NUM_PACKS * WSCALES_VALID_LANES,
#if 1
out + (bm * BLOCK_M * N) + bn * BLOCK_N,
#else
out + (bm * BLOCK_M * N / 2) + bn * BLOCK_N / 2,
CHECK_NAN(pack[i], "rope.out");
CHECK_NAN(pack[i+1], "rope.out");
}
#endif
M, N, K,
epilogueArgs,
alwaysfalse
);
}
};
// mean pool
for (int i = 0; i < PACK_SIZE; i++) {
reduce_tmp[i] += pack[i];
}
});
struct EpilogueGLU {
struct Arguments { size_t unused; };
if (!pool_out) {
return;
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, half_t *out, int M, int N, int K, Arguments args) {
const int warpId = threadIdx.x / WARP_SIZE;
store<true>(&pool[warpId], reduce_tmp);
__syncthreads();
gated_fpsum_warp gated_fpsum = apply_glu(fpsum);
if (warpId < NUM_POOLS_PER_BLOCK) {
const int row = warpId * NUM_WARPS_PER_POOL;
reduce_tmp = load<true>(&pool[row]);
__shared__ alignas(128) uint8_t shmem[NUM_WARPS][ceilDiv(unpack_gated_fpsum_shmem_size, 128) * 128];
unpack_gated_fpsum(gated_fpsum, out + warpId * WARP_M * N / 2, N / 2, shmem[warpId]);
for (int i = 1; i < NUM_WARPS_PER_POOL; i++) {
pack_t pack = load<true>(&pool[row + i]);
for (int j = 0; j < PACK_SIZE; j++) {
reduce_tmp[j] += pack[j];
}
}
for (int j = 0; j < PACK_SIZE; j++) {
reduce_tmp[j] /= PoolSize;
}
store(reinterpret_cast<pack_t *>(pool_out + warpId * N), reduce_tmp);
}
__syncthreads();
}
};
struct EpilogueSilu {
struct Arguments { size_t unused; };
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, half_t *out, int M, int N, int K, Arguments args) {
const int warpId = threadIdx.x / WARP_SIZE;
fpsum = apply_act(fpsum, [](half_t x) { return silu(x); });
void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, Arguments args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
__shared__ alignas(128) uint8_t shmem[NUM_WARPS][ceilDiv(unpack_fpsum::SHMEM_SIZE, 128) * 128];
unpack_fpsum()(fpsum, out + warpId * WARP_M * N, N, shmem[warpId]);
assert(binfo.numBlocksN % 3 == 0);
const bool is_q = bn < binfo.numBlocksN / 3;
const bool is_k = !is_q && bn < binfo.numBlocksN / 3 * 2;
assert(args.actualM == M);
assert(args.actualN == N);
if (is_q || is_k) {
apply(
fpsum,
args.out + bm * BLOCK_M * args.actualN + bn * BLOCK_N,
M, N, K,
args.pool_out ? args.pool_out + bm * BLOCK_M / PoolSize * N : nullptr,
args.rotary_emb + bm * BLOCK_M * (HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS),
is_q ? args.rmsnorm_weight_q : args.rmsnorm_weight_k,
args.epsilon
);
} else {
EpilogueDefault()(binfo, fpsum, M, N, K, typename EpilogueDefault::Arguments{
.out = args.out,
.actualM = args.actualM,
.actualN = args.actualN,
});
}
}
};
......@@ -2451,26 +1238,6 @@ public:
asm volatile ("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" : "=r"(*reinterpret_cast<uint32_t *>(&x)) : "r"(*reinterpret_cast<uint32_t *>(&x)));
return x;
}
// __device__ __forceinline__
// static uint4 hmma_fp32(uint4 a, uint2 b, uint4 c) {
// asm volatile(
// "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
// "{%0, %1, %2, %3},"
// "{%4, %5, %6, %7},"
// "{%8, %9},"
// "{%10, %11, %12, %13};\n"
// :
// "=r"(c.x), "=r"(c.y), "=r"(c.z), "=r"(c.w)
// :
// "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
// "r"(b.x), "r"(b.y),
// // "r"(0), "r"(0), "r"(0), "r"(0)
// "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w)
// );
// return c;
// }
__device__ __forceinline__
......@@ -2491,7 +1258,7 @@ public:
// out_vk: [batch_size, num_heads, head_dim + 1, head_dim]
__device__ __forceinline__
static void apply_litela(const BlockInfo binfo, fpsum_warp fpsum, float *out_vk, int batch_m) {
static void apply_litela(const BlockInfo binfo, fpsum_warp fpsum, float *out_vk, int num_blocks_per_batch) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
......@@ -2505,7 +1272,7 @@ public:
assert(binfo.numBlocksN % 3 == 0);
const int num_heads = binfo.numBlocksN / 3 * 2 * (WARP_N / (LITELA_HEAD_DIM * 2));
const int batch_id = binfo.bm * BLOCK_M / batch_m;
const int batch_id = binfo.bm / num_blocks_per_batch;
for (int head_id = 0; head_id < WARP_N / (LITELA_HEAD_DIM * 2); head_id++) {
const int global_head_id = (binfo.bn - binfo.numBlocksN / 3) * (WARP_N / (LITELA_HEAD_DIM * 2)) + head_id;
......@@ -2523,7 +1290,7 @@ public:
packed_fpsum_t k = fpsum[i * WARP_N_TILES + head_id * (LITELA_HEAD_DIM * 2) / 16 + tile_k];
packed_fpsum_t v = fpsum[i * WARP_N_TILES + head_id * (LITELA_HEAD_DIM * 2) / 16 + LITELA_HEAD_DIM / 16 + tile_v];
for (int j = 0; j < 4; j++) {
k.data[j] = __hmax2(k.data[j], make_half2(0, 0)); // relu
k.data[j] = __hmax2(k.data[j], half2_t(0, 0)); // relu
}
attn_sum = mma_litela(k, v, attn_sum);
}
......@@ -2545,11 +1312,22 @@ public:
packed_f32psum_t attn_sum = { 0 };
for (int i = 0; i < WARP_M_TILES; i++) {
packed_fpsum_t k = fpsum[i * WARP_N_TILES + head_id * (LITELA_HEAD_DIM * 2) / 16 + tile_k];
packed_fpsum_t v = {}; // TODO fill to 0
if (laneId < 4) {
v.data[0] = make_half2(1, 1);
v.data[2] = make_half2(1, 1);
packed_fpsum_t v = {};
for (int j = 0; j < 4; j++) {
k.data[j] = __hmax2(k.data[j], half2_t(0, 0)); // relu
}
#pragma unroll
for (int i = 0; i < 4; i++) {
v.data[i] = half2_t(1, 1);
}
// if (laneId < 4) {
// v.data[0] = half2_t(1, 1);
// v.data[2] = half2_t(1, 1);
// }
// if (laneId % 4 == 0) {
// v.data[0] = half2_t(1, 0);
// v.data[1] = half2_t(1, 0);
// }
attn_sum = mma_litela(k, v, attn_sum);
}
const int row = LITELA_HEAD_DIM + laneId / 4;
......@@ -2580,11 +1358,12 @@ public:
struct Arguments {
half_t *out_q;
float *out_vk;
int batch_m;
int num_blocks_per_batch;
int actualM;
};
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, half_t *out, int M, int N, int K, Arguments args) {
void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, Arguments args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
......@@ -2593,11 +1372,14 @@ public:
return EpilogueDefault()(
binfo,
fpsum,
args.out_q + (bm * BLOCK_M * N / 3) + bn * BLOCK_N,
M, N / 3, K, EpilogueDefault::Arguments{});
M, N / 3, K, typename EpilogueDefault::Arguments{
.out = args.out_q,
.actualM = args.actualM,
.actualN = N / 3,
});
}
return apply_litela(binfo, fpsum, args.out_vk, args.batch_m);
return apply_litela(binfo, fpsum, args.out_vk, args.num_blocks_per_batch);
}
// each thread block mults BlockSize*HEAD_DIM q and (HEAD_DIM+1)*HEAD_DIM vk, in-place writes back to q
......@@ -2606,7 +1388,7 @@ public:
struct vk_mul_q_kernel {
// FIXME FIXME FIXME
__device__
void operator()(half_t *q, const float *vk, float eps) {
void operator()(half_t *q, const float *vk, float eps, int num_tokens) {
const int block_id = blockIdx.x;
const int head_id = blockIdx.y;
const int batch_id = blockIdx.z;
......@@ -2615,6 +1397,8 @@ public:
const int num_heads = gridDim.y;
const int block_size = blockDim.x;
bool pred = block_id * block_size + threadIdx.x < num_tokens;
half_t *localq = &q[(((batch_id * num_blocks + block_id) * block_size + threadIdx.x) * num_heads + head_id) * LITELA_HEAD_DIM];
const float *localvk = &vk[(batch_id * num_heads + head_id) * (LITELA_HEAD_DIM + 1) * LITELA_HEAD_DIM];
// half_t *localout = &out[(((batch_id * num_blocks + block_id) * block_size + threadIdx.x) * num_heads + head_id) * LITELA_HEAD_DIM];
......@@ -2624,7 +1408,9 @@ public:
half_t qblock[LITELA_HEAD_DIM];
for (int i = 0; i < LITELA_HEAD_DIM; i += sizeof(packed_q) / sizeof(half_t)) {
*reinterpret_cast<packed_q *>(&qblock[i]) = load(reinterpret_cast<const packed_q *>(&localq[i]));
if (pred) {
*reinterpret_cast<packed_q *>(&qblock[i]) = load(reinterpret_cast<const packed_q *>(&localq[i]));
}
}
float outblock[LITELA_HEAD_DIM + 1];
......@@ -2646,497 +1432,63 @@ public:
for (int k = 0; k < opack.size(); k++) {
opack[k] = __fdividef(outblock[i + k], outblock[LITELA_HEAD_DIM] + eps);
}
store(reinterpret_cast<packed_q *>(&localq[i]), opack);
if (pred) {
store(reinterpret_cast<packed_q *>(&localq[i]), opack);
}
}
}
};
};
};
template<typename kernel, typename ...T>
__global__
static void invoke_kernel(T ...args) {
kernel()(args...);
}
template<typename T>
__global__
static void test_sizeof_device() {
printf("sizeof on device = %d\n", (int)sizeof(T));
}
template<typename T>
static void test_sizeof_host() {
printf("sizeof on host = %d\n", (int)sizeof(T));
}
template<typename T>
static void test_sizeof() {
printf("typeid = %s\n", typeid(T).name());
test_sizeof_host<T>();
test_sizeof_device<T><<<1, 1>>>();
checkCUDA(cudaDeviceSynchronize());
}
void gemm_w4a4(
Tensor act, // packed act [M, K / 2]
Tensor wgt, // packed act [N, K / 2]
Tensor out, // linear [M, N]
Tensor qout, // packed act [M, N / 2]
Tensor ascales, // packed as [K / 64, M]
Tensor wscales, // packed ws [K / 64, N]
Tensor oscales, // packed as [N / 64, M]
Tensor poolout, // linear [M / PoolSize, N]
Tensor lora_act_in, // packed lora_act [M, R]
Tensor lora_up, // packed lora_wgt [N, R]
Tensor lora_down, // packed lora_wgt [N, R]
Tensor lora_act_out, // packed lora_act [M, R]
Tensor norm_q, // linear [HEAD_DIM]
Tensor norm_k, // linear [HEAD_DIM]
Tensor rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
Tensor bias, // packed ws [N]
Tensor smooth_factor, // packed ws [N], for quantization of the next layer
bool act_unsigned,
std::vector<float> lora_scales // [R / 16]
) {
using GEMM = GEMM_W4A4;
int M = act.numel() / act.shape[-1];
int N = wgt.shape[0];
int K = act.shape[-1] * 2;
assert(K == wgt.shape[1] * 2);
// spdlog::info("M={} N={} K={}", M, N, K);
// spdlog::info("act at {}", act.data_ptr());
// spdlog::info("wgt at {}", wgt.data_ptr());
// spdlog::info("ascales at {}", ascales.data_ptr());
// spdlog::info("wscales at {}", wscales.data_ptr());
// spdlog::info("bias at {}", bias.data_ptr());
auto launch = [&]<typename Epilogue>(Epilogue::Arguments args) {
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
bool swapBlockMN = M > N * 2;
if (swapBlockMN) {
std::swap(grid.x, grid.y);
}
dispatchBool(act_unsigned, [&]<bool ACT_UNSIGNED>() {
// test_sizeof<typename Epilogue::Arguments>();
// std::apply([](auto ...args) {
// (test_sizeof<decltype(args)>(), ...);
// }, args);
invoke_kernel<GEMM::gemm_w4a4_kernel<Epilogue, ACT_UNSIGNED>><<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS>>>(
act.data_ptr<GEMM::packed_act_t>(),
wgt.data_ptr<GEMM::packed_wgt_t>(),
ascales.data_ptr<GEMM::packed_ascale_t>(),
wscales.data_ptr<GEMM::packed_wscale_t>(),
// bias.valid() ? bias.data_ptr<GEMM::packed_wscale_t>() : nullptr,
out.valid() ? out.data_ptr<GEMM::half_t>() : nullptr,
M, N, K,
args,
swapBlockMN,
false
);
checkCUDA(cudaGetLastError());
});
};
auto launch_bias = [&]<typename NextEpilogue>(NextEpilogue::Arguments nextArgs) {
if (!bias.valid()) {
return launch.template operator()<NextEpilogue>(nextArgs);
}
assert(bias.numel() == N);
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using Epilogue = GEMM::EpilogueCombination<GEMM::EpilogueBias, NextEpilogue, GEMM::EpilogueNop>;
return launch.template operator()<Epilogue>({
GEMM::EpilogueBias::Arguments{
.bias = bias.data_ptr<GEMM::packed_wscale_t>(),
},
nextArgs,
{}
});
};
// auto launch_bias = launch;
auto launch_lora = [&]<typename NextEpilogue, typename MidEpilogue>(NextEpilogue::Arguments nextArgs, MidEpilogue::Arguments midArgs) {
assert(lora_up.valid() == lora_act_in.valid());
assert(lora_down.valid() == lora_act_out.valid());
if (!lora_up.valid()) {
assert(!lora_down.valid());
return launch_bias.template operator()<GEMM::EpilogueCombination<MidEpilogue, NextEpilogue>>({midArgs, nextArgs});
}
const int rank_up = lora_up.shape[1];
assert(lora_up.shape[0] == N);
// assert(lora_up.shape[1] == Lora::LORA_RANK);
assert(lora_act_in.shape[0] == M);
assert(lora_act_in.shape[1] == rank_up);
dispatchVal(rank_up, std::integer_sequence<int, 0, 32, 48, 64, 80, 96>(), [&]<int RANK_UP>() {
using LoraUp = GEMM::Lora<RANK_UP>;
using scale_t = typename LoraUp::scale_t;
template<typename Epilogue, bool ACT_UNSIGNED>
struct gemm_w4a4_kernel {
__device__
void operator()(
const packed_act_t *act,
const packed_wgt_t *wgt,
const packed_ascale_t *ascales,
const packed_wscale_t *wscales,
int M, int N, int K,
Epilogue::Arguments epilogueArgs,
bool swapBlockXY,
bool alwaysfalse)
{
// printf("Device sizeof(args) = %d", (int)sizeof(epilogueArgs));
scale_t scales;
if constexpr (scales.size() > 0) {
assert(lora_scales.size() >= scales.size());
for (size_t i = 0; i < scales.size(); i++) {
scales[i] = lora_scales[i];
}
}
BlockInfo binfo = {
.bm = (int)blockIdx.x,
.bn = (int)blockIdx.y,
.numBlocksM = (int)gridDim.x,
.numBlocksN = (int)gridDim.y,
};
if (!lora_down.valid()) {
using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp, MidEpilogue, NextEpilogue, GEMM::EpilogueNop>;
return launch_bias.template operator()<Epilogue>({
typename LoraUp::EpilogueLoraUp::Arguments{
.lora_act = lora_act_in.data_ptr<float>(),
.lora_wgt_up = lora_up.data_ptr<GEMM::packed_fpsum_t>(),
.scales = scales,
},
midArgs,
nextArgs,
{}
});
if (swapBlockXY) {
std::swap(binfo.bm, binfo.bn);
std::swap(binfo.numBlocksM, binfo.numBlocksN);
}
const int rank_down = lora_down.shape[1];
assert(rank_down == rank_up);
assert(lora_down.shape[0] == N);
// assert(lora_down.shape[1] == Lora::LORA_RANK);
assert(lora_act_out.shape[0] == M);
assert(lora_act_out.shape[1] == rank_down);
lora_act_out.zero_();
// dispatchVal(rank_down, std::integer_sequence<int, 16, 32, 48, 64, 80>(), [&]<int RANK_DOWN>() {
using LoraDown = LoraUp; // GEMM::Lora<RANK_DOWN>;
using Epilogue = GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp, MidEpilogue, typename LoraDown::EpilogueLoraDown, NextEpilogue, GEMM::EpilogueNop>;
return launch_bias.template operator()<Epilogue>({
typename LoraUp::EpilogueLoraUp::Arguments{
.lora_act = lora_act_in.data_ptr<float>(),
.lora_wgt_up = lora_up.data_ptr<GEMM::packed_fpsum_t>(),
.scales = scales,
},
midArgs,
typename LoraDown::EpilogueLoraDown::Arguments{
.lora_wgt_down = lora_down.data_ptr<GEMM::packed_fpsum_t>(),
.lora_act = lora_act_out.data_ptr<float>(),
},
nextArgs,
{}
});
// });
});
};
if (qout.valid() && oscales.valid()) {
// dispatchBool(qout_unsigned, [&]<bool USE_UNSIGNED>() {
static constexpr float SHIFT_GELU = 0.171875f;
const int bm = binfo.bm;
const int bn = binfo.bn;
constexpr bool USE_UNSIGNED = true;
using EpilogueQuantize = typename GEMM::EpilogueQuantize<false, USE_UNSIGNED>;
auto argsQuantize = typename EpilogueQuantize::Arguments{
.qout = qout.data_ptr<GEMM::packed_act_t>(),
.oscales = oscales.data_ptr<GEMM::packed_ascale_t>(),
.shift_value = SHIFT_GELU,
.smooth_factor = smooth_factor.data_ptr<GEMM::packed_wscale_t>()
};
// bool fusequant = !out;
// TODO: check if gelu is needed
if (out.valid()) {
launch_lora.template operator()<GEMM::EpilogueCombination<GEMM::EpilogueDefault, EpilogueQuantize>, GEMM::EpilogueGelu>({
GEMM::EpilogueDefault::Arguments{},
argsQuantize
}, {});
} else {
launch_lora.template operator()<EpilogueQuantize, GEMM::EpilogueGelu>(argsQuantize, {});
gemm_w4a4_block<Epilogue, ACT_UNSIGNED>(
binfo,
act + bm * (K / WARP_K) * NUM_WARPS * WARP_M_TILES * WARP_SIZE,
wgt + bn * (K / WARP_K) * WARP_N_TILES * WARP_SIZE,
ascales + bm * (K / WARP_K) * NUM_WARPS * ASCALES_NUM_PACKS * ASCALES_VALID_LANES,
wscales + bn * (K / WARP_K) * WSCALES_NUM_PACKS * WSCALES_VALID_LANES,
// bias ? bias + bn * WSCALES_NUM_PACKS * WSCALES_VALID_LANES : nullptr,
// out + (bm * BLOCK_M * N) + bn * BLOCK_N,
// out + (bm * N / BLOCK_N + bn) * NUM_WARPS * WARP_M_TILES * WARP_N_TILES * WARP_SIZE,
M, N, K,
epilogueArgs,
alwaysfalse
);
}
// });
} else if (rotary_emb.valid()) {
assert(norm_q.valid());
assert(norm_k.valid());
// assert(isTypeMatch<GEMM::half_t>(rotary_emb.scalar_type()));
assert(rotary_emb.scalar_type() == Tensor::FP32);
assert(rotary_emb.numel() == M * GEMM::EpilogueQKVProj::HEAD_DIM / 2 * GEMM::EpilogueQKVProj::ROTARY_EMB_NUM_ELEMENTS);
launch_lora.template operator()<GEMM::EpilogueQKVProj, GEMM::EpilogueNop>(GEMM::EpilogueQKVProj::Arguments{
.pool_out = poolout.valid() ? poolout.data_ptr<GEMM::half_t>() : nullptr,
.rotary_emb = rotary_emb.data_ptr<float>(),
.rmsnorm_weight_q = norm_q.data_ptr<GEMM::half_t>(),
.rmsnorm_weight_k = norm_k.data_ptr<GEMM::half_t>(),
.epsilon = 1e-6,
}, {});
} else if (out.valid()) {
launch_lora.template operator()<GEMM::EpilogueDefault, GEMM::EpilogueNop>({}, {});
} else {
assert(false);
}
}
void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth) {
using GEMM = GEMM_W4A4;
// using Lora = GEMM::Lora;
int M = input.numel() / input.shape[-1];
int N = input.shape[-1];
assert(output.dtype() == Tensor::INT8);
assert(output.numel() / output.shape[-1] == M);
assert(output.shape[-1] == N / 2);
// assert(oscales.dtype() == Tensor::FP16);
assert(isTypeMatch<GEMM::half_t>(oscales.dtype()));
assert(oscales.numel() == M * N / GEMM::WARP_K);
const int rank = lora_down.shape[1];
assert(lora_down.shape[0] == N);
// assert(lora_down.shape[1] == Lora::LORA_RANK);
assert(lora_act_out.shape[0] == M);
assert(lora_act_out.shape[1] == rank);
lora_act_out.zero_();
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
dispatchVal(rank, std::integer_sequence<int, 0, 32, 48, 64, 80, 96>(), [&]<int RANK>() {
using Lora = typename GEMM::Lora<RANK>;
using kernel = typename Lora::quantize_w4a4_fuse_lora_kernel;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE>>>(
typename kernel::Arguments{
.input = input.data_ptr<GEMM::half_t>(),
.smooth_factor = smooth.valid() ? smooth.data_ptr<GEMM::packed_wscale_t>() : nullptr,
.output = output.data_ptr<GEMM::packed_act_t>(),
.oscales = oscales.data_ptr<GEMM::packed_ascale_t>(),
.lora_wgt_down = lora_down.data_ptr<GEMM::packed_fpsum_t>(),
.lora_act = lora_act_out.data_ptr<float>(),
.M = M,
.N = N,
}
);
checkCUDA(cudaGetLastError());
});
}
void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales) {
using GEMM = GEMM_W4A4;
int M = input.numel() / input.shape[-1];
int K = input.shape[-1];
assert(output.dtype() == Tensor::INT8);
assert(output.numel() / output.shape[-1] == M);
assert(output.shape[-1] == K / 2);
// assert(oscales.dtype() == Tensor::FP16);
assert(isTypeMatch<GEMM::half_t>(oscales.dtype()));
assert(oscales.numel() == M * K / GEMM::WARP_K);
dim3 grid(M / GEMM::WARP_M, K / GEMM::WARP_K);
invoke_kernel<GEMM::quantize_w4a4_act_kernel><<<grid, GEMM::WARP_SIZE>>>(
input.data_ptr<GEMM::half_t>(),
output.data_ptr<GEMM::packed_act_t>(),
oscales.data_ptr<GEMM::packed_ascale_t>(),
K
);
checkCUDA(cudaGetLastError());
}
void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales) {
using GEMM = GEMM_W4A4;
int N = input.numel() / input.shape[-1];
int K = input.shape[-1];
assert(output.dtype() == Tensor::INT8);
assert(output.ndims() == 2);
assert(output.shape[0] == N);
assert(output.shape[1] == K / 2);
assert(isTypeMatch<GEMM::half_t>(oscales.dtype()));
// assert(oscales.dtype() == Tensor::FP16);
assert(oscales.numel() == N * K / GEMM::WARP_K);
dim3 grid(N / GEMM::WARP_N, K / GEMM::WARP_K);
invoke_kernel<GEMM::quantize_w4a4_wgt_kernel><<<grid, GEMM::WARP_SIZE>>>(
input.data_ptr<GEMM::half_t>(),
output.data_ptr<GEMM::packed_wgt_t>(),
oscales.data_ptr<GEMM::packed_wscale_t>(),
K
);
checkCUDA(cudaGetLastError());
}
void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_glu) {
int M = input.numel() / input.shape[-1];
int K = input.shape[-1];
assert(output.dtype() == Tensor::INT8);
assert(output.numel() / output.shape[-1] == M);
assert(output.shape[-1] == fuse_glu ? K / 2 : K);
assert(oscales.dtype() == Tensor::FP16);
assert(oscales.numel() == M * 1);
auto launch = [&]<bool FUSE_GLU>() {
using GEMM = GEMM_W8A8;
using kernel = GEMM::quantize_w8a8_act_kernel<FUSE_GLU>;
assert(kernel::check(M, K));
dim3 grid = kernel::gridSize(M, K);
dim3 block = kernel::blockSize(M, K);
auto func = invoke_kernel<kernel, const GEMM::half_t *, GEMM::packed_act_t *, GEMM::packed_ascale_t *, int, bool>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, 92160));
func<<<grid, block, kernel::smemSize(M, K)>>>(
input.data_ptr<GEMM::half_t>(),
output.data_ptr<GEMM::packed_act_t>(),
oscales.data_ptr<GEMM::packed_ascale_t>(),
K,
false
);
checkCUDA(cudaGetLastError());
};
};
if (fuse_glu) {
launch.template operator()<true>();
} else {
launch.template operator()<false>();
}
}
void gemm_w8a8(Tensor act, // [M, K]
Tensor wgt, // [N, K]
Tensor out, // [M, N]
Tensor ascales, // [1, M]
Tensor wscales // [1, N]
)
{
using GEMM = GEMM_W8A8;
using Epilogue = GEMM::EpilogueSilu;
int M = act.numel() / act.shape[-1];
int N = wgt.shape[0];
int K = act.shape[-1];
assert(K == wgt.shape[1]);
Epilogue::Arguments epilogueArgs;
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
bool swapBlockMN = M > N * 2;
if (swapBlockMN) {
std::swap(grid.x, grid.y);
}
invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>><<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS>>>(
act.data_ptr<GEMM::packed_act_t>(),
wgt.data_ptr<GEMM::packed_wgt_t>(),
ascales.data_ptr<GEMM::packed_ascale_t>(),
wscales.data_ptr<GEMM::packed_wscale_t>(),
out.data_ptr<GEMM::half_t>(),
M, N, K, epilogueArgs,
swapBlockMN,
false
);
checkCUDA(cudaGetLastError());
}
void gemm_w8a8_fuse_litela(
Tensor act, // [B, (M), K]
Tensor wgt, // [N, K]
Tensor out_q, // [B, (M), N / 3]
Tensor out_vk, // [B, num_heads, head_dim + 1, head_dim]
Tensor ascales, // [1, M]
Tensor wscales // [1, N]
) {
using GEMM = GEMM_W8A8;
using Epilogue = GEMM::EpilogueLiteLA;
int M = act.numel() / act.shape[-1];
int N = wgt.shape[0];
int K = act.shape[-1];
assert(K == wgt.shape[1]);
assert(out_vk.ndims() == 4);
assert(out_vk.shape[2] == Epilogue::LITELA_HEAD_DIM + 1);
assert(out_vk.shape[3] == Epilogue::LITELA_HEAD_DIM);
assert(out_vk.shape[1] * Epilogue::LITELA_HEAD_DIM * 3 == N);
int batch_size = out_vk.shape[0];
int num_heads = out_vk.shape[1];
assert(M % batch_size == 0);
int batch_m = M / batch_size;
Epilogue::Arguments epilogueArgs;
epilogueArgs.batch_m = act.shape[1];
epilogueArgs.out_q = out_q.data_ptr<GEMM::half_t>();
epilogueArgs.out_vk = out_vk.data_ptr<float>();
checkCUDA(cudaMemsetAsync(out_vk.data_ptr(), 0, out_vk.buffer->getSize()));
auto func = invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>,
const GEMM::packed_act_t *,
const GEMM::packed_wgt_t *,
const GEMM::packed_ascale_t *,
const GEMM::packed_wscale_t *,
GEMM::half_t *,
int, int, int,
Epilogue::Arguments,
bool,
bool>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, Epilogue::SHMEM_SIZE));
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
bool swapBlockMN = M > N * 2;
if (swapBlockMN) {
std::swap(grid.x, grid.y);
}
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, Epilogue::SHMEM_SIZE>>>(
act.data_ptr<GEMM::packed_act_t>(),
wgt.data_ptr<GEMM::packed_wgt_t>(),
ascales.data_ptr<GEMM::packed_ascale_t>(),
wscales.data_ptr<GEMM::packed_wscale_t>(),
nullptr,
M, N, K, epilogueArgs,
swapBlockMN,
false
);
checkCUDA(cudaGetLastError());
invoke_kernel<Epilogue::vk_mul_q_kernel><<<dim3(batch_m / 128, num_heads, batch_size), 128>>>(
out_q.data_ptr<GEMM::half_t>(),
out_vk.data_ptr<float>(),
1e-6f
);
checkCUDA(cudaGetLastError());
}
\ No newline at end of file
}; // namespace nunchaku::kernels
\ No newline at end of file
#include "gemm_w4a4.cuh"
namespace nunchaku::kernels {
template<typename Config>
class GEMM_W4A4_Launch {
using GEMM = GEMM_W4A4<Config>;
using LoraRanks = std::integer_sequence<int, 0, 32, 48, 64, 80, 96>;
// using LoraRanks = std::integer_sequence<int, 32>;
using packed_act_t = typename GEMM::packed_act_t;
using packed_wgt_t = typename GEMM::packed_wgt_t;
using packed_ascale_t = typename GEMM::packed_ascale_t;
using packed_wscale_t = typename GEMM::packed_wscale_t;
using packed_fpsum_t = typename GEMM::packed_fpsum_t;
using half_t = typename GEMM::half_t;
public:
static void gemm_w4a4(
Tensor act, // packed act [M, K / 2]
Tensor wgt, // packed act [N, K / 2]
Tensor out, // linear [M, N]
Tensor qout, // packed act [M, N / 2]
Tensor ascales, // packed as [K / 64, M]
Tensor wscales, // packed ws [K / 64, N]
Tensor oscales, // packed as [N / 64, M]
Tensor poolout, // linear [M / PoolSize, N]
Tensor lora_act_in, // packed lora_act [M, R]
Tensor lora_up, // packed lora_wgt [N, R]
Tensor lora_down, // packed lora_wgt [N, R]
Tensor lora_act_out, // packed lora_act [M, R]
Tensor norm_q, // linear [HEAD_DIM]
Tensor norm_k, // linear [HEAD_DIM]
Tensor rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
Tensor bias, // packed ws [N]
Tensor smooth_factor, // packed ws [N], for quantization of the next layer
Tensor out_vk, // linear [B, num_heads, head_dim + 1, head_dim]
Tensor out_linearattn,// linear [B, (M), N / 3]
bool act_unsigned,
std::vector<float> lora_scales, // [R / 16]
bool fuse_silu
);
static void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu);
static void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales);
static void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales);
static void linearattn_vk_mul_q(Tensor q, Tensor vk);
};
}; // namespace nunchaku::kernels
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment