Commit e9ad0535 authored by muyangli's avatar muyangli
Browse files

[major] support SANA

parent 9eb2cee0
#include "SanaModel.h"
#include "kernels/zgemm/zgemm.h"
#include "flash_api.h"
#include "kernels/misc_kernels.h"
#include <nvtx3/nvToolsExt.h>
using spdlog::fmt_lib::format;
using namespace nunchaku;
SanaLinearAttention::SanaLinearAttention(int dim, bool bias, bool pag, Tensor::ScalarType dtype, Device device) :
dim(dim),
dim_pad(ceilDiv(dim, 128) * 128),
qkv_proj(dim, dim_pad * 3, bias, dtype, device),
out_proj(dim_pad, dim, bias, dtype, device),
pag_to_v(std::nullopt)
{
registerChildren
(qkv_proj, "qkv_proj")
(out_proj, "out_proj")
;
if (pag) {
pag_to_v.emplace(dim, dim_pad, bias, dtype, device);
registerChildren(pag_to_v.value(), "pag_to_v");
}
}
Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
constexpr int HEAD_DIM = 32;
assert(x.ndims() == 3);
const int batch_size = x.shape[0];
const int num_tokens = x.shape[1];
const int num_tokens_pad = ceilDiv(num_tokens, 256) * 256;
assert(x.shape[2] == dim);
const int num_heads = dim_pad / HEAD_DIM;
if (num_tokens_pad != num_tokens) {
spdlog::debug("SanaLinearAttention: pad num_tokens from {} to {}", num_tokens, num_tokens_pad);
Tensor x_pad = Tensor::allocate({batch_size, num_tokens_pad, dim}, x.dtype(), x.device());
x_pad.zero_();
for (int i = 0; i < batch_size; i++) {
x_pad.slice(0, i, i + 1).slice(1, 0, num_tokens).copy_(x.slice(0, i, i + 1));
}
x = x_pad;
}
auto qact = qkv_proj.quantize(x, false);
Tensor q = Tensor::allocate({batch_size, num_tokens_pad, dim_pad}, x.dtype(), x.device());
Tensor vk = Tensor::allocate({batch_size, num_heads, HEAD_DIM + 1, HEAD_DIM}, Tensor::FP32, x.device());
kernels::gemm_w4a4(
qact.act,
qkv_proj.qweight,
{},
{},
qact.ascales,
qkv_proj.wscales,
{}, {}, qact.lora_act, qkv_proj.lora_up, {}, {}, {}, {}, {}, qkv_proj.bias, {},
vk, q,
qact.is_unsigned, qkv_proj.lora_scales, false);
debug("vk", vk);
debug("q", q);
kernels::linearattn_vk_mul_q(q, vk);
debug("raw_attn_output", q);
if (num_tokens_pad != num_tokens) {
Tensor q_unpad = Tensor::allocate({batch_size, num_tokens, dim_pad}, q.dtype(), q.device());
for (int i = 0; i < batch_size; i++) {
q_unpad.slice(0, i, i + 1).copy_(q.slice(0, i, i + 1).slice(1, 0, num_tokens));
}
q = q_unpad;
}
// kernels::gemm_w8a8_fuse_litela(qact.act, qkv.qweight, q, vk, qact.ascales, qkv.wscales);
// return out_proj.forward(q);
if (!out.valid()) {
out = Tensor::allocate({batch_size, num_tokens, dim}, q.dtype(), q.device());
}
out_proj.forward(q, out);
return out;
}
Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) {
const int batch_size = x.shape[0];
const int num_tokens = x.shape[1];
Tensor out = Tensor::allocate({batch_size, num_tokens, dim}, x.dtype(), x.device());
Tensor x_org, x_ptb;
Tensor out_org, out_ptb;
if (cfg) {
assert(batch_size % 3 == 0);
x_org = x.slice(0, 0, batch_size * 2 / 3);
x_ptb = x.slice(0, batch_size * 2 / 3, batch_size);
out_org = out.slice(0, 0, batch_size * 2 / 3);
out_ptb = out.slice(0, batch_size * 2 / 3, batch_size);
} else {
assert(batch_size % 2 == 0);
x_org = x.slice(0, 0, batch_size / 2);
x_ptb = x.slice(0, batch_size / 2, batch_size);
out_org = out.slice(0, 0, batch_size / 2);
out_ptb = out.slice(0, batch_size / 2, batch_size);
}
this->forward(x_org, out_org);
Tensor v_ptb = this->pag_to_v.value().forward(x_ptb);
this->out_proj.forward(v_ptb, out_ptb);
return out;
}
MultiHeadCrossAttention::MultiHeadCrossAttention(int num_heads, int head_dim, Tensor::ScalarType dtype, Device device) :
num_heads(num_heads), head_dim(head_dim),
q_linear(num_heads * head_dim, num_heads * head_dim, true, dtype, device),
kv_linear(num_heads * head_dim, num_heads * head_dim * 2, true, dtype, device),
out_proj(num_heads * head_dim, num_heads * head_dim, true, dtype, device)
{
registerChildren
(q_linear, "q_linear")
(kv_linear, "kv_linear")
(out_proj, "out_proj")
;
}
Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens_img, Tensor cu_seqlens_txt) {
assert(x.ndims() == 3);
assert(cond.ndims() == 2);
assert(cu_seqlens_img.ndims() == 1);
assert(cu_seqlens_txt.ndims() == 1);
const int batch_size = x.shape[0];
const int num_tokens_img = x.shape[1];
const int num_tokens_txt = cond.shape[0];
assert(cu_seqlens_img.shape[0] == batch_size + 1);
assert(cu_seqlens_txt.shape[0] == batch_size + 1);
Tensor q = q_linear.forward(x).view({batch_size * num_tokens_img, num_heads, head_dim});
Tensor kv = kv_linear.forward(cond).view({num_tokens_txt, num_heads * 2, head_dim});
Tensor k = kv.slice(1, 0, num_heads);
Tensor v = kv.slice(1, num_heads, num_heads * 2);
Tensor attn_output = mha_varlen_fwd(
q, k, v,
cu_seqlens_img, cu_seqlens_txt,
num_tokens_img, num_tokens_txt,
0.0f,
pow(q.shape[-1], (-0.5)),
false, false,
-1, -1,
false
).front().view({batch_size, num_tokens_img, num_heads * head_dim});
// Tensor attn_output = mha_fwd(q, k, v,
// 0.0f,
// pow(q.shape[-1], (-0.5)),
// false, -1, -1, false
// ).front().view({B, N, num_heads * head_dim});
return out_proj.forward(attn_output);
}
SanaGLUMBConv::SanaGLUMBConv(int in_features, int hidden_features, Tensor::ScalarType dtype, Device device) :
in_features(in_features), hidden_features(hidden_features),
inverted_conv(in_features, hidden_features * 2, true, dtype, device),
depth_conv(hidden_features * 2, true, dtype, device),
point_conv(hidden_features, in_features, false, dtype, device)
{
registerChildren
(inverted_conv, "inverted_conv")
(depth_conv, "depth_conv")
(point_conv, "point_conv")
;
}
Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) {
if (H <= 0 || W <= 0) {
H = W = sqrt(x.shape[1]);
}
x = inverted_conv.forward_silu(x);
x = x.view({x.shape[0], H, W, x.shape[-1]});
debug("inverted_conv_output", x);
x = depth_conv.forward(x);
debug("depth_conv_output", x);
x = x.view({x.shape[0], H * W, x.shape[-1]});
auto qact = point_conv.quantize(x, true);
return point_conv.forward_quant(qact);
}
SanaLinearTransformerBlock::SanaLinearTransformerBlock(int hidden_size, int intermediate_size, int num_cross_attention_heads, bool pag, Tensor::ScalarType dtype, Device device) :
hidden_size(hidden_size), num_cross_attention_heads(num_cross_attention_heads),
attn(hidden_size, false, pag, dtype, device),
cross_attn(num_cross_attention_heads, hidden_size / num_cross_attention_heads, dtype, device),
ff(hidden_size, intermediate_size, dtype, device),
norm1(hidden_size, 1e-6, false, dtype, device),
norm2(hidden_size, 1e-6, false, dtype, device)
{
this->scale_shift_table = Tensor::allocate({6, hidden_size}, dtype, device);
registerChildren
(attn, "attn")
(cross_attn, "cross_attn")
(ff, "ff")
;
registerParams
(this->scale_shift_table, "scale_shift_table")
;
}
Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg) {
nvtxRangePushA("SanaLinearTransformerBlock");
nvtxRangePushA("chunk");
// Tensor ones = Tensor::ones({hidden_size}, Tensor::FP16, x.device());
const int batch_size = timestep.shape[0];
timestep = timestep.copy(timestep.device());
timestep = timestep.view({batch_size, 6, hidden_size});
kernels::mul_add_batch(timestep, {}, false, 0, this->scale_shift_table, false);
debug("shifted_timestep", timestep);
std::array<Tensor, 6> chunked;
for (int i = 0; i < 6; i++) {
chunked[i] = timestep.slice(1, i, i + 1);
}
auto &&[shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] = chunked;
// auto &&[shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] = kernels::split_mod<6>(timestep);
nvtxRangePop();
{
nvtxRangePushA("LinearAttention");
Tensor residual = hidden_states;
Tensor norm_hidden_states = norm1.forward(hidden_states);
kernels::mul_add_batch(norm_hidden_states, scale_msa, true, 1, shift_msa, true);
debug("norm_hidden_states_la", norm_hidden_states);
Tensor attn_output = pag ? attn.forward_pag(norm_hidden_states, cfg) : attn.forward(norm_hidden_states);
debug("attn_output_la", attn_output);
kernels::mul_add_batch(attn_output, gate_msa, true, 0, residual, true);
hidden_states = attn_output;
nvtxRangePop();
}
{
nvtxRangePushA("CrossAttention");
debug("norm_hidden_states_cross", hidden_states);
Tensor attn_output = cross_attn.forward(hidden_states, encoder_hidden_states, cu_seqlens_img, cu_seqlens_txt);
debug("attn_output_cross", attn_output);
kernels::mul_add_batch(attn_output, {}, false, 0, hidden_states, true);
hidden_states = attn_output;
nvtxRangePop();
}
{
nvtxRangePushA("Feed-forward");
debug("hidden_states_ff", hidden_states);
Tensor norm_hidden_states = norm2.forward(hidden_states);
kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 1, shift_mlp, true);
debug("norm_hidden_states_ff", norm_hidden_states);
Tensor ff_output = ff.forward(norm_hidden_states, H, W);
debug("ff_output", ff_output);
kernels::mul_add_batch(ff_output, gate_mlp, true, 0, hidden_states, true);
hidden_states = ff_output;
nvtxRangePop();
}
nvtxRangePop();
debug("hidden_states_out", hidden_states);
return hidden_states;
}
SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device) :
config(config)
{
const int inner_dim = config.num_attention_heads * config.attention_head_dim;
for (int i = 0; i < config.num_layers; i++) {
transformer_blocks.push_back(std::make_unique<SanaLinearTransformerBlock>(
inner_dim,
ceilDiv(int(round(config.expand_ratio * inner_dim)), 64) * 64,
config.num_cross_attention_heads,
std::find(config.pag_layers.begin(), config.pag_layers.end(), i) != config.pag_layers.end(),
dtype, device
));
registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i));
}
}
Tensor SanaModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg) {
for (int i = 0; i < config.num_layers; i++) {
auto &&block = transformer_blocks[i];
hidden_states = block->forward(
hidden_states, encoder_hidden_states, timestep, cu_seqlens_img, cu_seqlens_txt, H, W,
pag && std::find(config.pag_layers.begin(), config.pag_layers.end(), i) != config.pag_layers.end(),
cfg
);
}
return hidden_states;
}
#pragma once
#include "common.h"
#include "Tensor.h"
#include "Linear.h"
#include "layernorm.h"
class SanaLinearAttention : public Module {
public:
SanaLinearAttention(int dim, bool bias, bool pag, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor x, Tensor out = {});
Tensor forward_pag(Tensor x, bool cfg);
public:
const int dim;
const int dim_pad;
private:
GEMM_W4A4 qkv_proj;
GEMM_W4A4 out_proj;
std::optional<GEMM_W4A4> pag_to_v;
};
class MultiHeadCrossAttention : public Module {
public:
MultiHeadCrossAttention(int num_heads, int head_dim, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor x, Tensor cond, Tensor cu_seqlens_img, Tensor cu_seqlens_txt);
public:
const int num_heads;
const int head_dim;
private:
GEMM_W4A4 q_linear;
GEMM_F16 kv_linear;
GEMM_W4A4 out_proj;
};
class SanaGLUMBConv : public Module {
public:
SanaGLUMBConv(int in_features, int hidden_features, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor x, int H, int W);
public:
const int in_features;
const int hidden_features;
private:
GEMM_W4A4 inverted_conv;
DWCONV depth_conv;
GEMM_W4A4 point_conv;
};
class SanaLinearTransformerBlock : public Module {
public:
SanaLinearTransformerBlock(int hidden_size, int intermediate_size, int num_cross_attention_heads, bool pag, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg);
public:
const int hidden_size;
const int num_cross_attention_heads;
private:
Tensor scale_shift_table;
// Tensor ones;
SanaLinearAttention attn;
MultiHeadCrossAttention cross_attn;
SanaGLUMBConv ff;
LayerNorm norm1, norm2;
};
struct SanaConfig {
int num_layers;
int num_attention_heads;
int attention_head_dim;
int num_cross_attention_heads;
double expand_ratio;
std::vector<int> pag_layers;
};
class SanaModel : public Module {
public:
SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor timestep, Tensor cu_seqlens_img, Tensor cu_seqlens_txt, int H, int W, bool pag, bool cfg);
public:
const SanaConfig config;
public:
std::vector<std::unique_ptr<SanaLinearTransformerBlock>> transformer_blocks;
};
\ No newline at end of file
......@@ -32,6 +32,10 @@ public:
size_t getSize() { return size; }
Device getDevice() { return device; }
virtual bool isAsyncBuffer() {
return false;
}
protected:
template <typename Derived>
std::shared_ptr<Derived> shared_from_base() {
......@@ -90,6 +94,9 @@ public:
}
checkCUDA(cudaFreeAsync(this->ptr, 0));
}
virtual bool isAsyncBuffer() override {
return true;
}
};
class BufferCUDASync : public Buffer {
......@@ -499,16 +506,16 @@ private:
return cudaMemcpyDefault;
}
static bool isAsyncBuffer(Buffer *buffer) {
return dynamic_cast<BufferCUDA *>(buffer);
}
// static bool isAsyncBuffer(Buffer *buffer) {
// return dynamic_cast<BufferCUDA *>(buffer);
// }
static inline std::map<cudaStream_t, std::set<std::shared_ptr<Buffer>>> lockedBuffers;
public:
// before launching an async operation, make sure to lock the buffer in case the buffer is freed before GPU completes
static void lockBuffer(std::shared_ptr<Buffer> buffer, cudaStream_t stream) {
if (!isAsyncBuffer(buffer.get())) {
if (!buffer->isAsyncBuffer()) {
lockedBuffers[stream].insert(buffer);
}
}
......
......@@ -33,7 +33,7 @@ Tensor from_torch(at::Tensor input) {
result.scalarType = mapType.at(input.scalar_type());
result.buffer = std::make_shared<BufferTorchTensor>(std::move(input));
Tensor::lockBuffer(result.buffer, getCurrentCUDAStream());
// Tensor::lockBuffer(result.buffer, getCurrentCUDAStream());
return result;
}
......
......@@ -13,6 +13,10 @@ public:
this->device.type = this->tensor.is_cuda() ? Device::CUDA : Device::CPU;
this->device.idx = this->tensor.get_device();
}
virtual bool isAsyncBuffer() override {
// TODO: figure out how torch manages memory
return true;
}
private:
at::Tensor tensor;
};
......
#pragma once
#include "common.h"
#include "Tensor.h"
#include <cutlass/cutlass.h>
#include <cutlass/half.h>
#include <cutlass/bfloat16.h>
template<typename F>
inline void dispatchF16(Tensor::ScalarType type, F &&func) {
if (type == Tensor::FP16) {
func.template operator()<cutlass::half_t>();
} else if (type == Tensor::BF16) {
func.template operator()<cutlass::bfloat16_t>();
} else {
assert(false);
}
}
\ No newline at end of file
#include "common.h"
#include "Tensor.h"
#include "dispatch_cutlass.h"
#include <cuda_runtime.h>
#include "cutlass/cutlass.h"
......@@ -10,6 +12,7 @@
// depthwise_Conv2d operation cutlass_sm80_tensorop_f16_s16x8x16fprop_analytic_f16_256x128_64x3_nhwc_align8
#if 0
using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, 64>;
using FilterShape = cutlass::MatrixShape<3, 3>;
......@@ -194,3 +197,140 @@ Tensor depthwise_conv2d_kernel(Tensor A, Tensor B) {
return D;
}
#endif
Tensor dwconv_f16(Tensor input, Tensor weight, Tensor out, Tensor bias) {
assert(input.ndims() == 4);
const int N = input.size(0);
const int H = input.size(1);
const int W = input.size(2);
const int C_ = input.size(3);
assert(weight.ndims() == 4);
const int K = weight.size(0);
const int R = weight.size(1);
const int S = weight.size(2);
const int C__ = weight.size(3);
// weight = weight.copy(weight.device());
dispatchF16(weight.dtype(), [&]<typename half_t>() {
using ElementOutput = half_t;
using ElementAccumulator = half_t;
using ElementComputeEpilogue = half_t;
using ElementInputA = half_t;
using ElementInputB = half_t;
using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutOutput = cutlass::layout::TensorNHWC;
using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, 64>;
using FilterShape = cutlass::MatrixShape<3, 3>;
using ThreadblockShape = cutlass::gemm::GemmShape<ThreadBlockOutputShape::kNHW, 64, FilterShape::kCount>;
using WarpShape = cutlass::gemm::GemmShape<16, 64, FilterShape::kCount>;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop<
ElementInputA, LayoutInputA,
ElementInputB, LayoutInputB,
ElementOutput, LayoutOutput,
ElementAccumulator,
cutlass::arch::OpClassSimt,
cutlass::arch::Sm80,
ThreadblockShape,
ThreadBlockOutputShape,
FilterShape,
WarpShape,
InstructionShape,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementOutput, ElementComputeEpilogue>,
cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle<
1,
ThreadBlockOutputShape::kN,
ThreadBlockOutputShape::kH,
ThreadBlockOutputShape::kW>,
4,
cutlass::arch::OpMultiplyAdd,
cutlass::conv::IteratorAlgorithm::kFixedStrideDilation,
cutlass::conv::StrideSupport::kFixed,
cutlass::MatrixShape<1, 1>,
cutlass::MatrixShape<1, 1>>::Kernel;
using DeviceKernel = typename cutlass::conv::device::DirectConvolution<DepthwiseDirect2dConv>;
cutlass::conv::Conv2dProblemSize problem_size(
cutlass::Tensor4DCoord(N, H, W, C_),
cutlass::Tensor4DCoord(K, R, S, C__),
cutlass::Tensor4DCoord(1, 1, 1, 1),
cutlass::MatrixCoord(1, 1),
cutlass::MatrixCoord(1, 1),
cutlass::conv::Mode::kCrossCorrelation,
1,
C_ // groups
);
const int P = problem_size.P;
const int Q = problem_size.Q;
if (!out.valid()) {
out = Tensor::allocate({N, P, Q, K}, input.dtype(), input.device());
}
assert(out.ndims() == 4);
assert(out.size(0) == N);
assert(out.size(1) == P);
assert(out.size(2) == Q);
assert(out.size(3) == K);
Tensor tmp_weight = Tensor::empty_like(weight);
cutlass::TensorRef<ElementInputA, LayoutInputA> a_ref(input.data_ptr<ElementInputA>(), LayoutInputA(input.stride(2), input.stride(1), input.stride(0)));
cutlass::TensorRef<ElementInputB, LayoutInputB> b_ref(weight.data_ptr<ElementInputB>(), LayoutInputB(weight.stride(2), weight.stride(1), weight.stride(0)));
cutlass::TensorRef<ElementOutput, LayoutOutput> c_ref(bias.valid() ? bias.data_ptr<ElementOutput>() : out.data_ptr<ElementOutput>(), LayoutOutput(0, 0, 0));
cutlass::TensorRef<ElementOutput, LayoutOutput> d_ref(out.data_ptr<ElementOutput>(), LayoutOutput(out.stride(2), out.stride(1), out.stride(0)));
cutlass::TensorRef<ElementOutput, LayoutOutput> tmpw_ref(tmp_weight.data_ptr<ElementOutput>(), LayoutOutput(tmp_weight.stride(2), tmp_weight.stride(1), tmp_weight.stride(0)));
typename DeviceKernel::Arguments arguments{
problem_size,
a_ref,
b_ref,
c_ref,
d_ref,
{ElementOutput(1.0f), ElementOutput(bias.valid() ? 1.0f : 0.0f)},
tmpw_ref,
};
DeviceKernel implicit_gemm_op;
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
BufferCUDA workspace(workspace_size);
auto stream = getCurrentCUDAStream();
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot implement");
}
status = implicit_gemm_op.initialize(arguments, workspace.getPtr(), stream);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}
status = implicit_gemm_op(stream);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot run");
}
});
return out;
}
\ No newline at end of file
......@@ -3,4 +3,6 @@
#include "common.h"
#include "Tensor.h"
Tensor depthwise_conv2d_kernel(Tensor A, Tensor B);
\ No newline at end of file
// Tensor depthwise_conv2d_kernel(Tensor A, Tensor B);
Tensor dwconv_f16(Tensor input, Tensor weight, Tensor out, Tensor bias);
\ No newline at end of file
......@@ -47,7 +47,7 @@ Tensor gemm_batched_fp16(
auto sizeO = cutlass::MatrixCoord(M, N);
if (!out.valid()) {
auto outShape = a.shape;
auto outShape = TensorShape(a.shape.dataExtent);
outShape[-1] = N;
out = Tensor::empty(outShape, Tensor::FP32, a.device());
}
......
#include "gemm_f16.h"
#include "dispatch_cutlass.h"
#include <cutlass/core_io.h>
#include <cutlass/cutlass.h>
#include <cutlass/half.h>
......@@ -13,8 +15,8 @@ using spdlog::fmt_lib::format;
Tensor gemm_f16(Tensor input, // FP16
Tensor weight, // FP16
Tensor out, // FP16
float alpha,
float beta
Tensor bias,
float alpha
) {
auto N = weight.size(0);
auto K = input.size(-1);
......@@ -23,102 +25,111 @@ Tensor gemm_f16(Tensor input, // FP16
spdlog::debug("gemm_f16: M={} K={} N={}", M, K, N);
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using ElementComputeEpilogue = cutlass::bfloat16_t;
using ElementInputA = cutlass::bfloat16_t; // <- data type of elements in input matrix A
using ElementInputB = cutlass::bfloat16_t; // <- data type of elements in input matrix B
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
// #if CUDA_ARCH >= 800
using Gemm = cutlass::gemm::device::Gemm<
ElementInputA, cutlass::layout::RowMajor, ElementInputB, cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementComputeEpilogue>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
auto input_size = cutlass::MatrixCoord(M, K);
auto weight_size = cutlass::MatrixCoord(K, N);
auto output_size = cutlass::MatrixCoord(M, N);
auto device = input.device();
// use the broadcasted bias as the output
// auto out = bias.to(device).view({1, -1}).repeat({M, 1});
if (!out.valid()) {
auto out_shape = input.shape;
out_shape[-1] = N;
out = Tensor::empty(out_shape, input.scalar_type(), input.device());
}
// FIXME: check contiguous of input if dims >= 3
assert(input.stride(-1) == 1);
// assert(input.is_contiguous());
assert(weight.is_contiguous());
assert(out.dtype() == input.scalar_type());
assert(out.shape[-1] == N);
assert(out.numel() / out.shape[-1] == M);
assert(out.stride(-1) == 1);
// FIXME: check contiguous of output if dims >= 3
// constexpr int kSparse = Gemm::kSparse;
// How many elements of A are covered per ElementE
// constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
// The size of individual meta data
// constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
cutlass::gemm::GemmCoord problem_size(M, N, K);
cutlass::TensorRef<ElementInputA, LayoutInputA> input_ref(
input.data_ptr<ElementInputA>(), LayoutInputA(input.stride(-2)));
cutlass::TensorRef<ElementInputB, LayoutInputB> weight_ref(
weight.data_ptr<ElementInputB>(), LayoutInputB::packed(weight_size));
cutlass::TensorRef<ElementOutput, LayoutOutput> out_ref(
out.data_ptr<ElementOutput>(), LayoutOutput(out.stride(-2)));
typename Gemm::Arguments arguments{
problem_size, // <- problem size of matrix multiplication
input_ref, // <- reference to matrix A on device
weight_ref, // <- reference to matrix B on device
out_ref, // <- reference to matrix C on device
out_ref, // <- reference to matrix D on device
{ElementOutput(alpha), ElementOutput(beta)},
1};
Gemm gemm_op;
// Using the arguments, query for extra workspace required for matrix
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
// cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
BufferCUDA workspace(workspace_size);
// Check the problem size is supported or not
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error(format("cutlass cannot implement M={} N={} K={}", M, N, K));
}
// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm_op.initialize(arguments, workspace.getPtr());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}
status = gemm_op();
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot run");
}
dispatchF16(weight.dtype(), [&]<typename half_t>() {
using ElementOutput = half_t;
using ElementAccumulator = float;
using ElementComputeEpilogue = half_t;
using ElementInputA = half_t; // <- data type of elements in input matrix A
using ElementInputB = half_t; // <- data type of elements in input matrix B
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
// #if CUDA_ARCH >= 800
using Gemm = cutlass::gemm::device::Gemm<
ElementInputA, cutlass::layout::RowMajor, ElementInputB, cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementComputeEpilogue>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
auto input_size = cutlass::MatrixCoord(M, K);
auto weight_size = cutlass::MatrixCoord(K, N);
auto output_size = cutlass::MatrixCoord(M, N);
auto device = input.device();
// use the broadcasted bias as the output
// auto out = bias.to(device).view({1, -1}).repeat({M, 1});
if (!out.valid()) {
auto out_shape = TensorShape(input.shape.dataExtent);
out_shape[-1] = N;
out = Tensor::empty(out_shape, input.scalar_type(), input.device());
}
// FIXME: check contiguous of input if dims >= 3
assert(input.stride(-1) == 1);
// assert(input.is_contiguous());
assert(weight.is_contiguous());
assert(out.dtype() == input.scalar_type());
assert(out.shape[-1] == N);
assert(out.numel() / out.shape[-1] == M);
assert(out.stride(-1) == 1);
// FIXME: check contiguous of output if dims >= 3
assert(!bias.valid() || (bias.ndims() == 1 && bias.shape[0] == N));
// constexpr int kSparse = Gemm::kSparse;
// How many elements of A are covered per ElementE
// constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
// The size of individual meta data
// constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
cutlass::gemm::GemmCoord problem_size(M, N, K);
cutlass::TensorRef<ElementInputA, LayoutInputA> input_ref(
input.data_ptr<ElementInputA>(), LayoutInputA(input.stride(-2)));
cutlass::TensorRef<ElementInputB, LayoutInputB> weight_ref(
weight.data_ptr<ElementInputB>(), LayoutInputB::packed(weight_size));
cutlass::TensorRef<ElementOutput, LayoutOutput> bias_ref(
bias.valid() ? bias.data_ptr<ElementOutput>() : out.data_ptr<ElementOutput>(), LayoutOutput(0));
cutlass::TensorRef<ElementOutput, LayoutOutput> out_ref(
out.data_ptr<ElementOutput>(), LayoutOutput(out.stride(-2)));
typename Gemm::Arguments arguments{
problem_size, // <- problem size of matrix multiplication
input_ref, // <- reference to matrix A on device
weight_ref, // <- reference to matrix B on device
bias_ref, // <- reference to matrix C on device
out_ref, // <- reference to matrix D on device
{ElementOutput(alpha), ElementOutput(bias.valid() ? 1.0f : 0.0f)},
1};
Gemm gemm_op;
// Using the arguments, query for extra workspace required for matrix
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
// cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
BufferCUDA workspace(workspace_size);
// Check the problem size is supported or not
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error(format("cutlass cannot implement M={} N={} K={}", M, N, K));
}
// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm_op.initialize(arguments, workspace.getPtr());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}
status = gemm_op();
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot run");
}
});
return out;
}
......@@ -7,6 +7,6 @@ Tensor gemm_f16(
Tensor input, // FP16
Tensor weight, // FP16
Tensor out, // FP16
float alpha,
float beta
Tensor bias,
float alpha
);
\ No newline at end of file
......@@ -82,7 +82,7 @@ Tensor gemm_w8a8_fp16(Tensor input, // INT8
// auto out = bias.to(device).view({1, -1}).repeat({M, 1});
if (!out.valid()) {
auto out_shape = input.shape;
auto out_shape = TensorShape(input.shape.dataExtent);
out_shape[-1] = N;
out = Tensor::empty(out_shape, Tensor::FP16, input.device());
}
......
......@@ -2,6 +2,8 @@
#include "misc_kernels.h"
#include "dispatch_utils.h"
namespace nunchaku::kernels {
Tensor add(Tensor a, Tensor b) {
assert(a.shape.dataExtent == b.shape.dataExtent);
assert(a.dtype() == b.dtype());
......@@ -34,11 +36,11 @@ void mul_add(Tensor x, Tensor scale, Tensor bias) {
constexpr int unroll = 8;
assert((uintptr_t)x.data_ptr() % (x.scalar_size() * unroll) == 0);
assert((uintptr_t)scale.data_ptr() % (x.scalar_size() * unroll) == 0);
assert(!scale.valid() || (uintptr_t)scale.data_ptr() % (x.scalar_size() * unroll) == 0);
assert((uintptr_t)bias.data_ptr() % (x.scalar_size() * unroll) == 0);
assert(x.numel() % unroll == 0);
assert(scale.numel() % unroll == 0);
assert(!scale.valid() || scale.numel() % unroll == 0);
assert(bias.numel() % unroll == 0);
int threadsPerBlock = 1024;
......@@ -47,8 +49,60 @@ void mul_add(Tensor x, Tensor scale, Tensor bias) {
auto stream = getCurrentCUDAStream();
dispatch(x.scalar_type(), [&]<typename scalar_t>() {
mul_add_kernel<scalar_t, unroll><<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
x.data_ptr<scalar_t>(), scale.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), x.numel(), scale.numel(), bias.numel());
if (scale.valid()) {
mul_add_kernel<scalar_t, unroll, false><<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
x.data_ptr<scalar_t>(), scale.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), 0, x.numel(), scale.numel(), bias.numel(), 0, 0, 0);
} else {
mul_add_kernel<scalar_t, unroll, true><<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
x.data_ptr<scalar_t>(), nullptr, bias.data_ptr<scalar_t>(), 0, x.numel(), 1, bias.numel(), 0, 0, 0);
}
});
}
void mul_add_batch(Tensor x, Tensor scale, bool batch_scale, double scale_shift, Tensor bias, bool batch_bias) {
const int batch_size = x.shape[0];
assert(!batch_scale || scale.shape[0] == batch_size);
assert(!batch_bias || bias.shape[0] == batch_size);
const int numel = x.numel() / batch_size;
const int numel_scale = scale.valid() ? (scale.numel() / (batch_scale ? batch_size : 1)) : 1;
const int numel_bias = bias.numel() / (batch_bias ? batch_size : 1);
assert(numel % numel_scale == 0);
assert(numel % numel_bias == 0);
assert(!scale.valid() || x.dtype() == scale.dtype());
assert(x.dtype() == bias.dtype());
constexpr int unroll = 8;
assert((uintptr_t)x.data_ptr() % (x.scalar_size() * unroll) == 0);
assert(!scale.valid() || (uintptr_t)scale.data_ptr() % (x.scalar_size() * unroll) == 0);
assert((uintptr_t)bias.data_ptr() % (x.scalar_size() * unroll) == 0);
assert(numel % unroll == 0);
assert(!scale.valid() || numel_scale % unroll == 0);
assert(numel_bias % unroll == 0);
int threadsPerBlock = 1024;
dim3 grid(ceilDiv(numel, threadsPerBlock * unroll), batch_size);
auto stream = getCurrentCUDAStream();
dispatch(x.scalar_type(), [&]<typename scalar_t>() {
if (scale.valid()) {
mul_add_kernel<scalar_t, unroll, false><<<grid, threadsPerBlock, 0, stream>>>(
x.data_ptr<scalar_t>(), scale.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(),
(scalar_t)scale_shift,
numel, numel_scale, numel_bias,
x.stride(0), batch_scale ? scale.stride(0) : 0, batch_bias ? bias.stride(0) : 0);
} else {
mul_add_kernel<scalar_t, unroll, true><<<grid, threadsPerBlock, 0, stream>>>(
x.data_ptr<scalar_t>(), nullptr, bias.data_ptr<scalar_t>(),
(scalar_t)scale_shift,
numel, 1, numel_bias,
x.stride(0), 0, batch_bias ? bias.stride(0) : 0);
}
});
}
......@@ -219,7 +273,7 @@ Tensor topk(Tensor x, int k) {
assert(k <= N);
assert(k <= MAXK);
auto outShape = x.shape;
auto outShape = TensorShape(x.shape.dataExtent);
outShape[-1] = k;
outShape.dataStride.clear();
......@@ -252,4 +306,6 @@ template std::array<Tensor, 2> split_mod<2>(Tensor input);
template std::array<Tensor, 3> split_mod<3>(Tensor input);
template std::array<Tensor, 4> split_mod<4>(Tensor input);
template std::array<Tensor, 5> split_mod<5>(Tensor input);
template std::array<Tensor, 6> split_mod<6>(Tensor input);
\ No newline at end of file
template std::array<Tensor, 6> split_mod<6>(Tensor input);
}; // namespace nunchaku::kernels
\ No newline at end of file
......@@ -3,8 +3,12 @@
#include "common.h"
#include "Tensor.h"
namespace nunchaku::kernels {
Tensor add(Tensor a, Tensor b);
void mul_add(Tensor x, Tensor scale, Tensor bias);
void mul_add_batch(Tensor x, Tensor scale, bool batch_scale, double scale_shift, Tensor bias, bool batch_bias);
Tensor embedding(Tensor input_id, Tensor lookup);
Tensor argmax_sample(Tensor logits);
void splitqkv(Tensor qkv, Tensor q, Tensor k, Tensor v);
......@@ -16,4 +20,6 @@ void cast(Tensor input, Tensor output);
Tensor topk(Tensor x, int k);
template<size_t N>
std::array<Tensor, N> split_mod(Tensor input);
\ No newline at end of file
std::array<Tensor, N> split_mod(Tensor input);
}; // namespace nunchaku::kernels
\ No newline at end of file
......@@ -7,6 +7,8 @@
#include "utils.cuh"
#include "activation_kernels_impl.cuh"
namespace nunchaku::kernels {
template<typename T>
__global__ void add_kernel(T *a, T *b, T *c, size_t length) {
......@@ -21,9 +23,9 @@ struct alignas(sizeof(T) * unroll) Tvec {
T data[unroll];
};
template<typename T, int unroll>
__global__ void mul_add_kernel(T *x, T *scale, T *bias, size_t length, int mod_scale, int mod_bias) {
template<typename T, int unroll, bool no_scale>
__global__ void mul_add_kernel(T *x, T *scale, T *bias, T scale_shift, size_t length, int mod_scale, int mod_bias, int64_t batch_stride_x, int64_t batch_stride_scale, int64_t batch_stride_bias) {
const int batch_id = blockIdx.y;
int thread = threadIdx.x + blockIdx.x * blockDim.x;
int i = thread * unroll;
int i_scale = i % mod_scale;
......@@ -33,15 +35,20 @@ __global__ void mul_add_kernel(T *x, T *scale, T *bias, size_t length, int mod_s
return;
}
using Tvec = ::Tvec<T, unroll>;
using Tvec = nunchaku::kernels::Tvec<T, unroll>;
Tvec rx = *reinterpret_cast<Tvec *>(&x[i]);
Tvec rscale = *reinterpret_cast<Tvec *>(&scale[i_scale]);
Tvec rbias = *reinterpret_cast<Tvec *>(&bias[i_bias]);
Tvec rx = *reinterpret_cast<Tvec *>(&x[i + batch_stride_x * batch_id]);
Tvec rscale = *reinterpret_cast<Tvec *>(&scale[i_scale + batch_stride_scale * batch_id]);
Tvec rbias = *reinterpret_cast<Tvec *>(&bias[i_bias + batch_stride_bias * batch_id]);
#pragma unroll
for (int k = 0; k < unroll; k++) {
T tmp = rx.data[k] * rscale.data[k] + rbias.data[k];
T tmp;
if constexpr (no_scale) {
tmp = rx.data[k] + rbias.data[k];
} else {
tmp = rx.data[k] * (rscale.data[k] + scale_shift) + rbias.data[k];
}
if constexpr (std::is_same_v<T, half>) {
tmp = __hmin(tmp, (half)65504);
tmp = __hmax(tmp, (half)-65504);
......@@ -49,7 +56,7 @@ __global__ void mul_add_kernel(T *x, T *scale, T *bias, size_t length, int mod_s
rx.data[k] = tmp;
}
*reinterpret_cast<Tvec *>(&x[i]) = rx;
*reinterpret_cast<Tvec *>(&x[i + batch_stride_x * batch_id]) = rx;
// #pragma unroll
// for (int k = 0; k < unroll; k++) {
......@@ -127,8 +134,8 @@ __global__ void quant_kernel_static(const T * input, int8_t * output, T scale, s
return;
}
using Tvec = ::Tvec<T, unroll>;
using I8vec = ::Tvec<int8_t, unroll>;
using Tvec = nunchaku::kernels::Tvec<T, unroll>;
using I8vec = nunchaku::kernels::Tvec<int8_t, unroll>;
Tvec rinput = *reinterpret_cast<const Tvec *>(&input[i]);
I8vec routput;
......@@ -149,8 +156,8 @@ __global__ void quant_kernel_static_fuse_gelu(const T * input, int8_t * output,
return;
}
using Tvec = ::Tvec<T, unroll>;
using I8vec = ::Tvec<int8_t, unroll>;
using Tvec = nunchaku::kernels::Tvec<T, unroll>;
using I8vec = nunchaku::kernels::Tvec<int8_t, unroll>;
Tvec rinput = *reinterpret_cast<const Tvec *>(&input[i]);
I8vec routput;
......@@ -168,8 +175,8 @@ template<typename Tin, typename Tout, int unroll>
__global__ void cast_kernel(const Tin *input, Tout *output, size_t length) {
const int i = (blockIdx.x * blockDim.x + threadIdx.x) * unroll;
using Tvec_in = ::Tvec<Tin, unroll>;
using Tvec_out = ::Tvec<Tout, unroll>;
using Tvec_in = nunchaku::kernels::Tvec<Tin, unroll>;
using Tvec_out = nunchaku::kernels::Tvec<Tout, unroll>;
Tvec_in rinput = *reinterpret_cast<const Tvec_in *>(&input[i]);
Tvec_out routput;
......@@ -250,4 +257,6 @@ void topk_kernel(const T *input, int *output, int N, int strideInput, int numRow
for (int i = 0; i < K; i++) {
output[row * K + i] = idx[K - i - 1];
}
}
\ No newline at end of file
}
}; // namespace nunchaku::kernels
\ No newline at end of file
This diff is collapsed.
......@@ -2,7 +2,9 @@
#include <cstdint>
#include "common.h"
#include "utils.cuh"
#include "../utils.cuh"
namespace nunchaku::kernels {
static constexpr int clamp(int val, int min, int max) {
if (val < min)
......@@ -74,25 +76,19 @@ static void store(T *addr, T val) {
*addr = val;
}
template<typename T>
__device__ __forceinline__
float2 half22float2(T val);
template<>
__device__ __forceinline__
float2 half22float2<half2>(half2 val) {
static float2 half22float2(half2 val) {
return __half22float2(val);
}
template<>
__device__ __forceinline__
float2 half22float2<__nv_bfloat162>(__nv_bfloat162 val) {
static float2 half22float2(__nv_bfloat162 val) {
return __bfloat1622float2(val);
}
template<typename T>
__device__ __forceinline__
T float22half2(float2 val);
static T float22half2(float2 val) = delete;
template<>
__device__ __forceinline__
......@@ -108,7 +104,7 @@ __nv_bfloat162 float22half2<__nv_bfloat162>(float2 val) {
template<typename T>
__device__ __forceinline__
void unused_var(T &val, bool alwaysfalse) {
static void unused_var(T &val, bool alwaysfalse) {
volatile T *ptr = nullptr;
if (alwaysfalse) {
*ptr = val;
......@@ -218,7 +214,7 @@ static float cuda_sigmoidf (float a)
template<typename T>
__device__ __forceinline__
static T gelu_half2(T x) {
float2 xf = half22float2<T>(x);
float2 xf = half22float2(x);
float2 x3f = xf * xf * xf;
float t1 = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf.x + (0.044715f * x3f.x)));
float t2 = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf.y + (0.044715f * x3f.y)));
......@@ -242,6 +238,25 @@ static T silu(const T &x) {
// return (T)__fdividef((float)x, 1.0f + __expf((float)-x));
}
__device__ __forceinline__
static half2 h2div(half2 a, half2 b) {
float2 af = half22float2(a);
float2 bf = half22float2(b);
float2 of;
of.x = __fdividef(af.x, bf.x);
of.y = __fdividef(af.y, bf.y);
return float22half2<half2>(of);
};
__device__ __forceinline__
static __nv_bfloat162 h2div(__nv_bfloat162 a, __nv_bfloat162 b) {
float2 af = half22float2(a);
float2 bf = half22float2(b);
float2 of;
of.x = __fdividef(af.x, bf.x);
of.y = __fdividef(af.y, bf.y);
return float22half2<__nv_bfloat162>(of);
};
__device__ __forceinline__
static void reduce_add(float *addr, float val) {
asm volatile ("red.relaxed.gpu.global.add.f32 [%0], %1;" :: "l"(addr), "f"(val));
......@@ -254,4 +269,6 @@ static void unrolled_loop(F &&lambda) {
(lambda.template operator()<Is>(), ...);
};
call(std::make_integer_sequence<int, cnt>());
}
\ No newline at end of file
}
}; // namespace nunchaku::kernels
\ No newline at end of file
#include "zgemm.h"
#include "gemm_w4a4_launch.cuh"
namespace nunchaku::kernels {
template<typename F>
static void invoke_launch(Tensor::ScalarType dtype, F &&launch) {
if (dtype == Tensor::FP16) {
launch.template operator()<GEMMConfig_W4A4_FP16>();
} else if (dtype == Tensor::BF16) {
launch.template operator()<GEMMConfig_W4A4_BF16>();
} else {
assert(false);
}
}
void gemm_w4a4(
Tensor act, // packed act [M, K / 2]
Tensor wgt, // packed act [N, K / 2]
Tensor out, // linear [M, N]
Tensor qout, // packed act [M, N / 2]
Tensor ascales, // packed as [K / 64, M]
Tensor wscales, // packed ws [K / 64, N]
Tensor oscales, // packed as [N / 64, M]
Tensor poolout, // linear [M / PoolSize, N]
Tensor lora_act_in, // packed lora_act [M, R]
Tensor lora_up, // packed lora_wgt [N, R]
Tensor lora_down, // packed lora_wgt [N, R]
Tensor lora_act_out, // packed lora_act [M, R]
Tensor norm_q, // linear [HEAD_DIM]
Tensor norm_k, // linear [HEAD_DIM]
Tensor rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
Tensor bias, // packed ws [N]
Tensor smooth_factor, // packed ws [N], for quantization of the next layer
Tensor out_vk, // linear [B, num_heads, head_dim + 1, head_dim]
Tensor out_linearattn,// linear [B, (M), N / 3]
bool act_unsigned,
std::vector<float> lora_scales, // [R / 16]
bool fuse_silu
) {
invoke_launch(ascales.dtype(), [&]<typename Config>() {
GEMM_W4A4_Launch<Config>::gemm_w4a4(
act,
wgt,
out,
qout,
ascales,
wscales,
oscales,
poolout,
lora_act_in,
lora_up,
lora_down,
lora_act_out,
norm_q,
norm_k,
rotary_emb,
bias,
smooth_factor,
out_vk,
out_linearattn,
act_unsigned,
lora_scales,
fuse_silu
);
});
}
void linearattn_vk_mul_q(Tensor q, Tensor vk) {
invoke_launch(q.dtype(), [&]<typename Config>() {
GEMM_W4A4_Launch<Config>::linearattn_vk_mul_q(q, vk);
});
}
void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu) {
invoke_launch(input.dtype(), [&]<typename Config>() {
GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(
input, output, oscales, lora_down, lora_act_out, smooth, fuse_glu
);
});
}
void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales) {
invoke_launch(input.dtype(), [&]<typename Config>() {
GEMM_W4A4_Launch<Config>::quantize_w4a4_act(
input, output, oscales
);
});
}
void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales) {
invoke_launch(input.dtype(), [&]<typename Config>() {
GEMM_W4A4_Launch<Config>::quantize_w4a4_wgt(
input, output, oscales
);
});
}
};
\ No newline at end of file
#include "gemm_w4a4.cuh"
namespace nunchaku::kernels {
template<typename Config>
class GEMM_W4A4_Launch {
using GEMM = GEMM_W4A4<Config>;
using LoraRanks = std::integer_sequence<int, 0, 32, 48, 64, 80, 96>;
// using LoraRanks = std::integer_sequence<int, 32>;
using packed_act_t = typename GEMM::packed_act_t;
using packed_wgt_t = typename GEMM::packed_wgt_t;
using packed_ascale_t = typename GEMM::packed_ascale_t;
using packed_wscale_t = typename GEMM::packed_wscale_t;
using packed_fpsum_t = typename GEMM::packed_fpsum_t;
using half_t = typename GEMM::half_t;
public:
static void gemm_w4a4(
Tensor act, // packed act [M, K / 2]
Tensor wgt, // packed act [N, K / 2]
Tensor out, // linear [M, N]
Tensor qout, // packed act [M, N / 2]
Tensor ascales, // packed as [K / 64, M]
Tensor wscales, // packed ws [K / 64, N]
Tensor oscales, // packed as [N / 64, M]
Tensor poolout, // linear [M / PoolSize, N]
Tensor lora_act_in, // packed lora_act [M, R]
Tensor lora_up, // packed lora_wgt [N, R]
Tensor lora_down, // packed lora_wgt [N, R]
Tensor lora_act_out, // packed lora_act [M, R]
Tensor norm_q, // linear [HEAD_DIM]
Tensor norm_k, // linear [HEAD_DIM]
Tensor rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
Tensor bias, // packed ws [N]
Tensor smooth_factor, // packed ws [N], for quantization of the next layer
Tensor out_vk, // linear [B, num_heads, head_dim + 1, head_dim]
Tensor out_linearattn,// linear [B, (M), N / 3]
bool act_unsigned,
std::vector<float> lora_scales, // [R / 16]
bool fuse_silu
);
static void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu);
static void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales);
static void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales);
static void linearattn_vk_mul_q(Tensor q, Tensor vk);
};
}; // namespace nunchaku::kernels
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment