Unverified Commit 37a27712 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

Merge pull request #340 from mit-han-lab/dev

feat: support PuLID, Double FBCache and TeaCache; better linter
parents c1d6fc84 760ab022
#include "gemm_w4a4_launch_impl.cuh" #include "gemm_w4a4_launch_impl.cuh"
namespace nunchaku::kernels { namespace nunchaku::kernels {
template class GEMM_W4A4_Launch<GEMMConfig_W4A4_BF16, true>; template class GEMM_W4A4_Launch<GEMMConfig_W4A4_BF16, true>;
}; };
\ No newline at end of file
#include "gemm_w4a4_launch_impl.cuh" #include "gemm_w4a4_launch_impl.cuh"
namespace nunchaku::kernels { namespace nunchaku::kernels {
template class GEMM_W4A4_Launch<GEMMConfig_W4A4_BF16, false>; template class GEMM_W4A4_Launch<GEMMConfig_W4A4_BF16, false>;
}; };
\ No newline at end of file
#include "gemm_w4a4_launch_impl.cuh" #include "gemm_w4a4_launch_impl.cuh"
namespace nunchaku::kernels { namespace nunchaku::kernels {
template class GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, true>; template class GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, true>;
}; };
\ No newline at end of file
#include "gemm_w4a4_launch_impl.cuh" #include "gemm_w4a4_launch_impl.cuh"
namespace nunchaku::kernels { namespace nunchaku::kernels {
template class GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>; template class GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>;
}; };
\ No newline at end of file
#include "gemm_w4a4_launch_impl.cuh" #include "gemm_w4a4_launch_impl.cuh"
namespace nunchaku::kernels { namespace nunchaku::kernels {
template class GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16_FasterI2F, false>; template class GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16_FasterI2F, false>;
}; };
\ No newline at end of file
...@@ -9,36 +9,35 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::gemm_w4a4( ...@@ -9,36 +9,35 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::gemm_w4a4(
template<> template<>
void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4( void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
#endif #endif
Tensor act, // packed act [M, K / 2] Tensor act, // packed act [M, K / 2]
Tensor wgt, // packed act [N, K / 2] Tensor wgt, // packed act [N, K / 2]
Tensor out, // linear [M, N] Tensor out, // linear [M, N]
Tensor qout, // packed act [M, N / 2] Tensor qout, // packed act [M, N / 2]
Tensor ascales, // packed as [K / 64, M] Tensor ascales, // packed as [K / 64, M]
Tensor wscales, // packed ws [K / 64, N] Tensor wscales, // packed ws [K / 64, N]
Tensor oscales, // packed as [N / 64, M] Tensor oscales, // packed as [N / 64, M]
Tensor poolout, // linear [M / PoolSize, N] Tensor poolout, // linear [M / PoolSize, N]
Tensor lora_act_in, // packed lora_act [M, R] Tensor lora_act_in, // packed lora_act [M, R]
Tensor lora_up, // packed lora_wgt [N, R] Tensor lora_up, // packed lora_wgt [N, R]
Tensor lora_down, // packed lora_wgt [N, R] Tensor lora_down, // packed lora_wgt [N, R]
Tensor lora_act_out, // packed lora_act [M, R] Tensor lora_act_out, // packed lora_act [M, R]
Tensor norm_q, // linear [HEAD_DIM] Tensor norm_q, // linear [HEAD_DIM]
Tensor norm_k, // linear [HEAD_DIM] Tensor norm_k, // linear [HEAD_DIM]
Tensor rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2] Tensor rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
Tensor bias, // packed ws [N] Tensor bias, // packed ws [N]
Tensor smooth_factor, // packed ws [N], for quantization of the next layer Tensor smooth_factor, // packed ws [N], for quantization of the next layer
Tensor out_vk, // linear [B, num_heads, head_dim + 1, head_dim] Tensor out_vk, // linear [B, num_heads, head_dim + 1, head_dim]
Tensor out_linearattn,// linear [B, (M), N / 3] Tensor out_linearattn, // linear [B, (M), N / 3]
bool act_unsigned, bool act_unsigned,
std::vector<float> lora_scales, // [R / 16] std::vector<float> lora_scales, // [R / 16]
bool fuse_silu, bool fuse_silu,
bool fp4, bool fp4,
float alpha, float alpha,
Tensor wcscales, // packed ws [N] Tensor wcscales, // packed ws [N]
Tensor out_q, // packed attention [B, H, M, D] Tensor out_q, // packed attention [B, H, M, D]
Tensor out_k, // packed attention [B, H, M, D] Tensor out_k, // packed attention [B, H, M, D]
Tensor out_v, // packed attention [B, H, M, D] Tensor out_v, // packed attention [B, H, M, D]
int attn_tokens int attn_tokens) {
) {
#ifdef __INTELLISENSE__ #ifdef __INTELLISENSE__
static constexpr bool USE_FP4 = false; static constexpr bool USE_FP4 = false;
#endif #endif
...@@ -89,32 +88,35 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4( ...@@ -89,32 +88,35 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
if constexpr (!USE_FP4) { if constexpr (!USE_FP4) {
dispatchBool(act_unsigned, [&]<bool ACT_UNSIGNED>() { dispatchBool(act_unsigned, [&]<bool ACT_UNSIGNED>() {
auto func = invoke_kernel<typename GEMM::gemm_w4a4_kernel<Epilogue, ACT_UNSIGNED>, auto func = invoke_kernel<typename GEMM::gemm_w4a4_kernel<Epilogue, ACT_UNSIGNED>,
const packed_act_t *, const packed_act_t *,
const packed_wgt_t *, const packed_wgt_t *,
const packed_ascale_t *, const packed_ascale_t *,
const packed_wscale_t *, const packed_wscale_t *,
int, int, int, int,
typename Epilogue::Arguments, int,
bool, int,
bool>; typename Epilogue::Arguments,
bool,
bool>;
if (shmem >= 24 * 1024) { if (shmem >= 24 * 1024) {
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
} }
assert(alpha == 1.0f); assert(alpha == 1.0f);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>( func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>(
act.data_ptr<packed_act_t>(), act.data_ptr<packed_act_t>(),
wgt.data_ptr<packed_wgt_t>(), wgt.data_ptr<packed_wgt_t>(),
ascales.data_ptr<packed_ascale_t>(), ascales.data_ptr<packed_ascale_t>(),
wscales.data_ptr<packed_wscale_t>(), wscales.data_ptr<packed_wscale_t>(),
M, N, K, M,
N,
K,
args, args,
swapBlockMN, swapBlockMN,
false false);
);
checkCUDA(cudaGetLastError()); checkCUDA(cudaGetLastError());
}); });
return; return;
...@@ -124,16 +126,18 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4( ...@@ -124,16 +126,18 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
dispatchBool(alpha != 1.0f, [&]<bool USE_ALPHA>() { dispatchBool(alpha != 1.0f, [&]<bool USE_ALPHA>() {
assert(!act_unsigned); assert(!act_unsigned);
auto func = invoke_kernel<typename GEMM::gemm_w4a4_fp4_kernel<Epilogue, USE_ALPHA>, auto func = invoke_kernel<typename GEMM::gemm_w4a4_fp4_kernel<Epilogue, USE_ALPHA>,
const packed_act_t *, const packed_act_t *,
const packed_wgt_t *, const packed_wgt_t *,
const packed_amscale_t *, const packed_amscale_t *,
const packed_wmscale_t *, const packed_wmscale_t *,
float, float,
int, int, int, int,
typename Epilogue::Arguments, int,
bool, int,
bool>; typename Epilogue::Arguments,
bool,
bool>;
if (shmem >= 24 * 1024) { if (shmem >= 24 * 1024) {
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
...@@ -141,21 +145,22 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4( ...@@ -141,21 +145,22 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
assert(ascales.dtype() == Tensor::FP8_E4M3); assert(ascales.dtype() == Tensor::FP8_E4M3);
assert(wscales.dtype() == Tensor::FP8_E4M3); assert(wscales.dtype() == Tensor::FP8_E4M3);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>( func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>(
act.data_ptr<packed_act_t>(), act.data_ptr<packed_act_t>(),
wgt.data_ptr<packed_wgt_t>(), wgt.data_ptr<packed_wgt_t>(),
ascales.data_ptr<packed_amscale_t>(), ascales.data_ptr<packed_amscale_t>(),
wscales.data_ptr<packed_wmscale_t>(), wscales.data_ptr<packed_wmscale_t>(),
alpha, alpha,
M, N, K, M,
N,
K,
args, args,
swapBlockMN, swapBlockMN,
false false);
);
checkCUDA(cudaGetLastError()); checkCUDA(cudaGetLastError());
}); });
return; return;
} }
...@@ -171,35 +176,37 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4( ...@@ -171,35 +176,37 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
dispatchBool(bias.valid(), [&]<bool USE_BIAS>() { dispatchBool(bias.valid(), [&]<bool USE_BIAS>() {
dispatchBool(wcscales.valid(), [&]<bool USE_SCALE>() { dispatchBool(wcscales.valid(), [&]<bool USE_SCALE>() {
using EpilogueBias = typename GEMM::EpilogueBias<USE_BIAS, USE_SCALE>; using EpilogueBias = typename GEMM::EpilogueBias<USE_BIAS, USE_SCALE>;
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows // append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code
// on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device ** // ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using Epilogue = typename GEMM::EpilogueCombination<EpilogueBias, NextEpilogue, typename GEMM::EpilogueNop>; using Epilogue =
return launch.template operator()<Epilogue>({ typename GEMM::EpilogueCombination<EpilogueBias, NextEpilogue, typename GEMM::EpilogueNop>;
typename EpilogueBias::Arguments{ return launch.template operator()<Epilogue>(
.bias = USE_BIAS ? bias.data_ptr<packed_wscale_t>() : nullptr, {typename EpilogueBias::Arguments{
.scale = USE_SCALE ? wcscales.data_ptr<packed_wscale_t>() : nullptr, .bias = USE_BIAS ? bias.data_ptr<packed_wscale_t>() : nullptr,
}, .scale = USE_SCALE ? wcscales.data_ptr<packed_wscale_t>() : nullptr,
nextArgs, },
{} nextArgs,
}); {}});
}); });
}); });
}; };
// auto launch_bias = launch; // auto launch_bias = launch;
auto launch_lora = [&]<typename NextEpilogue, typename MidEpilogue>(NextEpilogue::Arguments nextArgs, MidEpilogue::Arguments midArgs) { auto launch_lora = [&]<typename NextEpilogue, typename MidEpilogue>(NextEpilogue::Arguments nextArgs,
MidEpilogue::Arguments midArgs) {
assert(lora_up.valid() == lora_act_in.valid()); assert(lora_up.valid() == lora_act_in.valid());
assert(lora_down.valid() == lora_act_out.valid()); assert(lora_down.valid() == lora_act_out.valid());
const int rank_up = lora_up.valid() ? lora_up.shape[1] : 0; const int rank_up = lora_up.valid() ? lora_up.shape[1] : 0;
const int rank_down = lora_down.valid() ? lora_down.shape[1] : 0; const int rank_down = lora_down.valid() ? lora_down.shape[1] : 0;
if (rank_up == 0) { if (rank_up == 0) {
assert(rank_down == 0); assert(rank_down == 0);
return launch_bias.template operator()<typename GEMM::EpilogueCombination<MidEpilogue, NextEpilogue>>({midArgs, nextArgs}); return launch_bias.template operator()<typename GEMM::EpilogueCombination<MidEpilogue, NextEpilogue>>(
{midArgs, nextArgs});
} }
assert(rank_up % 16 == 0); assert(rank_up % 16 == 0);
assert(lora_up.shape[0] == N); assert(lora_up.shape[0] == N);
...@@ -207,7 +214,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4( ...@@ -207,7 +214,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
assert(lora_act_in.shape[0] == M); assert(lora_act_in.shape[0] == M);
assert(lora_act_in.shape[1] == rank_up); assert(lora_act_in.shape[1] == rank_up);
using LoraUp = Lora; using LoraUp = Lora;
using scale_t = typename LoraUp::scale_t; using scale_t = typename LoraUp::scale_t;
scale_t scales; scale_t scales;
...@@ -218,19 +225,20 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4( ...@@ -218,19 +225,20 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
} }
if (rank_down == 0) { if (rank_down == 0) {
using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp, MidEpilogue, NextEpilogue, typename GEMM::EpilogueNop>; using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp,
return launch_bias.template operator()<Epilogue>({ MidEpilogue,
typename LoraUp::EpilogueLoraUp::Arguments{ NextEpilogue,
.lora_act = lora_act_in.data_ptr<float>(), typename GEMM::EpilogueNop>;
.lora_wgt_up = lora_up.data_ptr<packed_fpsum_t>(), return launch_bias.template operator()<Epilogue>({typename LoraUp::EpilogueLoraUp::Arguments{
.rank = rank_up, .lora_act = lora_act_in.data_ptr<float>(),
.scales = scales, .lora_wgt_up = lora_up.data_ptr<packed_fpsum_t>(),
.alwaysfalse = false, .rank = rank_up,
}, .scales = scales,
midArgs, .alwaysfalse = false,
nextArgs, },
{} midArgs,
}); nextArgs,
{}});
} }
// assert(rank_down == rank_up); // assert(rank_down == rank_up);
...@@ -246,25 +254,27 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4( ...@@ -246,25 +254,27 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
// dispatchVal(rank_down, std::integer_sequence<int, 16, 32, 48, 64, 80>(), [&]<int RANK_DOWN>() { // dispatchVal(rank_down, std::integer_sequence<int, 16, 32, 48, 64, 80>(), [&]<int RANK_DOWN>() {
using LoraDown = LoraUp; // GEMM::Lora<RANK_DOWN>; using LoraDown = LoraUp; // GEMM::Lora<RANK_DOWN>;
using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp, MidEpilogue, typename LoraDown::EpilogueLoraDown, NextEpilogue, typename GEMM::EpilogueNop>; using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp,
return launch_bias.template operator()<Epilogue>({ MidEpilogue,
typename LoraUp::EpilogueLoraUp::Arguments{ typename LoraDown::EpilogueLoraDown,
.lora_act = lora_act_in.data_ptr<float>(), NextEpilogue,
.lora_wgt_up = lora_up.data_ptr<packed_fpsum_t>(), typename GEMM::EpilogueNop>;
.rank = rank_up, return launch_bias.template operator()<Epilogue>({typename LoraUp::EpilogueLoraUp::Arguments{
.scales = scales, .lora_act = lora_act_in.data_ptr<float>(),
.alwaysfalse = false, .lora_wgt_up = lora_up.data_ptr<packed_fpsum_t>(),
}, .rank = rank_up,
midArgs, .scales = scales,
typename LoraDown::EpilogueLoraDown::Arguments{ .alwaysfalse = false,
.lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(), },
.lora_act = lora_act_out.data_ptr<float>(), midArgs,
.rank = rank_down, typename LoraDown::EpilogueLoraDown::Arguments{
.alwaysfalse = false, .lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(),
}, .lora_act = lora_act_out.data_ptr<float>(),
nextArgs, .rank = rank_down,
{} .alwaysfalse = false,
}); },
nextArgs,
{}});
// }); // });
}; };
...@@ -276,29 +286,28 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4( ...@@ -276,29 +286,28 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
static constexpr float SHIFT_GELU = 0.171875f; static constexpr float SHIFT_GELU = 0.171875f;
constexpr bool USE_UNSIGNED = !USE_FP4; constexpr bool USE_UNSIGNED = !USE_FP4;
using EpilogueQuantize = typename GEMM::EpilogueQuantize<false, USE_UNSIGNED, USE_FP4>; using EpilogueQuantize = typename GEMM::EpilogueQuantize<false, USE_UNSIGNED, USE_FP4>;
auto argsQuantize = typename EpilogueQuantize::Arguments{ auto argsQuantize =
.qout = qout.data_ptr<packed_act_t>(), typename EpilogueQuantize::Arguments{.qout = qout.data_ptr<packed_act_t>(),
.oscales = oscales.data_ptr<typename EpilogueQuantize::oscales_t>(), .oscales = oscales.data_ptr<typename EpilogueQuantize::oscales_t>(),
.shift_value = USE_FP4 ? 0.0f : SHIFT_GELU, .shift_value = USE_FP4 ? 0.0f : SHIFT_GELU,
.smooth_factor = smooth_factor.data_ptr<packed_wscale_t>() .smooth_factor = smooth_factor.data_ptr<packed_wscale_t>()};
};
// TODO: check if gelu is needed // TODO: check if gelu is needed
if (out.valid()) { if (out.valid()) {
launch_lora.template operator()<typename GEMM::EpilogueCombination<typename GEMM::EpilogueDefault, EpilogueQuantize>, typename Epilogues::EpilogueGelu>({ launch_lora.template
typename GEMM::EpilogueDefault::Arguments{ operator()<typename GEMM::EpilogueCombination<typename GEMM::EpilogueDefault, EpilogueQuantize>,
.out = out.data_ptr<half_t>(), typename Epilogues::EpilogueGelu>({typename GEMM::EpilogueDefault::Arguments{
.actualM = actualM, .out = out.data_ptr<half_t>(),
.actualN = actualN, .actualM = actualM,
}, .actualN = actualN,
argsQuantize },
}, {}); argsQuantize},
{});
} else { } else {
launch_lora.template operator()<EpilogueQuantize, typename Epilogues::EpilogueGelu>(argsQuantize, {}); launch_lora.template operator()<EpilogueQuantize, typename Epilogues::EpilogueGelu>(argsQuantize, {});
} }
} else if (out_linearattn.valid()) { } else if (out_linearattn.valid()) {
assert(out_vk.valid()); assert(out_vk.valid());
...@@ -311,7 +320,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4( ...@@ -311,7 +320,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
assert(out_vk.shape[3] == Epilogue::LITELA_HEAD_DIM); assert(out_vk.shape[3] == Epilogue::LITELA_HEAD_DIM);
assert(out_vk.shape[1] * Epilogue::LITELA_HEAD_DIM * 3 == N); assert(out_vk.shape[1] * Epilogue::LITELA_HEAD_DIM * 3 == N);
int batch_size = out_vk.shape[0]; int batch_size = out_vk.shape[0];
int num_heads = out_vk.shape[1]; int num_heads = out_vk.shape[1];
assert(isTypeMatch<half_t>(out_linearattn.dtype())); assert(isTypeMatch<half_t>(out_linearattn.dtype()));
assert(out_linearattn.ndims() == 3); assert(out_linearattn.ndims() == 3);
...@@ -326,12 +335,14 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4( ...@@ -326,12 +335,14 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
out_vk.zero_(); out_vk.zero_();
launch_lora.template operator()<Epilogue, typename GEMM::EpilogueNop>(typename Epilogue::Arguments{ launch_lora.template operator()<Epilogue, typename GEMM::EpilogueNop>(
.out_q = out_linearattn.data_ptr<half_t>(), typename Epilogue::Arguments{
.out_vk = out_vk.data_ptr<float>(), .out_q = out_linearattn.data_ptr<half_t>(),
.num_blocks_per_batch = num_blocks_per_batch, .out_vk = out_vk.data_ptr<float>(),
.actualM = M, .num_blocks_per_batch = num_blocks_per_batch,
}, {}); .actualM = M,
},
{});
} else if (rotary_emb.valid()) { } else if (rotary_emb.valid()) {
assert(norm_q.valid()); assert(norm_q.valid());
...@@ -342,8 +353,9 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4( ...@@ -342,8 +353,9 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
assert(rotary_emb.shape[0] * rotary_emb.shape[1] == M); assert(rotary_emb.shape[0] * rotary_emb.shape[1] == M);
assert(rotary_emb.shape[2] == Epilogues::EpilogueRMSNormRope::HEAD_DIM); assert(rotary_emb.shape[2] == Epilogues::EpilogueRMSNormRope::HEAD_DIM);
// assert(rotary_emb.numel() == M * GEMM::EpilogueQKVProj::HEAD_DIM / 2 * GEMM::EpilogueQKVProj::ROTARY_EMB_NUM_ELEMENTS); // assert(rotary_emb.numel() == M * GEMM::EpilogueQKVProj::HEAD_DIM / 2 *
// launch_lora.template operator()<typename GEMM::EpilogueQKVProj, typename GEMM::EpilogueNop>(typename GEMM::EpilogueQKVProj::Arguments{ // GEMM::EpilogueQKVProj::ROTARY_EMB_NUM_ELEMENTS); launch_lora.template operator()<typename
// GEMM::EpilogueQKVProj, typename GEMM::EpilogueNop>(typename GEMM::EpilogueQKVProj::Arguments{
// .out = out.data_ptr<half_t>(), // .out = out.data_ptr<half_t>(),
// .actualM = actualM, // .actualM = actualM,
// .actualN = actualN, // .actualN = actualN,
...@@ -355,42 +367,48 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4( ...@@ -355,42 +367,48 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
// }, {}); // }, {});
using EpilogueRope = typename Epilogues::EpilogueRMSNormRope; using EpilogueRope = typename Epilogues::EpilogueRMSNormRope;
auto argsRope = typename Epilogues::EpilogueRMSNormRope::Arguments{ auto argsRope = typename Epilogues::EpilogueRMSNormRope::Arguments{
.rotary_emb = rotary_emb.data_ptr<typename EpilogueRope::packed_rotemb_t>(), .rotary_emb = rotary_emb.data_ptr<typename EpilogueRope::packed_rotemb_t>(),
.rmsnorm_weight_q = norm_q.data_ptr<half_t>(), .rmsnorm_weight_q = norm_q.data_ptr<half_t>(),
.rmsnorm_weight_k = norm_k.data_ptr<half_t>(), .rmsnorm_weight_k = norm_k.data_ptr<half_t>(),
.epsilon = 1e-6, .epsilon = 1e-6,
}; };
if (out_q.valid()) { if (out_q.valid()) {
launch_lora.template operator()<typename GEMM::EpilogueCombination<EpilogueRope, typename Epilogues::EpiloguePackQKV>, typename GEMM::EpilogueNop>({ launch_lora.template
argsRope, operator()<typename GEMM::EpilogueCombination<EpilogueRope, typename Epilogues::EpiloguePackQKV>,
typename Epilogues::EpiloguePackQKV::Arguments{ typename GEMM::EpilogueNop>(
.out_q = out_q.data_ptr<typename Epilogues::EpiloguePackQKV::packed_qkv_t>(), {argsRope,
.out_k = out_k.data_ptr<typename Epilogues::EpiloguePackQKV::packed_qkv_t>(), typename Epilogues::EpiloguePackQKV::Arguments{
.out_v = out_v.data_ptr<typename Epilogues::EpiloguePackQKV::packed_qkv_t>(), .out_q = out_q.data_ptr<typename Epilogues::EpiloguePackQKV::packed_qkv_t>(),
.actualM = attn_tokens, .out_k = out_k.data_ptr<typename Epilogues::EpiloguePackQKV::packed_qkv_t>(),
.strideHead_q = int(out_q.stride(1) * out_q.scalar_size() / sizeof(typename Epilogues::EpiloguePackQKV::packed_qkv_t)), .out_v = out_v.data_ptr<typename Epilogues::EpiloguePackQKV::packed_qkv_t>(),
.strideHead_k = int(out_k.stride(1) * out_k.scalar_size() / sizeof(typename Epilogues::EpiloguePackQKV::packed_qkv_t)), .actualM = attn_tokens,
.strideHead_v = int(out_v.stride(1) * out_v.scalar_size() / sizeof(typename Epilogues::EpiloguePackQKV::packed_qkv_t)), .strideHead_q = int(out_q.stride(1) * out_q.scalar_size() /
} sizeof(typename Epilogues::EpiloguePackQKV::packed_qkv_t)),
}, {}); .strideHead_k = int(out_k.stride(1) * out_k.scalar_size() /
sizeof(typename Epilogues::EpiloguePackQKV::packed_qkv_t)),
.strideHead_v = int(out_v.stride(1) * out_v.scalar_size() /
sizeof(typename Epilogues::EpiloguePackQKV::packed_qkv_t)),
}},
{});
} else { } else {
launch_lora.template operator()<typename GEMM::EpilogueCombination<EpilogueRope, typename GEMM::EpilogueDefault>, typename GEMM::EpilogueNop>({ launch_lora
argsRope, .template operator()<typename GEMM::EpilogueCombination<EpilogueRope, typename GEMM::EpilogueDefault>,
typename GEMM::EpilogueDefault::Arguments{ typename GEMM::EpilogueNop>({argsRope,
.out = out.data_ptr<half_t>(), typename GEMM::EpilogueDefault::Arguments{
.actualM = actualM, .out = out.data_ptr<half_t>(),
.actualN = actualN, .actualM = actualM,
} .actualN = actualN,
}, {}); }},
{});
} }
} else if (out.valid()) { } else if (out.valid()) {
using Epilogue = typename GEMM::EpilogueDefault; using Epilogue = typename GEMM::EpilogueDefault;
typename Epilogue::Arguments args{ typename Epilogue::Arguments args{
.out = out.data_ptr<half_t>(), .out = out.data_ptr<half_t>(),
.actualM = actualM, .actualM = actualM,
.actualN = actualN, .actualN = actualN,
}; };
...@@ -410,7 +428,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::linearattn_vk_mul_q(Tensor q, Tensor vk) ...@@ -410,7 +428,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::linearattn_vk_mul_q(Tensor q, Tensor vk)
using Epilogue = typename Epilogues::EpilogueLiteLA; using Epilogue = typename Epilogues::EpilogueLiteLA;
int batch_size = vk.shape[0]; int batch_size = vk.shape[0];
int num_heads = vk.shape[1]; int num_heads = vk.shape[1];
int num_tokens = q.shape[1]; int num_tokens = q.shape[1];
assert(isTypeMatch<half_t>(q.scalar_type())); assert(isTypeMatch<half_t>(q.scalar_type()));
...@@ -423,17 +441,21 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::linearattn_vk_mul_q(Tensor q, Tensor vk) ...@@ -423,17 +441,21 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::linearattn_vk_mul_q(Tensor q, Tensor vk)
BLOCK_SIZE = 128; BLOCK_SIZE = 128;
} }
invoke_kernel<typename Epilogue::vk_mul_q_kernel><<<dim3(ceilDiv(num_tokens, BLOCK_SIZE), num_heads, batch_size), BLOCK_SIZE, 0, getCurrentCUDAStream()>>>( invoke_kernel<typename Epilogue::vk_mul_q_kernel>
q.data_ptr<half_t>(), <<<dim3(ceilDiv(num_tokens, BLOCK_SIZE), num_heads, batch_size), BLOCK_SIZE, 0, getCurrentCUDAStream()>>>(
vk.data_ptr<float>(), q.data_ptr<half_t>(), vk.data_ptr<float>(), 1e-6f, num_tokens);
1e-6f,
num_tokens
);
checkCUDA(cudaGetLastError()); checkCUDA(cudaGetLastError());
} }
template<typename Config, bool USE_FP4> template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu, bool fp4) { void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input,
Tensor output,
Tensor oscales,
Tensor lora_down,
Tensor lora_act_out,
Tensor smooth,
bool fuse_glu,
bool fp4) {
const int actualM = input.numel() / input.shape[-1]; const int actualM = input.numel() / input.shape[-1];
const int actualN = input.shape[-1]; const int actualN = input.shape[-1];
...@@ -475,24 +497,24 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input ...@@ -475,24 +497,24 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE)); checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel())); // log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N,
// input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>( func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
typename kernel::Arguments{ typename kernel::Arguments{
.input = input.data_ptr<half_t>(), .input = input.data_ptr<half_t>(),
.smooth_factor = smooth.valid() ? smooth.data_ptr<packed_wscale_t>() : nullptr, .smooth_factor = smooth.valid() ? smooth.data_ptr<packed_wscale_t>() : nullptr,
.output = output.data_ptr<packed_act_t>(), .output = output.data_ptr<packed_act_t>(),
.oscales = oscales.data_ptr<typename kernel::oscales_t>(), .oscales = oscales.data_ptr<typename kernel::oscales_t>(),
.lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(), .lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(),
.lora_act = lora_act_out.data_ptr<float>(), .lora_act = lora_act_out.data_ptr<float>(),
.lora_rank = rank, .lora_rank = rank,
.M = M, .M = M,
.N = N, .N = N,
.actualM = actualM, .actualM = actualM,
.actualN = actualN, .actualN = actualN,
.alwaysfalse = false, .alwaysfalse = false,
} });
);
checkCUDA(cudaGetLastError()); checkCUDA(cudaGetLastError());
}); });
// }); // });
...@@ -501,7 +523,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input ...@@ -501,7 +523,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input
template<typename Config, bool USE_FP4> template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales) { void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales) {
if constexpr (USE_FP4) { if constexpr (USE_FP4) {
assert(false); // not implemented assert(false); // not implemented
return; return;
} }
...@@ -518,11 +540,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act(Tensor input, Tensor o ...@@ -518,11 +540,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act(Tensor input, Tensor o
dim3 grid(M / GEMM::WARP_M, K / GEMM::WARP_K); dim3 grid(M / GEMM::WARP_M, K / GEMM::WARP_K);
invoke_kernel<typename GEMM::quantize_w4a4_act_kernel><<<grid, GEMM::WARP_SIZE, 0, getCurrentCUDAStream()>>>( invoke_kernel<typename GEMM::quantize_w4a4_act_kernel><<<grid, GEMM::WARP_SIZE, 0, getCurrentCUDAStream()>>>(
input.data_ptr<half_t>(), input.data_ptr<half_t>(), output.data_ptr<packed_act_t>(), oscales.data_ptr<packed_ascale_t>(), K);
output.data_ptr<packed_act_t>(),
oscales.data_ptr<packed_ascale_t>(),
K
);
checkCUDA(cudaGetLastError()); checkCUDA(cudaGetLastError());
} }
...@@ -540,19 +558,15 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_wgt(Tensor input, Tensor o ...@@ -540,19 +558,15 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_wgt(Tensor input, Tensor o
assert(output.ndims() == 2); assert(output.ndims() == 2);
assert(output.shape[0] == N); assert(output.shape[0] == N);
assert(output.shape[1] == K / 2); assert(output.shape[1] == K / 2);
assert(isTypeMatch<half_t>(oscales.dtype())); assert(isTypeMatch<half_t>(oscales.dtype()));
// assert(oscales.dtype() == Tensor::FP16); // assert(oscales.dtype() == Tensor::FP16);
assert(oscales.numel() == N * K / GEMM::WARP_K); assert(oscales.numel() == N * K / GEMM::WARP_K);
dim3 grid(N / GEMM::WARP_N, K / GEMM::WARP_K); dim3 grid(N / GEMM::WARP_N, K / GEMM::WARP_K);
invoke_kernel<typename GEMM::quantize_w4a4_wgt_kernel><<<grid, GEMM::WARP_SIZE, 0, getCurrentCUDAStream()>>>( invoke_kernel<typename GEMM::quantize_w4a4_wgt_kernel><<<grid, GEMM::WARP_SIZE, 0, getCurrentCUDAStream()>>>(
input.data_ptr<half_t>(), input.data_ptr<half_t>(), output.data_ptr<packed_wgt_t>(), oscales.data_ptr<packed_wscale_t>(), K);
output.data_ptr<packed_wgt_t>(),
oscales.data_ptr<packed_wscale_t>(),
K
);
checkCUDA(cudaGetLastError()); checkCUDA(cudaGetLastError());
} }
}; // namespace nunchaku::kernels }; // namespace nunchaku::kernels
\ No newline at end of file
...@@ -11,7 +11,7 @@ void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k ...@@ -11,7 +11,7 @@ void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k
assert(input.shape.dataExtent == output.shape.dataExtent); assert(input.shape.dataExtent == output.shape.dataExtent);
assert(input.scalar_type() == Tensor::FP16); assert(input.scalar_type() == Tensor::FP16);
using GEMM = Epilogues<GEMMConfig_W4A4_FP16>; using GEMM = Epilogues<GEMMConfig_W4A4_FP16>;
using Epilogue = GEMM::EpilogueRMSNormRope; using Epilogue = GEMM::EpilogueRMSNormRope;
assert(M % GEMM::BLOCK_M == 0); assert(M % GEMM::BLOCK_M == 0);
...@@ -26,21 +26,18 @@ void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k ...@@ -26,21 +26,18 @@ void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N); dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>( func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
typename kernel::Arguments{ typename kernel::Arguments{.input = input.data_ptr<GEMM::half_t>(),
.input = input.data_ptr<GEMM::half_t>(), .output = output.data_ptr<GEMM::half_t>(),
.output = output.data_ptr<GEMM::half_t>(), .M = M,
.M = M, .N = N,
.N = N, .actualM = M,
.actualM = M, .actualN = N,
.actualN = N, .argsEpilogue = typename Epilogue::Arguments{
.argsEpilogue = typename Epilogue::Arguments{ .rotary_emb = rotary_emb.data_ptr<typename Epilogue::packed_rotemb_t>(),
.rotary_emb = rotary_emb.data_ptr<typename Epilogue::packed_rotemb_t>(), .rmsnorm_weight_q = norm_q.data_ptr<GEMM::half_t>(),
.rmsnorm_weight_q = norm_q.data_ptr<GEMM::half_t>(), .rmsnorm_weight_k = norm_k.data_ptr<GEMM::half_t>(),
.rmsnorm_weight_k = norm_k.data_ptr<GEMM::half_t>(), .epsilon = 1e-6,
.epsilon = 1e-6, }});
}
}
);
checkCUDA(cudaGetLastError()); checkCUDA(cudaGetLastError());
} }
...@@ -52,7 +49,7 @@ void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int n ...@@ -52,7 +49,7 @@ void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int n
Tensor output = Tensor::empty_like(input); Tensor output = Tensor::empty_like(input);
using GEMM = Epilogues<GEMMConfig_W4A4_FP16>; using GEMM = Epilogues<GEMMConfig_W4A4_FP16>;
using Epilogue = GEMM::EpiloguePackQKV; using Epilogue = GEMM::EpiloguePackQKV;
assert(M % GEMM::BLOCK_M == 0); assert(M % GEMM::BLOCK_M == 0);
...@@ -68,24 +65,25 @@ void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int n ...@@ -68,24 +65,25 @@ void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int n
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>( func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
typename kernel::Arguments{ typename kernel::Arguments{
.input = input.data_ptr<GEMM::half_t>(), .input = input.data_ptr<GEMM::half_t>(),
.output = output.data_ptr<GEMM::half_t>(), .output = output.data_ptr<GEMM::half_t>(),
.M = M, .M = M,
.N = N, .N = N,
.actualM = M, .actualM = M,
.actualN = N, .actualN = N,
.argsEpilogue = typename Epilogue::Arguments{ .argsEpilogue = typename Epilogue::Arguments{
.out_q = out_q.data_ptr<typename Epilogue::packed_qkv_t>(), .out_q = out_q.data_ptr<typename Epilogue::packed_qkv_t>(),
.out_k = out_k.data_ptr<typename Epilogue::packed_qkv_t>(), .out_k = out_k.data_ptr<typename Epilogue::packed_qkv_t>(),
.out_v = out_v.data_ptr<typename Epilogue::packed_qkv_t>(), .out_v = out_v.data_ptr<typename Epilogue::packed_qkv_t>(),
.actualM = numTokens, .actualM = numTokens,
.strideHead_q = int(out_q.stride(1) * out_q.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)), .strideHead_q =
.strideHead_k = int(out_k.stride(1) * out_k.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)), int(out_q.stride(1) * out_q.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)),
.strideHead_v = int(out_v.stride(1) * out_v.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)), .strideHead_k =
} int(out_k.stride(1) * out_k.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)),
} .strideHead_v =
); int(out_v.stride(1) * out_v.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)),
}});
checkCUDA(cudaGetLastError()); checkCUDA(cudaGetLastError());
} }
}; // namespace nunchaku::kernels }; // namespace nunchaku::kernels
\ No newline at end of file
...@@ -17,24 +17,22 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl ...@@ -17,24 +17,22 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
assert(oscales.numel() == M * 1); assert(oscales.numel() == M * 1);
auto launch = [&]<bool FUSE_GLU>() { auto launch = [&]<bool FUSE_GLU>() {
using kernel = GEMM::quantize_w8a8_act_kernel<FUSE_GLU>; using kernel = GEMM::quantize_w8a8_act_kernel<FUSE_GLU>;
assert(kernel::check(M, K)); assert(kernel::check(M, K));
dim3 grid = kernel::gridSize(M, K); dim3 grid = kernel::gridSize(M, K);
dim3 block = kernel::blockSize(M, K); dim3 block = kernel::blockSize(M, K);
auto func = invoke_kernel<kernel, const GEMM::half_t *, GEMM::packed_act_t *, GEMM::packed_ascale_t *, int, bool>; auto func =
invoke_kernel<kernel, const GEMM::half_t *, GEMM::packed_act_t *, GEMM::packed_ascale_t *, int, bool>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, 92160)); checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, 92160));
func<<<grid, block, kernel::smemSize(M, K)>>>( func<<<grid, block, kernel::smemSize(M, K)>>>(input.data_ptr<GEMM::half_t>(),
input.data_ptr<GEMM::half_t>(), output.data_ptr<GEMM::packed_act_t>(),
output.data_ptr<GEMM::packed_act_t>(), oscales.data_ptr<GEMM::packed_ascale_t>(),
oscales.data_ptr<GEMM::packed_ascale_t>(), K,
K, false);
false
);
checkCUDA(cudaGetLastError()); checkCUDA(cudaGetLastError());
}; };
...@@ -45,14 +43,12 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl ...@@ -45,14 +43,12 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
} }
} }
void gemm_w8a8(Tensor act, // [M, K] void gemm_w8a8(Tensor act, // [M, K]
Tensor wgt, // [N, K] Tensor wgt, // [N, K]
Tensor out, // [M, N] Tensor out, // [M, N]
Tensor ascales, // [1, M] Tensor ascales, // [1, M]
Tensor wscales, // [1, N] Tensor wscales, // [1, N]
Tensor bias Tensor bias) {
)
{
using GEMM = GEMM_W8A8; using GEMM = GEMM_W8A8;
int M = act.numel() / act.shape[-1]; int M = act.numel() / act.shape[-1];
...@@ -78,16 +74,18 @@ void gemm_w8a8(Tensor act, // [M, K] ...@@ -78,16 +74,18 @@ void gemm_w8a8(Tensor act, // [M, K]
std::swap(grid.x, grid.y); std::swap(grid.x, grid.y);
} }
invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>><<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS>>>( invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>>
act.data_ptr<GEMM::packed_act_t>(), <<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS>>>(act.data_ptr<GEMM::packed_act_t>(),
wgt.data_ptr<GEMM::packed_wgt_t>(), wgt.data_ptr<GEMM::packed_wgt_t>(),
ascales.data_ptr<GEMM::packed_ascale_t>(), ascales.data_ptr<GEMM::packed_ascale_t>(),
wscales.data_ptr<GEMM::packed_wscale_t>(), wscales.data_ptr<GEMM::packed_wscale_t>(),
// out.valid() ? out.data_ptr<GEMM::half_t>() : nullptr, // out.valid() ? out.data_ptr<GEMM::half_t>() : nullptr,
M, N, K, args, M,
swapBlockMN, N,
false K,
); args,
swapBlockMN,
false);
checkCUDA(cudaGetLastError()); checkCUDA(cudaGetLastError());
}; };
...@@ -98,20 +96,19 @@ void gemm_w8a8(Tensor act, // [M, K] ...@@ -98,20 +96,19 @@ void gemm_w8a8(Tensor act, // [M, K]
assert(bias.numel() == N); assert(bias.numel() == N);
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows // append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on
// Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device ** // ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using Epilogue = GEMM::EpilogueCombination<GEMM::EpilogueBias<true, false>, NextEpilogue, GEMM::EpilogueNop>; using Epilogue = GEMM::EpilogueCombination<GEMM::EpilogueBias<true, false>, NextEpilogue, GEMM::EpilogueNop>;
return launch.template operator()<Epilogue>({ return launch.template operator()<Epilogue>({GEMM::EpilogueBias<true, false>::Arguments{
GEMM::EpilogueBias<true, false>::Arguments{ .bias = bias.data_ptr<GEMM::packed_wscale_t>(),
.bias = bias.data_ptr<GEMM::packed_wscale_t>(), },
}, nextArgs,
nextArgs, {}});
{}
});
}; };
launch_bias.template operator()<GEMM::EpilogueDefault>(GEMM::EpilogueDefault::Arguments{ launch_bias.template operator()<GEMM::EpilogueDefault>(GEMM::EpilogueDefault::Arguments{
.out = out.data_ptr<GEMM::half_t>(), .out = out.data_ptr<GEMM::half_t>(),
.actualM = actualM, .actualM = actualM,
.actualN = actualN, .actualN = actualN,
}); });
...@@ -152,9 +149,9 @@ void gemm_w8a8_fuse_litela( ...@@ -152,9 +149,9 @@ void gemm_w8a8_fuse_litela(
checkCUDA(cudaMemsetAsync(out_vk.data_ptr(), 0, out_vk.buffer->getSize())); checkCUDA(cudaMemsetAsync(out_vk.data_ptr(), 0, out_vk.buffer->getSize()));
auto func = invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>, auto func = invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>,
const GEMM::packed_act_t *, const GEMM::packed_act_t *,
const GEMM::packed_wgt_t *, const GEMM::packed_wgt_t *,
const GEMM::packed_ascale_t *, const GEMM::packed_ascale_t *,
const GEMM::packed_wscale_t *, const GEMM::packed_wscale_t *,
// GEMM::half_t *, // GEMM::half_t *,
...@@ -178,7 +175,7 @@ void gemm_w8a8_fuse_litela( ...@@ -178,7 +175,7 @@ void gemm_w8a8_fuse_litela(
ascales.data_ptr<GEMM::packed_ascale_t>(), ascales.data_ptr<GEMM::packed_ascale_t>(),
wscales.data_ptr<GEMM::packed_wscale_t>(), wscales.data_ptr<GEMM::packed_wscale_t>(),
// nullptr, // nullptr,
M, N, K, epilogueArgs, M, N, K, epilogueArgs,
swapBlockMN, swapBlockMN,
false false
); );
...@@ -193,4 +190,4 @@ void gemm_w8a8_fuse_litela( ...@@ -193,4 +190,4 @@ void gemm_w8a8_fuse_litela(
} }
#endif #endif
}; // namespace nunchaku::kernels }; // namespace nunchaku::kernels
\ No newline at end of file
...@@ -8,48 +8,52 @@ class GEMM_W8A8 : public GEMMBase<GEMMConfig_W8A8> { ...@@ -8,48 +8,52 @@ class GEMM_W8A8 : public GEMMBase<GEMMConfig_W8A8> {
public: public:
using psum_warp = std::array<packed_psum_t, WARP_M_TILES * WARP_N_TILES>; using psum_warp = std::array<packed_psum_t, WARP_M_TILES * WARP_N_TILES>;
__device__ __forceinline__ __device__ __forceinline__ static packed_psum_t mma(packed_act_t act, packed_wgt_t wgt, packed_psum_t psum) {
static packed_psum_t mma(packed_act_t act, packed_wgt_t wgt, packed_psum_t psum) {
// packed_psum_t psum; // packed_psum_t psum;
asm volatile( asm volatile("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " "{%0, %1, %2, %3},"
"{%0, %1, %2, %3}," "{%4, %5, %6, %7},"
"{%4, %5, %6, %7}," "{%8, %9},"
"{%8, %9}," "{%10, %11, %12, %13};\n"
"{%10, %11, %12, %13};\n" : "=r"(psum.data[0]), "=r"(psum.data[1]), "=r"(psum.data[2]), "=r"(psum.data[3])
: : "r"(act.x),
"=r"(psum.data[0]), "=r"(psum.data[1]), "=r"(psum.data[2]), "=r"(psum.data[3]) "r"(act.y),
: "r"(act.z),
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w), "r"(act.w),
"r"(wgt.x), "r"(wgt.y), "r"(wgt.x),
// "r"(0), "r"(0), "r"(0), "r"(0) "r"(wgt.y),
"r"(psum.data[0]), "r"(psum.data[1]), "r"(psum.data[2]), "r"(psum.data[3]) // "r"(0), "r"(0), "r"(0), "r"(0)
); "r"(psum.data[0]),
asm volatile( "r"(psum.data[1]),
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " "r"(psum.data[2]),
"{%0, %1, %2, %3}," "r"(psum.data[3]));
"{%4, %5, %6, %7}," asm volatile("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"{%8, %9}," "{%0, %1, %2, %3},"
"{%10, %11, %12, %13};\n" "{%4, %5, %6, %7},"
: "{%8, %9},"
"=r"(psum.data[4]), "=r"(psum.data[5]), "=r"(psum.data[6]), "=r"(psum.data[7]) "{%10, %11, %12, %13};\n"
: : "=r"(psum.data[4]), "=r"(psum.data[5]), "=r"(psum.data[6]), "=r"(psum.data[7])
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w), : "r"(act.x),
"r"(wgt.z), "r"(wgt.w), "r"(act.y),
// "r"(0), "r"(0), "r"(0), "r"(0) "r"(act.z),
"r"(psum.data[4]), "r"(psum.data[5]), "r"(psum.data[6]), "r"(psum.data[7]) "r"(act.w),
); "r"(wgt.z),
"r"(wgt.w),
// "r"(0), "r"(0), "r"(0), "r"(0)
"r"(psum.data[4]),
"r"(psum.data[5]),
"r"(psum.data[6]),
"r"(psum.data[7]));
return psum; return psum;
} }
__device__ __forceinline__ __device__ __forceinline__ static void compute(act_warp A, wgt_warp W, psum_warp &psum) {
static void compute(act_warp A, wgt_warp W, psum_warp &psum) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE;
#pragma unroll #pragma unroll
for (int j = 0; j < WARP_N_TILES; j++) { for (int j = 0; j < WARP_N_TILES; j++) {
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) { for (int i = 0; i < WARP_M_TILES; i++) {
psum[i * WARP_N_TILES + j] = mma(A[i], W[j], psum[i * WARP_N_TILES + j]); psum[i * WARP_N_TILES + j] = mma(A[i], W[j], psum[i * WARP_N_TILES + j]);
} }
...@@ -62,11 +66,12 @@ public: ...@@ -62,11 +66,12 @@ public:
* oscales is per-warp (in shared memory) * oscales is per-warp (in shared memory)
* output is per-thread (in regs) * output is per-thread (in regs)
* shmem must be at least INSN_M * (INSN_K * sizeof(element) + 16) (16 * 32 = 512 Bytes) * shmem must be at least INSN_M * (INSN_K * sizeof(element) + 16) (16 * 32 = 512 Bytes)
* default to quantize activation, if quantize weight, input should be column-majored and output should be transposed ({x, y, z, w} = {x, z, y, w}) * default to quantize activation, if quantize weight, input should be column-majored and output should be
* transposed ({x, y, z, w} = {x, z, y, w})
*/ */
template<bool input_shmem = false> template<bool input_shmem = false>
__device__ __forceinline__ __device__ __forceinline__ static void
static void quantize_w8a8_warp(const half_t *input, const half_t *oscales, int stride, packed_act_t &output, void *shmem) { quantize_w8a8_warp(const half_t *input, const half_t *oscales, int stride, packed_act_t &output, void *shmem) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
constexpr int QUANTIZE_BITWIDTH = 8; constexpr int QUANTIZE_BITWIDTH = 8;
...@@ -75,28 +80,29 @@ public: ...@@ -75,28 +80,29 @@ public:
// 1 lane = 1 pack // 1 lane = 1 pack
// 1 warp = 32 lanes = 32 packs = 1 packwarp // 1 warp = 32 lanes = 32 packs = 1 packwarp
// a pack is {a0, ..., a7} in figure https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=ex2#mma-16864-a // a pack is {a0, ..., a7} in figure
// PACK_SIZE * 4 = INSN_K / 2 // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=ex2#mma-16864-a PACK_SIZE * 4 =
constexpr int PACK_SIZE = INSN_K / 8; // = 4 for 8bit // INSN_K / 2
constexpr int NUM_PACKS_PER_ROW = INSN_K / PACK_SIZE; constexpr int PACK_SIZE = INSN_K / 8; // = 4 for 8bit
constexpr int NUM_PACKS_PER_ROW = INSN_K / PACK_SIZE;
constexpr int NUM_ROWS_PER_PACKWARP = PACK_SIZE * WARP_SIZE / INSN_K; constexpr int NUM_ROWS_PER_PACKWARP = PACK_SIZE * WARP_SIZE / INSN_K;
constexpr int NUM_PACKWARPS = INSN_M / NUM_ROWS_PER_PACKWARP; constexpr int NUM_PACKWARPS = INSN_M / NUM_ROWS_PER_PACKWARP;
using packed_input = std::array<half_t, PACK_SIZE>; using packed_input = std::array<half_t, PACK_SIZE>;
packed_input packs[NUM_PACKWARPS]; packed_input packs[NUM_PACKWARPS];
// load // load
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) { for (int i = 0; i < NUM_PACKWARPS; i++) {
int rowId = i * NUM_ROWS_PER_PACKWARP + laneId / NUM_PACKS_PER_ROW; int rowId = i * NUM_ROWS_PER_PACKWARP + laneId / NUM_PACKS_PER_ROW;
int colId = laneId % NUM_PACKS_PER_ROW * PACK_SIZE; int colId = laneId % NUM_PACKS_PER_ROW * PACK_SIZE;
packs[i] = load<input_shmem>(reinterpret_cast<const packed_input *>(input + rowId * stride + colId)); packs[i] = load<input_shmem>(reinterpret_cast<const packed_input *>(input + rowId * stride + colId));
} }
// quantize // quantize
using matrix_t = uint32_t[INSN_M][NUM_PACKS_PER_ROW]; using matrix_t = uint32_t[INSN_M][NUM_PACKS_PER_ROW];
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem); matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) { for (int i = 0; i < NUM_PACKWARPS; i++) {
const int row = i * NUM_ROWS_PER_PACKWARP + laneId / NUM_PACKS_PER_ROW; const int row = i * NUM_ROWS_PER_PACKWARP + laneId / NUM_PACKS_PER_ROW;
const int col = laneId % NUM_PACKS_PER_ROW; const int col = laneId % NUM_PACKS_PER_ROW;
...@@ -104,7 +110,7 @@ public: ...@@ -104,7 +110,7 @@ public:
float rscale = cuda_frcp(float(oscales[row])); float rscale = cuda_frcp(float(oscales[row]));
uint32_t qpack = 0; uint32_t qpack = 0;
#pragma unroll #pragma unroll
for (int j = 0; j < PACK_SIZE; j += 2) { for (int j = 0; j < PACK_SIZE; j += 2) {
// half2_t hval = __hmul2(half2_t(rscale, rscale), half2_t(packs[i][j], packs[i][j + 1])); // half2_t hval = __hmul2(half2_t(rscale, rscale), half2_t(packs[i][j], packs[i][j + 1]));
float2 fval = half22float2(half2_t(packs[i][j], packs[i][j + 1])) * float2(rscale, rscale); float2 fval = half22float2(half2_t(packs[i][j], packs[i][j + 1])) * float2(rscale, rscale);
...@@ -113,7 +119,7 @@ public: ...@@ -113,7 +119,7 @@ public:
mat[row][col] = qpack; mat[row][col] = qpack;
} }
__syncwarp(); __syncwarp();
// convert to imma format // convert to imma format
int row = laneId % 16; int row = laneId % 16;
int col = laneId / 16 * 4; int col = laneId / 16 * 4;
...@@ -126,20 +132,20 @@ public: ...@@ -126,20 +132,20 @@ public:
* each warp finds absmax from a row * each warp finds absmax from a row
*/ */
template<bool fuse_glu = false> template<bool fuse_glu = false>
__device__ __forceinline__ __device__ __forceinline__ static half_t
static half_t findmax_warp(const half_t *input, half_t *output_shmem, int K, bool alwaysfalse) { findmax_warp(const half_t *input, half_t *output_shmem, int K, bool alwaysfalse) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
using packed_input = std::array<half2_t, 4>; using packed_input = std::array<half2_t, 4>;
using packed_gated_input = std::array<half_t, 4>; using packed_gated_input = std::array<half_t, 4>;
constexpr int PACK_SIZE = sizeof(packed_input) / sizeof(half_t); constexpr int PACK_SIZE = sizeof(packed_input) / sizeof(half_t);
constexpr int NUM_STAGES = 2; constexpr int NUM_STAGES = 2;
half2_t maxvalue2 = { 0, 0 }; half2_t maxvalue2 = {0, 0};
packed_input pack[NUM_STAGES]; packed_input pack[NUM_STAGES];
#pragma unroll #pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) { for (int k = 0; k < NUM_STAGES - 1; k++) {
const int idx = k * PACK_SIZE * WARP_SIZE + laneId * PACK_SIZE; const int idx = k * PACK_SIZE * WARP_SIZE + laneId * PACK_SIZE;
if (idx < K) { if (idx < K) {
...@@ -155,11 +161,11 @@ public: ...@@ -155,11 +161,11 @@ public:
// TODO: store quantized data to shmem (instead of half) // TODO: store quantized data to shmem (instead of half)
for (int k1 = 0; k1 < ceilDiv(K, PACK_SIZE * WARP_SIZE); k1 += NUM_STAGES) { for (int k1 = 0; k1 < ceilDiv(K, PACK_SIZE * WARP_SIZE); k1 += NUM_STAGES) {
#pragma unroll #pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) { for (int k2 = 0; k2 < NUM_STAGES; k2++) {
const int nextidx = (k1 + k2 + NUM_STAGES - 1) * PACK_SIZE * WARP_SIZE + laneId * PACK_SIZE; const int nextidx = (k1 + k2 + NUM_STAGES - 1) * PACK_SIZE * WARP_SIZE + laneId * PACK_SIZE;
const int nextk2 = (k2 + NUM_STAGES - 1) % NUM_STAGES; const int nextk2 = (k2 + NUM_STAGES - 1) % NUM_STAGES;
if (nextidx < K) { if (nextidx < K) {
pack[nextk2] = load(reinterpret_cast<const packed_input *>(&input[nextidx])); pack[nextk2] = load(reinterpret_cast<const packed_input *>(&input[nextidx]));
...@@ -172,11 +178,11 @@ public: ...@@ -172,11 +178,11 @@ public:
if constexpr (fuse_glu) { if constexpr (fuse_glu) {
packed_gated_input gated; packed_gated_input gated;
#pragma unroll #pragma unroll
for (int j = 0; j < p.size(); j++) { for (int j = 0; j < p.size(); j++) {
gated[j] = p[j].x * gelu_half(p[j].y); gated[j] = p[j].x * gelu_half(p[j].y);
p[j].x = gated[j]; p[j].x = gated[j];
p[j].y = 0; p[j].y = 0;
} }
int idx = (k1 + k2) * PACK_SIZE / 2 * WARP_SIZE + laneId * PACK_SIZE / 2; int idx = (k1 + k2) * PACK_SIZE / 2 * WARP_SIZE + laneId * PACK_SIZE / 2;
...@@ -185,7 +191,7 @@ public: ...@@ -185,7 +191,7 @@ public:
} }
} }
#pragma unroll #pragma unroll
for (int j = 0; j < p.size(); j++) { for (int j = 0; j < p.size(); j++) {
maxvalue2 = __hmax2(maxvalue2, __habs2(p[j])); maxvalue2 = __hmax2(maxvalue2, __habs2(p[j]));
} }
...@@ -194,7 +200,7 @@ public: ...@@ -194,7 +200,7 @@ public:
// unused_var(dummy, alwaysfalse); // unused_var(dummy, alwaysfalse);
#pragma unroll #pragma unroll
for (int mask = 32 / 2; mask > 0; mask /= 2) { for (int mask = 32 / 2; mask > 0; mask /= 2) {
maxvalue2 = __hmax2(maxvalue2, __shfl_xor_sync(~0, maxvalue2, mask)); maxvalue2 = __hmax2(maxvalue2, __shfl_xor_sync(~0, maxvalue2, mask));
} }
...@@ -223,8 +229,8 @@ public: ...@@ -223,8 +229,8 @@ public:
return INSN_M * K2 * sizeof(half_t); return INSN_M * K2 * sizeof(half_t);
} }
__device__ __device__ void
void operator()(const half_t *input, packed_act_t *output, packed_ascale_t *oscales, int K, bool alwaysfalse) { operator()(const half_t *input, packed_act_t *output, packed_ascale_t *oscales, int K, bool alwaysfalse) {
// for quantize kernel // for quantize kernel
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE;
...@@ -232,10 +238,9 @@ public: ...@@ -232,10 +238,9 @@ public:
const int numWarps = blockDim.x / WARP_SIZE; const int numWarps = blockDim.x / WARP_SIZE;
// for GEMM kernel // for GEMM kernel
const int bm = blockIdx.x / (BLOCK_M / WARP_M); const int bm = blockIdx.x / (BLOCK_M / WARP_M);
const int gemmWarpId = blockIdx.x % (BLOCK_M / WARP_M); const int gemmWarpId = blockIdx.x % (BLOCK_M / WARP_M);
__shared__ alignas(128) half_t oscale_shmem[WARP_M]; __shared__ alignas(128) half_t oscale_shmem[WARP_M];
// __shared__ alignas(128) half_t maxv_shmem[WARP_M]; // __shared__ alignas(128) half_t maxv_shmem[WARP_M];
__shared__ alignas(128) uint8_t tmp_shmem[NUM_WARPS][512]; __shared__ alignas(128) uint8_t tmp_shmem[NUM_WARPS][512];
...@@ -249,7 +254,7 @@ public: ...@@ -249,7 +254,7 @@ public:
for (int tileM = 0; tileM < WARP_M_TILES; tileM++) { for (int tileM = 0; tileM < WARP_M_TILES; tileM++) {
for (int i = warpId; i < INSN_M; i += numWarps) { for (int i = warpId; i < INSN_M; i += numWarps) {
const int rowLocal = tileM * INSN_M + i; const int rowLocal = tileM * INSN_M + i;
const int rowGlobal = blockIdx.x * WARP_M + rowLocal; const int rowGlobal = blockIdx.x * WARP_M + rowLocal;
half_t maxv = findmax_warp<fuse_glu>(input + rowGlobal * K, shmem + i * K2, K, alwaysfalse); half_t maxv = findmax_warp<fuse_glu>(input + rowGlobal * K, shmem + i * K2, K, alwaysfalse);
...@@ -260,76 +265,66 @@ public: ...@@ -260,76 +265,66 @@ public:
__syncthreads(); __syncthreads();
for (int bk = warpId; bk < K2 / WARP_K; bk += numWarps) { for (int bk = warpId; bk < K2 / WARP_K; bk += numWarps) {
const int rowLocal = tileM * INSN_M; const int rowLocal = tileM * INSN_M;
const int rowGlobal = blockIdx.x * WARP_M + rowLocal; const int rowGlobal = blockIdx.x * WARP_M + rowLocal;
const int col = bk * WARP_K; const int col = bk * WARP_K;
packed_act_t tmpout; packed_act_t tmpout;
if constexpr (fuse_glu) { if constexpr (fuse_glu) {
quantize_w8a8_warp<true>( quantize_w8a8_warp<true>(shmem + col, oscale_shmem + rowLocal, K2, tmpout, &tmp_shmem[warpId]);
shmem + col,
oscale_shmem + rowLocal,
K2,
tmpout,
&tmp_shmem[warpId]
);
} else { } else {
quantize_w8a8_warp<false>( quantize_w8a8_warp<false>(
input + rowGlobal * K + col, input + rowGlobal * K + col, oscale_shmem + rowLocal, K, tmpout, &tmp_shmem[warpId]);
oscale_shmem + rowLocal,
K,
tmpout,
&tmp_shmem[warpId]
);
} }
store(&output[(((bm * K2 / WARP_K + bk) * NUM_WARPS + gemmWarpId) * WARP_M_TILES + tileM) * WARP_SIZE + laneId], tmpout); store(&output[(((bm * K2 / WARP_K + bk) * NUM_WARPS + gemmWarpId) * WARP_M_TILES + tileM) *
WARP_SIZE +
laneId],
tmpout);
} }
__syncthreads(); __syncthreads();
} }
// [M / BLOCK_M, 1, NUM_WARPS, ASCALES_NUM_PACKS, ASCALES_VALID_LANES] of packed_ascale_t // [M / BLOCK_M, 1, NUM_WARPS, ASCALES_NUM_PACKS, ASCALES_VALID_LANES] of packed_ascale_t
pack_ascales(oscale_shmem, &oscales[(bm * NUM_WARPS + gemmWarpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES]); pack_ascales(oscale_shmem,
&oscales[(bm * NUM_WARPS + gemmWarpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES]);
} }
}; };
__device__ __forceinline__ static gated_fpsum_warp apply_glu(fpsum_warp fpsum) {
__device__ __forceinline__
static gated_fpsum_warp apply_glu(fpsum_warp fpsum) {
gated_fpsum_warp result; gated_fpsum_warp result;
for (int i = 0; i < WARP_M_TILES; i++) { for (int i = 0; i < WARP_M_TILES; i++) {
for (int j = 0; j < WARP_N_TILES; j++) { for (int j = 0; j < WARP_N_TILES; j++) {
for (int k = 0; k < 4; k++) { for (int k = 0; k < 4; k++) {
half_t &dst = result[i * WARP_N_TILES + j].data[k]; half_t &dst = result[i * WARP_N_TILES + j].data[k];
half2_t src = fpsum[i * WARP_N_TILES + j].data[k]; half2_t src = fpsum[i * WARP_N_TILES + j].data[k];
dst = src.x * gelu_half(src.y); dst = src.x * gelu_half(src.y);
} }
} }
} }
return result; return result;
} }
static constexpr int unpack_gated_fpsum_shmem_size = INSN_M * (WARP_N / 2 + 8) * sizeof(half_t); static constexpr int unpack_gated_fpsum_shmem_size = INSN_M * (WARP_N / 2 + 8) * sizeof(half_t);
__device__ __forceinline__ __device__ __forceinline__ static void
static void unpack_gated_fpsum(gated_fpsum_warp fpsum, half_t *output, int stride, void *shmem) { unpack_gated_fpsum(gated_fpsum_warp fpsum, half_t *output, int stride, void *shmem) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
constexpr int PACK_SIZE = WARP_N / 2 / WARP_SIZE; constexpr int PACK_SIZE = WARP_N / 2 / WARP_SIZE;
using pack_t = std::array<half_t, PACK_SIZE>; using pack_t = std::array<half_t, PACK_SIZE>;
// +8 to prevent bank conflicts // +8 to prevent bank conflicts
using matrix_t = half_t[INSN_M][WARP_N / 2 + 8]; using matrix_t = half_t[INSN_M][WARP_N / 2 + 8];
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem); matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
for (int i = 0; i < WARP_M_TILES; i++) { for (int i = 0; i < WARP_M_TILES; i++) {
for (int j = 0; j < WARP_N_TILES; j++) { for (int j = 0; j < WARP_N_TILES; j++) {
packed_gated_fpsum_t &fsum = fpsum[i * WARP_N_TILES + j]; packed_gated_fpsum_t &fsum = fpsum[i * WARP_N_TILES + j];
int row = laneId / 4; int row = laneId / 4;
int col = laneId % 4 + j * INSN_N / 2; int col = laneId % 4 + j * INSN_N / 2;
*reinterpret_cast<half_t *>(&mat[row][col + 0]) = fsum.data[0]; *reinterpret_cast<half_t *>(&mat[row][col + 0]) = fsum.data[0];
*reinterpret_cast<half_t *>(&mat[row][col + 4]) = fsum.data[2]; *reinterpret_cast<half_t *>(&mat[row][col + 4]) = fsum.data[2];
*reinterpret_cast<half_t *>(&mat[row + 8][col + 4]) = fsum.data[1]; *reinterpret_cast<half_t *>(&mat[row + 8][col + 4]) = fsum.data[1];
*reinterpret_cast<half_t *>(&mat[row + 8][col + 4]) = fsum.data[3]; *reinterpret_cast<half_t *>(&mat[row + 8][col + 4]) = fsum.data[3];
} }
...@@ -345,28 +340,27 @@ public: ...@@ -345,28 +340,27 @@ public:
// out: [M, N] <=> [..., NUM_WARPS, WARP_M, N] of half // out: [M, N] <=> [..., NUM_WARPS, WARP_M, N] of half
template<typename Epilogue> template<typename Epilogue>
__device__ __forceinline__ __device__ __forceinline__ static void gemm_w8a8_block(const BlockInfo binfo,
static void gemm_w8a8_block( const packed_act_t *act,
const BlockInfo binfo, const packed_wgt_t *wgt,
const packed_act_t *act, const packed_ascale_t *ascales,
const packed_wgt_t *wgt, const packed_wscale_t *wscales,
const packed_ascale_t *ascales, // half_t *out,
const packed_wscale_t *wscales, int M,
// half_t *out, int N,
int M, int N, int K, int K,
Epilogue::Arguments epilogeParams, Epilogue::Arguments epilogeParams,
bool alwaysfalse) bool alwaysfalse) {
{
constexpr int NUM_STAGES = 2; constexpr int NUM_STAGES = 2;
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE;
act_warp A[NUM_STAGES]; // 8 act_warp A[NUM_STAGES]; // 8
wgt_warp W[NUM_STAGES]; // 32 wgt_warp W[NUM_STAGES]; // 32
ascale_warp ascale; // 1 ascale_warp ascale; // 1
wscale_warp wscale; // 2 wscale_warp wscale; // 2
psum_warp psum; // 128 psum_warp psum; // 128
for (auto &pack : psum) { for (auto &pack : psum) {
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
...@@ -377,7 +371,7 @@ public: ...@@ -377,7 +371,7 @@ public:
// load_wscale<true>(wscales, wscale[0], true); // load_wscale<true>(wscales, wscale[0], true);
// load_wscale<false>(wscales, wscale[1], true); // load_wscale<false>(wscales, wscale[1], true);
// load_wscale<false>(wscales, wscale[2], true); // load_wscale<false>(wscales, wscale[2], true);
load_ascale(ascales, 0, M, ascale, true); load_ascale(ascales, 0, M, ascale, true);
load_wscale(wscales, 0, N, wscale, true); load_wscale(wscales, 0, N, wscale, true);
...@@ -385,14 +379,14 @@ public: ...@@ -385,14 +379,14 @@ public:
load_act(act, k, K, A[k], true); load_act(act, k, K, A[k], true);
load_wgt(wgt, k, K, W[k], true); load_wgt(wgt, k, K, W[k], true);
} }
int dummy = 0; int dummy = 0;
for (int k1 = 0; k1 < K / WARP_K; k1 += NUM_STAGES) { for (int k1 = 0; k1 < K / WARP_K; k1 += NUM_STAGES) {
#pragma unroll #pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) { for (int k2 = 0; k2 < NUM_STAGES; k2++) {
int nextk = k1 + k2 + NUM_STAGES - 1; int nextk = k1 + k2 + NUM_STAGES - 1;
int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES; int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
bool pred = nextk < K / WARP_K; bool pred = nextk < K / WARP_K;
load_act(act, nextk, K, A[idx], pred); load_act(act, nextk, K, A[idx], pred);
load_wgt(wgt, nextk, K, W[idx], pred); load_wgt(wgt, nextk, K, W[idx], pred);
...@@ -421,17 +415,15 @@ public: ...@@ -421,17 +415,15 @@ public:
f32psum_warp f32psum; f32psum_warp f32psum;
#pragma unroll #pragma unroll
for (int i = 0; i < f32psum.size(); i++) { for (int i = 0; i < f32psum.size(); i++) {
#pragma unroll #pragma unroll
for (int j = 0; j < 8; j++) { for (int j = 0; j < 8; j++) {
f32psum[i].data[j] = 0; f32psum[i].data[j] = 0;
} }
} }
apply_scales([&](int i, int j) { apply_scales([&](int i, int j) { return psum[i * WARP_N_TILES + j]; }, ascale, wscale, f32psum);
return psum[i * WARP_N_TILES + j];
}, ascale, wscale, f32psum);
fpsum_warp fpsum = packed_fp32_to_fp16(f32psum); fpsum_warp fpsum = packed_fp32_to_fp16(f32psum);
...@@ -443,27 +435,24 @@ public: ...@@ -443,27 +435,24 @@ public:
Epilogue()(binfo, fpsum, M, N, K, epilogeParams); Epilogue()(binfo, fpsum, M, N, K, epilogeParams);
} }
// out : [M / BLOCK_M, BLOCK_M, N / BLOCK_N, BLOCK_N] // out : [M / BLOCK_M, BLOCK_M, N / BLOCK_N, BLOCK_N]
template<typename Epilogue> template<typename Epilogue>
struct gemm_w8a8_kernel { struct gemm_w8a8_kernel {
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750; static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
__device__ __device__ void operator()(const packed_act_t *act,
void operator()( const packed_wgt_t *wgt,
const packed_act_t *act, const packed_ascale_t *ascales,
const packed_wgt_t *wgt, const packed_wscale_t *wscales,
const packed_ascale_t *ascales, // half_t *out,
const packed_wscale_t *wscales, int M,
// half_t *out, int N,
int M, int N, int K, int K,
Epilogue::Arguments epilogueArgs, Epilogue::Arguments epilogueArgs,
bool swapBlockXY, bool swapBlockXY,
bool alwaysfalse) bool alwaysfalse) {
{
BlockInfo binfo = { BlockInfo binfo = {
.bm = (int)blockIdx.x, .bm = (int)blockIdx.x,
.bn = (int)blockIdx.y, .bn = (int)blockIdx.y,
.numBlocksM = (int)gridDim.x, .numBlocksM = (int)gridDim.x,
.numBlocksN = (int)gridDim.y, .numBlocksN = (int)gridDim.y,
}; };
...@@ -476,25 +465,25 @@ public: ...@@ -476,25 +465,25 @@ public:
const int bm = binfo.bm; const int bm = binfo.bm;
const int bn = binfo.bn; const int bn = binfo.bn;
gemm_w8a8_block<Epilogue>( gemm_w8a8_block<Epilogue>(binfo,
binfo, act + bm * (K / WARP_K) * NUM_WARPS * WARP_M_TILES * WARP_SIZE,
act + bm * (K / WARP_K) * NUM_WARPS * WARP_M_TILES * WARP_SIZE, wgt + bn * (K / WARP_K) * WARP_N_TILES * WARP_SIZE,
wgt + bn * (K / WARP_K) * WARP_N_TILES * WARP_SIZE, ascales + bm * (1) * NUM_WARPS * ASCALES_NUM_PACKS *
ascales + bm * (1) * NUM_WARPS * ASCALES_NUM_PACKS * ASCALES_VALID_LANES, // only 1 group in W8A8 ASCALES_VALID_LANES, // only 1 group in W8A8
wscales + bn * (1) * WSCALES_NUM_PACKS * WSCALES_VALID_LANES, wscales + bn * (1) * WSCALES_NUM_PACKS * WSCALES_VALID_LANES,
// #if 1 // #if 1
// out + (bm * BLOCK_M * N) + bn * BLOCK_N, // out + (bm * BLOCK_M * N) + bn * BLOCK_N,
// #else // #else
// out + (bm * BLOCK_M * N / 2) + bn * BLOCK_N / 2, // out + (bm * BLOCK_M * N / 2) + bn * BLOCK_N / 2,
// #endif // #endif
M, N, K, M,
epilogueArgs, N,
alwaysfalse K,
); epilogueArgs,
alwaysfalse);
} }
}; };
#if 0 #if 0
struct EpilogueGLU { struct EpilogueGLU {
struct Arguments { size_t unused; }; struct Arguments { size_t unused; };
...@@ -510,9 +499,6 @@ public: ...@@ -510,9 +499,6 @@ public:
} }
}; };
#endif #endif
}; };
}; // namespace nunchaku::kernels }; // namespace nunchaku::kernels
\ No newline at end of file
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
#include "gemm_base.cuh" #include "gemm_base.cuh"
namespace nunchaku::kernels { namespace nunchaku::kernels {
template<typename Config> template<typename Config>
...@@ -21,7 +20,7 @@ public: ...@@ -21,7 +20,7 @@ public:
public: public:
static constexpr int MAX_RANK = 1024; static constexpr int MAX_RANK = 1024;
static constexpr int WARP_R = 16; static constexpr int WARP_R = 16;
// static constexpr int LORA_RANK = rank; // static constexpr int LORA_RANK = rank;
static constexpr int LORA_M_TILES = WARP_M / 16; static constexpr int LORA_M_TILES = WARP_M / 16;
...@@ -30,57 +29,57 @@ public: ...@@ -30,57 +29,57 @@ public:
static_assert(LORA_M_TILES == WARP_M_TILES); static_assert(LORA_M_TILES == WARP_M_TILES);
static_assert(LORA_N_TILES == WARP_N_TILES); static_assert(LORA_N_TILES == WARP_N_TILES);
// lora_down: [WARP_M, WARP_N] x [WARP_N, R] (row-wise) = [WARP_M, R] // lora_down: [WARP_M, WARP_N] x [WARP_N, R] (row-wise) = [WARP_M, R]
// lora up: [WARP_M, R] x [WARP_N, R] (col-wise) = [WARP_M, WARP_N] // lora up: [WARP_M, R] x [WARP_N, R] (col-wise) = [WARP_M, WARP_N]
// we use fp32 for lora activation since there's no bf16 reduction in sm_89 :( // we use fp32 for lora activation since there's no bf16 reduction in sm_89 :(
using lora_act_warp = std::array<packed_f32psum_t, LORA_M_TILES * LORA_R_TILES>; using lora_act_warp = std::array<packed_f32psum_t, LORA_M_TILES * LORA_R_TILES>;
using lora_act16_warp = std::array<packed_fpsum_t, LORA_M_TILES * LORA_R_TILES>; using lora_act16_warp = std::array<packed_fpsum_t, LORA_M_TILES * LORA_R_TILES>;
using lora_wgt_warp = std::array<packed_fpsum_t, LORA_N_TILES * LORA_R_TILES>; using lora_wgt_warp = std::array<packed_fpsum_t, LORA_N_TILES * LORA_R_TILES>;
using scale_t = std::array<float, MAX_RANK / 16>; using scale_t = std::array<float, MAX_RANK / 16>;
// lora_wgt: [N / 16, rank / WARP_R, LORA_R_TILES, WARP_SIZE] of packed_fpsum_t // lora_wgt: [N / 16, rank / WARP_R, LORA_R_TILES, WARP_SIZE] of packed_fpsum_t
// [N / 16, rank / 16, WARP_SIZE] // [N / 16, rank / 16, WARP_SIZE]
__device__ __forceinline__ __device__ __forceinline__ static void
static void load_lora_wgt(const packed_fpsum_t *ptr, int rtile, int rank, lora_wgt_warp &result, bool pred) { load_lora_wgt(const packed_fpsum_t *ptr, int rtile, int rank, lora_wgt_warp &result, bool pred) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
const packed_fpsum_t *ptr_lane = &ptr[rtile * LORA_R_TILES * WARP_SIZE + laneId]; const packed_fpsum_t *ptr_lane = &ptr[rtile * LORA_R_TILES * WARP_SIZE + laneId];
const int stride_ntile = rank / 16 * WARP_SIZE; const int stride_ntile = rank / 16 * WARP_SIZE;
unrolled_loop<LORA_N_TILES>([&]<int n>() { unrolled_loop<LORA_N_TILES>([&]<int n>() {
unrolled_loop<LORA_R_TILES>([&]<int r>() { unrolled_loop<LORA_R_TILES>([&]<int r>() {
constexpr int roffset = r * WARP_SIZE; constexpr int roffset = r * WARP_SIZE;
const int noffset = n * stride_ntile; const int noffset = n * stride_ntile;
result[n * LORA_R_TILES + r] = load_pred(ptr_lane + noffset + roffset, pred); result[n * LORA_R_TILES + r] = load_pred(ptr_lane + noffset + roffset, pred);
}); });
}); });
} }
// lora_act: [M / BLOCK_M, rank / WARP_R, NUM_WARPS, LORA_M_TILES, LORA_R_TILES, 8, WARP_SIZE] of float // lora_act: [M / BLOCK_M, rank / WARP_R, NUM_WARPS, LORA_M_TILES, LORA_R_TILES, 8, WARP_SIZE] of float
__device__ __forceinline__ __device__ __forceinline__ static void
static void load_lora_act(const float *ptr, int rtile, lora_act_warp &result, bool pred) { load_lora_act(const float *ptr, int rtile, lora_act_warp &result, bool pred) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE;
const float *ptrlane = &ptr[(rtile * NUM_WARPS + warpId) * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE + laneId]; const float *ptrlane =
&ptr[(rtile * NUM_WARPS + warpId) * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE + laneId];
unrolled_loop<LORA_M_TILES>([&]<int m>() { unrolled_loop<LORA_M_TILES>([&]<int m>() {
unrolled_loop<LORA_R_TILES>([&]<int r>{ unrolled_loop<LORA_R_TILES>([&]<int r> {
constexpr int i = m * LORA_R_TILES + r; constexpr int i = m * LORA_R_TILES + r;
unrolled_loop<8>([&]<int j>() { unrolled_loop<8>([&]<int j>() {
constexpr int offset = i * 8 * WARP_SIZE + j * WARP_SIZE; constexpr int offset = i * 8 * WARP_SIZE + j * WARP_SIZE;
result[i].data[j] = load_pred(ptrlane + offset, pred); // * scales[rtile * LORA_R_TILES + r]; result[i].data[j] = load_pred(ptrlane + offset, pred); // * scales[rtile * LORA_R_TILES + r];
}); });
// CHECK_NAN(tmp, "load_lora_act.tmp"); // CHECK_NAN(tmp, "load_lora_act.tmp");
}); });
}); });
} }
// no vector reduction in sm_89 :( // no vector reduction in sm_89 :(
__device__ __forceinline__ __device__ __forceinline__ static void reduce_lora_act(float *ptr, int rtile, lora_act_warp val, bool pred) {
static void reduce_lora_act(float *ptr, int rtile, lora_act_warp val, bool pred) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE;
...@@ -108,7 +107,6 @@ public: ...@@ -108,7 +107,6 @@ public:
// }); // });
// } // }
struct EpilogueLoraUp { struct EpilogueLoraUp {
struct Arguments { struct Arguments {
const float *lora_act; const float *lora_act;
...@@ -120,19 +118,23 @@ public: ...@@ -120,19 +118,23 @@ public:
bool alwaysfalse; bool alwaysfalse;
}; };
__device__ __forceinline__ __device__ __forceinline__ static void apply_lora_up(fpsum_warp &fpsum,
static void apply_lora_up(fpsum_warp &fpsum, const float *act, const packed_fpsum_t *wgt, const scale_t &scales, int rank, bool alwaysfalse) { const float *act,
const packed_fpsum_t *wgt,
const scale_t &scales,
int rank,
bool alwaysfalse) {
constexpr int NUM_STAGES = 2; constexpr int NUM_STAGES = 2;
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE;
lora_act_warp lora_act[NUM_STAGES]; // 32 lora_act_warp lora_act[NUM_STAGES]; // 32
lora_wgt_warp lora_wgt[NUM_STAGES]; // 64 lora_wgt_warp lora_wgt[NUM_STAGES]; // 64
int dummy = 0; int dummy = 0;
#pragma unroll #pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) { for (int k = 0; k < NUM_STAGES - 1; k++) {
// we have rank > 0 // we have rank > 0
const bool pred = k == 0 ? true : k < rank / WARP_R; const bool pred = k == 0 ? true : k < rank / WARP_R;
...@@ -140,14 +142,14 @@ public: ...@@ -140,14 +142,14 @@ public:
load_lora_wgt(wgt, 0, rank, lora_wgt[k], pred); load_lora_wgt(wgt, 0, rank, lora_wgt[k], pred);
} }
f32psum_warp f32psum = packed_fp16_to_fp32(fpsum); // 128 f32psum_warp f32psum = packed_fp16_to_fp32(fpsum); // 128
auto compute = [&scales](lora_act_warp A, lora_wgt_warp W, f32psum_warp &f32psum, int rtile) ALWAYSINLINE { auto compute = [&scales](lora_act_warp A, lora_wgt_warp W, f32psum_warp &f32psum, int rtile) ALWAYSINLINE {
lora_act16_warp A_fp16; lora_act16_warp A_fp16;
for (int m = 0; m < LORA_M_TILES; m++) { for (int m = 0; m < LORA_M_TILES; m++) {
for (int r = 0; r < LORA_R_TILES; r++) { for (int r = 0; r < LORA_R_TILES; r++) {
packed_f32psum_t pack = A[m * LORA_R_TILES + r]; packed_f32psum_t pack = A[m * LORA_R_TILES + r];
#pragma unroll #pragma unroll
for (int j = 0; j < 8; j++) { for (int j = 0; j < 8; j++) {
pack.data[j] *= scales[rtile * LORA_R_TILES + r]; pack.data[j] *= scales[rtile * LORA_R_TILES + r];
} }
...@@ -159,28 +161,28 @@ public: ...@@ -159,28 +161,28 @@ public:
for (int r = 0; r < LORA_R_TILES; r++) { for (int r = 0; r < LORA_R_TILES; r++) {
CHECK_NAN(lora_act[m * LORA_R_TILES + r], "lora_act"); CHECK_NAN(lora_act[m * LORA_R_TILES + r], "lora_act");
CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "lora_wgt"); CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "lora_wgt");
f32psum[m * WARP_N_TILES + n] = mma_f16xf16_f32(A_fp16[m * LORA_R_TILES + r], W[n * LORA_R_TILES + r], f32psum[m * WARP_N_TILES + n]); f32psum[m * WARP_N_TILES + n] = mma_f16xf16_f32(
A_fp16[m * LORA_R_TILES + r], W[n * LORA_R_TILES + r], f32psum[m * WARP_N_TILES + n]);
} }
} }
} }
}; };
for (int k1 = 0; k1 < rank / WARP_R; k1 += NUM_STAGES) { for (int k1 = 0; k1 < rank / WARP_R; k1 += NUM_STAGES) {
#pragma unroll #pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) { for (int k2 = 0; k2 < NUM_STAGES; k2++) {
if (k1 + k2 >= rank / WARP_R) { if (k1 + k2 >= rank / WARP_R) {
break; break;
} }
int nextk = k1 + k2 + NUM_STAGES - 1; int nextk = k1 + k2 + NUM_STAGES - 1;
int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES; int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
bool pred = nextk < rank / WARP_R; bool pred = nextk < rank / WARP_R;
if (alwaysfalse) { if (alwaysfalse) {
act += kernels::bit_cast<int>(lora_act[k2][0].data[0]); act += kernels::bit_cast<int>(lora_act[k2][0].data[0]);
} }
if (alwaysfalse) { if (alwaysfalse) {
dummy = clock(); dummy = clock();
} }
...@@ -194,25 +196,24 @@ public: ...@@ -194,25 +196,24 @@ public:
// NVCC does not know rank > 0 :( // NVCC does not know rank > 0 :(
// it will generate a branch instruction to skip the initial load // it will generate a branch instruction to skip the initial load
// the branch splits the basic blocks and prevents the overlap of memory access and computing (packed_fp16_to_fp32) // the branch splits the basic blocks and prevents the overlap of memory access and computing
// add fake dependency of loaded data so NVCC will not skip the load // (packed_fp16_to_fp32) add fake dependency of loaded data so NVCC will not skip the load
#pragma unroll #pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) { for (int k = 0; k < NUM_STAGES - 1; k++) {
#pragma unroll #pragma unroll
for (auto &&data : lora_act[k]) { for (auto &&data : lora_act[k]) {
#pragma unroll #pragma unroll
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
dummy ^= kernels::bit_cast<int>(data.data[i]); dummy ^= kernels::bit_cast<int>(data.data[i]);
} }
} }
#pragma unroll #pragma unroll
for (auto &&data : lora_wgt[k]) { for (auto &&data : lora_wgt[k]) {
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
dummy ^= kernels::bit_cast<int>(data.data[i]); dummy ^= kernels::bit_cast<int>(data.data[i]);
} }
} }
} }
unused_var(dummy, alwaysfalse); unused_var(dummy, alwaysfalse);
...@@ -220,21 +221,20 @@ public: ...@@ -220,21 +221,20 @@ public:
fpsum = packed_fp32_to_fp16(f32psum); fpsum = packed_fp32_to_fp16(f32psum);
} }
__device__ __forceinline__ __device__ __forceinline__ void
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) { operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) {
const int bm = binfo.bm; const int bm = binfo.bm;
const int bn = binfo.bn; const int bn = binfo.bn;
CHECK_NAN(fpsum, "fpsum"); CHECK_NAN(fpsum, "fpsum");
apply_lora_up( apply_lora_up(fpsum,
fpsum, args.lora_act +
args.lora_act + bm * (args.rank / WARP_R) * (NUM_WARPS * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), bm * (args.rank / WARP_R) * (NUM_WARPS * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE),
args.lora_wgt_up + bn * (BLOCK_N / 16) * (args.rank / 16) * WARP_SIZE, args.lora_wgt_up + bn * (BLOCK_N / 16) * (args.rank / 16) * WARP_SIZE,
args.scales, args.scales,
args.rank, args.rank,
args.alwaysfalse args.alwaysfalse);
);
CHECK_NAN(fpsum, "fpsum"); CHECK_NAN(fpsum, "fpsum");
} }
...@@ -250,16 +250,16 @@ public: ...@@ -250,16 +250,16 @@ public:
bool alwaysfalse; bool alwaysfalse;
}; };
__device__ __forceinline__ __device__ __forceinline__ static void
static void apply_lora_down(fpsum_warp &fpsum, float *act, const packed_fpsum_t *wgt, int rank, bool alwaysfalse) { apply_lora_down(fpsum_warp &fpsum, float *act, const packed_fpsum_t *wgt, int rank, bool alwaysfalse) {
constexpr int NUM_STAGES = 2; constexpr int NUM_STAGES = 2;
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE;
lora_wgt_warp lora_wgt[NUM_STAGES]; // 64 lora_wgt_warp lora_wgt[NUM_STAGES]; // 64
#pragma unroll #pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) { for (int k = 0; k < NUM_STAGES - 1; k++) {
// we have rank > 0 // we have rank > 0
bool pred = k == 0 ? true : k < rank / WARP_R; bool pred = k == 0 ? true : k < rank / WARP_R;
...@@ -270,11 +270,11 @@ public: ...@@ -270,11 +270,11 @@ public:
lora_act_warp lora_act; lora_act_warp lora_act;
lora_act.fill(packed_f32psum_t::zeros()); lora_act.fill(packed_f32psum_t::zeros());
#pragma unroll #pragma unroll
for (int m = 0; m < LORA_M_TILES; m++) { for (int m = 0; m < LORA_M_TILES; m++) {
#pragma unroll #pragma unroll
for (int n = 0; n < LORA_N_TILES; n++) { for (int n = 0; n < LORA_N_TILES; n++) {
#pragma unroll #pragma unroll
for (int r = 0; r < LORA_R_TILES; r++) { for (int r = 0; r < LORA_R_TILES; r++) {
auto &psum = lora_act[m * LORA_R_TILES + r]; auto &psum = lora_act[m * LORA_R_TILES + r];
...@@ -294,14 +294,14 @@ public: ...@@ -294,14 +294,14 @@ public:
int dummy = 0; int dummy = 0;
for (int k1 = 0; k1 < rank / WARP_R; k1 += NUM_STAGES) { for (int k1 = 0; k1 < rank / WARP_R; k1 += NUM_STAGES) {
#pragma unroll #pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) { for (int k2 = 0; k2 < NUM_STAGES; k2++) {
if (k1 + k2 >= rank / WARP_R) { if (k1 + k2 >= rank / WARP_R) {
break; break;
} }
int nextk = k1 + k2 + NUM_STAGES - 1; int nextk = k1 + k2 + NUM_STAGES - 1;
int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES; int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
bool pred = nextk < rank / WARP_R; bool pred = nextk < rank / WARP_R;
if (alwaysfalse) { if (alwaysfalse) {
...@@ -324,38 +324,33 @@ public: ...@@ -324,38 +324,33 @@ public:
} }
} }
#pragma unroll #pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) { for (int k = 0; k < NUM_STAGES - 1; k++) {
#pragma unroll #pragma unroll
for (auto &&data : lora_wgt[k]) { for (auto &&data : lora_wgt[k]) {
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
dummy ^= kernels::bit_cast<int>(data.data[i]); dummy ^= kernels::bit_cast<int>(data.data[i]);
} }
} }
} }
unused_var(dummy, alwaysfalse); unused_var(dummy, alwaysfalse);
} }
__device__ __forceinline__ __device__ __forceinline__ void
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) { operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) {
const int bm = binfo.bm; const int bm = binfo.bm;
const int bn = binfo.bn; const int bn = binfo.bn;
apply_lora_down( apply_lora_down(fpsum,
fpsum, args.lora_act +
args.lora_act + bm * (args.rank / WARP_R) * (NUM_WARPS * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), bm * (args.rank / WARP_R) * (NUM_WARPS * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE),
args.lora_wgt_down + bn * (BLOCK_N / 16) * (args.rank / 16) * WARP_SIZE, args.lora_wgt_down + bn * (BLOCK_N / 16) * (args.rank / 16) * WARP_SIZE,
args.rank, args.rank,
args.alwaysfalse args.alwaysfalse);
);
} }
}; };
}; };
}; // namespace nunchaku::kernels
}; // namespace nunchaku::kernels
\ No newline at end of file
...@@ -7,183 +7,169 @@ ...@@ -7,183 +7,169 @@
namespace nunchaku::kernels { namespace nunchaku::kernels {
namespace mma_helper { namespace mma_helper {
struct f32 { struct f32 {
static constexpr const char value[] = "f32"; static constexpr const char value[] = "f32";
}; };
struct f16 { struct f16 {
static constexpr const char value[] = "f16"; static constexpr const char value[] = "f16";
}; };
struct bf16 { struct bf16 {
static constexpr const char value[] = "bf16"; static constexpr const char value[] = "bf16";
}; };
struct s32 { struct s32 {
static constexpr const char value[] = "s32"; static constexpr const char value[] = "s32";
};
struct s4 {
static constexpr const char value[] = "s4";
};
struct u4 {
static constexpr const char value[] = "u4";
};
template<bool is_bf16>
using f16bf16 = std::conditional_t<is_bf16, bf16, f16>;
template<bool is_unsigned>
using s4u4 = std::conditional_t<is_unsigned, u4, s4>;
}; };
struct s4 {
static constexpr const char value[] = "s4";
};
struct u4 {
static constexpr const char value[] = "u4";
};
template<bool is_bf16>
using f16bf16 = std::conditional_t<is_bf16, bf16, f16>;
template<bool is_unsigned>
using s4u4 = std::conditional_t<is_unsigned, u4, s4>;
}; // namespace mma_helper
__device__ __forceinline__ __device__ __forceinline__ static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2 b, uint2 c) {
static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2 b, uint2 c) {
uint2 d; uint2 d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile( asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " "{%0, %1},"
"{%0, %1}," "{%2, %3, %4, %5},"
"{%2, %3, %4, %5}," "{%6, %7},"
"{%6, %7}," "{%8, %9};\n"
"{%8, %9};\n" : "=r"(d.x), "=r"(d.y)
: : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y));
"=r"(d.x), "=r"(d.y)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y)
);
#else #else
asm volatile( asm volatile("{"
"{" ".reg .b32 tmp0, tmp1;"
".reg .b32 tmp0, tmp1;" "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " "{tmp0, tmp1},"
"{tmp0, tmp1}," "{%2, %3},"
"{%2, %3}," "{%6},"
"{%6}," "{%8, %9};\n"
"{%8, %9};\n" "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " "{%0, %1},"
"{%0, %1}," "{%4, %5},"
"{%4, %5}," "{%7},"
"{%7}," "{tmp0, tmp1};"
"{tmp0, tmp1};" "}\n"
"}\n" : "=r"(d.x), "=r"(d.y)
: : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y));
"=r"(d.x), "=r"(d.y)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y)
);
#endif #endif
return d; return d;
} }
template<bool is_bf16> template<bool is_bf16>
__device__ __forceinline__ __device__ __forceinline__ static uint4 mma_m16n8k16_f32f16f16f32(uint4 a, uint2 b, uint4 c) {
static uint4 mma_m16n8k16_f32f16f16f32(uint4 a, uint2 b, uint4 c) {
uint4 d; uint4 d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile( asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.%14.%14.f32 "
"mma.sync.aligned.m16n8k16.row.col.f32.%14.%14.f32 " "{%0, %1, %2, %3},"
"{%0, %1, %2, %3}," "{%4, %5, %6, %7},"
"{%4, %5, %6, %7}," "{%8, %9},"
"{%8, %9}," "{%10, %11, %12, %13};\n"
"{%10, %11, %12, %13};\n" : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: : "r"(a.x),
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w) "r"(a.y),
: "r"(a.z),
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(a.w),
"r"(b.x), "r"(b.y), "r"(b.x),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "r"(b.y),
"C"(mma_helper::f16bf16<is_bf16>::value) "r"(c.x),
); "r"(c.y),
"r"(c.z),
"r"(c.w),
"C"(mma_helper::f16bf16<is_bf16>::value));
#else #else
static_assert(!is_bf16); static_assert(!is_bf16);
asm volatile( asm volatile("{"
"{" ".reg .b32 tmp0, tmp1, tmp2, tmp3;"
".reg .b32 tmp0, tmp1, tmp2, tmp3;" "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " "{tmp0, tmp1, tmp2, tmp3},"
"{tmp0, tmp1, tmp2, tmp3}," "{%4, %5},"
"{%4, %5}," "{%8},"
"{%8}," "{%10, %11, %12, %13};\n"
"{%10, %11, %12, %13};\n" "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " "{%0, %1, %2, %3},"
"{%0, %1, %2, %3}," "{%6, %7},"
"{%6, %7}," "{%9},"
"{%9}," "{tmp0, tmp1, tmp2, tmp3};"
"{tmp0, tmp1, tmp2, tmp3};" "}\n"
"}\n" : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w));
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w)
);
#endif #endif
return d; return d;
} }
template<typename AType, typename BType> template<typename AType, typename BType>
__device__ __forceinline__ __device__ __forceinline__ static uint4 mma_m16n8kx_s32common(uint4 a, uint2 b, uint4 c) {
static uint4 mma_m16n8kx_s32common(uint4 a, uint2 b, uint4 c) {
uint4 d; uint4 d;
static constexpr int K = (std::is_same_v<AType, mma_helper::s4> || std::is_same_v<AType, mma_helper::u4>) ? 64 : 32; static constexpr int K = (std::is_same_v<AType, mma_helper::s4> || std::is_same_v<AType, mma_helper::u4>) ? 64 : 32;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile( asm volatile("mma.sync.aligned.m16n8k%14.row.col.s32.%15.%16.s32 "
"mma.sync.aligned.m16n8k%14.row.col.s32.%15.%16.s32 " "{%0, %1, %2, %3},"
"{%0, %1, %2, %3}," "{%4, %5, %6, %7},"
"{%4, %5, %6, %7}," "{%8, %9},"
"{%8, %9}," "{%10, %11, %12, %13};\n"
"{%10, %11, %12, %13};\n" : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: : "r"(a.x),
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w) "r"(a.y),
: "r"(a.z),
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(a.w),
"r"(b.x), "r"(b.y), "r"(b.x),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "r"(b.y),
"n"(K), "r"(c.x),
"C"(AType::value), "r"(c.y),
"C"(BType::value) "r"(c.z),
); "r"(c.w),
"n"(K),
"C"(AType::value),
"C"(BType::value));
#else #else
asm volatile( asm volatile("{"
"{" ".reg .b32 tmp0, tmp1, tmp2, tmp3;"
".reg .b32 tmp0, tmp1, tmp2, tmp3;" "mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 " "{tmp0, tmp1},"
"{tmp0, tmp1}," "{%4},"
"{%4}," "{%8},"
"{%8}," "{%10, %11};\n"
"{%10, %11};\n" "mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 " "{tmp2, tmp3},"
"{tmp2, tmp3}," "{%5},"
"{%5}," "{%8},"
"{%8}," "{%12, %13};\n"
"{%12, %13};\n" "mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 " "{%0, %1},"
"{%0, %1}," "{%6},"
"{%6}," "{%9},"
"{%9}," "{tmp0, tmp1};\n"
"{tmp0, tmp1};\n" "mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 " "{%2, %3},"
"{%2, %3}," "{%7},"
"{%7}," "{%9},"
"{%9}," "{tmp2, tmp3};\n"
"{tmp2, tmp3};\n" "}\n"
"}\n" : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: : "r"(a.x),
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w) "r"(a.y),
: "r"(a.z),
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(a.w),
"r"(b.x), "r"(b.y), "r"(b.x),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "r"(b.y),
"n"(K / 2), "r"(c.x),
"C"(AType::value), "r"(c.y),
"C"(BType::value) "r"(c.z),
); "r"(c.w),
"n"(K / 2),
"C"(AType::value),
"C"(BType::value));
#endif #endif
return d; return d;
} }
}; // namespace nunchaku::kernels }; // namespace nunchaku::kernels
\ No newline at end of file
...@@ -6,156 +6,118 @@ ...@@ -6,156 +6,118 @@
// cuda 12.4- does not support "C" constraint in inline assembly :( // cuda 12.4- does not support "C" constraint in inline assembly :(
// use explicit specialization for now // use explicit specialization for now
namespace nunchaku::kernels { namespace nunchaku::kernels {
namespace mma_helper { namespace mma_helper {
struct f32 { struct f32 {
static constexpr const char value[] = "f32"; static constexpr const char value[] = "f32";
}; };
struct f16 { struct f16 {
static constexpr const char value[] = "f16"; static constexpr const char value[] = "f16";
}; };
struct bf16 { struct bf16 {
static constexpr const char value[] = "bf16"; static constexpr const char value[] = "bf16";
}; };
struct s32 { struct s32 {
static constexpr const char value[] = "s32"; static constexpr const char value[] = "s32";
}; };
struct s4 { struct s4 {
static constexpr const char value[] = "s4"; static constexpr const char value[] = "s4";
};
struct u4 {
static constexpr const char value[] = "u4";
};
template<bool is_bf16>
using f16bf16 = std::conditional_t<is_bf16, bf16, f16>;
template<bool is_unsigned>
using s4u4 = std::conditional_t<is_unsigned, u4, s4>;
}; };
struct u4 {
static constexpr const char value[] = "u4";
};
template<bool is_bf16>
using f16bf16 = std::conditional_t<is_bf16, bf16, f16>;
template<bool is_unsigned>
using s4u4 = std::conditional_t<is_unsigned, u4, s4>;
}; // namespace mma_helper
__device__ __forceinline__ __device__ __forceinline__ static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2 b, uint2 c) {
static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2 b, uint2 c) {
uint2 d; uint2 d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile( asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " "{%0, %1},"
"{%0, %1}," "{%2, %3, %4, %5},"
"{%2, %3, %4, %5}," "{%6, %7},"
"{%6, %7}," "{%8, %9};\n"
"{%8, %9};\n" : "=r"(d.x), "=r"(d.y)
: : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y));
"=r"(d.x), "=r"(d.y)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y)
);
#else #else
asm volatile( asm volatile("{"
"{" ".reg .b32 tmp0, tmp1;"
".reg .b32 tmp0, tmp1;" "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " "{tmp0, tmp1},"
"{tmp0, tmp1}," "{%2, %3},"
"{%2, %3}," "{%6},"
"{%6}," "{%8, %9};\n"
"{%8, %9};\n" "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " "{%0, %1},"
"{%0, %1}," "{%4, %5},"
"{%4, %5}," "{%7},"
"{%7}," "{tmp0, tmp1};"
"{tmp0, tmp1};" "}\n"
"}\n" : "=r"(d.x), "=r"(d.y)
: : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y));
"=r"(d.x), "=r"(d.y)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y)
);
#endif #endif
return d; return d;
} }
template<bool is_bf16> template<bool is_bf16>
__device__ __forceinline__ __device__ __forceinline__ static uint4 mma_m16n8k16_f32f16f16f32(uint4 a, uint2 b, uint4 c) = delete;
static uint4 mma_m16n8k16_f32f16f16f32(uint4 a, uint2 b, uint4 c) = delete;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template<> template<>
__device__ __forceinline__ __device__ __forceinline__ uint4 mma_m16n8k16_f32f16f16f32<true>(uint4 a, uint2 b, uint4 c) {
uint4 mma_m16n8k16_f32f16f16f32<true>(uint4 a, uint2 b, uint4 c) {
uint4 d; uint4 d;
asm volatile( asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " "{%0, %1, %2, %3},"
"{%0, %1, %2, %3}," "{%4, %5, %6, %7},"
"{%4, %5, %6, %7}," "{%8, %9},"
"{%8, %9}," "{%10, %11, %12, %13};\n"
"{%10, %11, %12, %13};\n" : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w));
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w)
);
return d; return d;
} }
#endif #endif
template<> template<>
__device__ __forceinline__ __device__ __forceinline__ uint4 mma_m16n8k16_f32f16f16f32<false>(uint4 a, uint2 b, uint4 c) {
uint4 mma_m16n8k16_f32f16f16f32<false>(uint4 a, uint2 b, uint4 c) {
uint4 d; uint4 d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile( asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " "{%0, %1, %2, %3},"
"{%0, %1, %2, %3}," "{%4, %5, %6, %7},"
"{%4, %5, %6, %7}," "{%8, %9},"
"{%8, %9}," "{%10, %11, %12, %13};\n"
"{%10, %11, %12, %13};\n" : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w));
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w)
);
#else #else
asm volatile( asm volatile("{"
"{" ".reg .b32 tmp0, tmp1, tmp2, tmp3;"
".reg .b32 tmp0, tmp1, tmp2, tmp3;" "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " "{tmp0, tmp1, tmp2, tmp3},"
"{tmp0, tmp1, tmp2, tmp3}," "{%4, %5},"
"{%4, %5}," "{%8},"
"{%8}," "{%10, %11, %12, %13};\n"
"{%10, %11, %12, %13};\n" "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " "{%0, %1, %2, %3},"
"{%0, %1, %2, %3}," "{%6, %7},"
"{%6, %7}," "{%9},"
"{%9}," "{tmp0, tmp1, tmp2, tmp3};"
"{tmp0, tmp1, tmp2, tmp3};" "}\n"
"}\n" : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w));
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w)
);
#endif #endif
return d; return d;
} }
template<typename AType, typename BType> template<typename AType, typename BType>
__device__ __forceinline__ __device__ __forceinline__ static uint4 mma_m16n8kx_s32common(uint4 a, uint2 b, uint4 c) = delete;
static uint4 mma_m16n8kx_s32common(uint4 a, uint2 b, uint4 c) = delete;
template<> template<>
__device__ __forceinline__ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helper::s4>(uint4 a, uint2 b, uint4 c) {
uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helper::s4>(uint4 a, uint2 b, uint4 c) {
uint4 d; uint4 d;
static constexpr int K = 64; static constexpr int K = 64;
...@@ -166,54 +128,50 @@ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helper::s4>(uint4 a, uint2 b, ui ...@@ -166,54 +128,50 @@ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helper::s4>(uint4 a, uint2 b, ui
"{%4, %5, %6, %7}," "{%4, %5, %6, %7},"
"{%8, %9}," "{%8, %9},"
"{%10, %11, %12, %13};\n" "{%10, %11, %12, %13};\n"
: : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w) : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "n"(K));
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
"n"(K)
);
#else #else
asm volatile( asm volatile("{"
"{" ".reg .b32 tmp0, tmp1, tmp2, tmp3;"
".reg .b32 tmp0, tmp1, tmp2, tmp3;" "mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 " "{tmp0, tmp1},"
"{tmp0, tmp1}," "{%4},"
"{%4}," "{%8},"
"{%8}," "{%10, %11};\n"
"{%10, %11};\n" "mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 " "{tmp2, tmp3},"
"{tmp2, tmp3}," "{%5},"
"{%5}," "{%8},"
"{%8}," "{%12, %13};\n"
"{%12, %13};\n" "mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 " "{%0, %1},"
"{%0, %1}," "{%6},"
"{%6}," "{%9},"
"{%9}," "{tmp0, tmp1};\n"
"{tmp0, tmp1};\n" "mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 " "{%2, %3},"
"{%2, %3}," "{%7},"
"{%7}," "{%9},"
"{%9}," "{tmp2, tmp3};\n"
"{tmp2, tmp3};\n" "}\n"
"}\n" : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: : "r"(a.x),
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w) "r"(a.y),
: "r"(a.z),
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(a.w),
"r"(b.x), "r"(b.y), "r"(b.x),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "r"(b.y),
"n"(K / 2) "r"(c.x),
); "r"(c.y),
"r"(c.z),
"r"(c.w),
"n"(K / 2));
#endif #endif
return d; return d;
} }
template<> template<>
__device__ __forceinline__ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helper::s4>(uint4 a, uint2 b, uint4 c) {
uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helper::s4>(uint4 a, uint2 b, uint4 c) {
uint4 d; uint4 d;
static constexpr int K = 64; static constexpr int K = 64;
...@@ -224,50 +182,46 @@ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helper::s4>(uint4 a, uint2 b, ui ...@@ -224,50 +182,46 @@ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helper::s4>(uint4 a, uint2 b, ui
"{%4, %5, %6, %7}," "{%4, %5, %6, %7},"
"{%8, %9}," "{%8, %9},"
"{%10, %11, %12, %13};\n" "{%10, %11, %12, %13};\n"
: : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w) : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "n"(K));
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
"n"(K)
);
#else #else
asm volatile( asm volatile("{"
"{" ".reg .b32 tmp0, tmp1, tmp2, tmp3;"
".reg .b32 tmp0, tmp1, tmp2, tmp3;" "mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 " "{tmp0, tmp1},"
"{tmp0, tmp1}," "{%4},"
"{%4}," "{%8},"
"{%8}," "{%10, %11};\n"
"{%10, %11};\n" "mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 " "{tmp2, tmp3},"
"{tmp2, tmp3}," "{%5},"
"{%5}," "{%8},"
"{%8}," "{%12, %13};\n"
"{%12, %13};\n" "mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 " "{%0, %1},"
"{%0, %1}," "{%6},"
"{%6}," "{%9},"
"{%9}," "{tmp0, tmp1};\n"
"{tmp0, tmp1};\n" "mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 " "{%2, %3},"
"{%2, %3}," "{%7},"
"{%7}," "{%9},"
"{%9}," "{tmp2, tmp3};\n"
"{tmp2, tmp3};\n" "}\n"
"}\n" : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: : "r"(a.x),
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w) "r"(a.y),
: "r"(a.z),
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(a.w),
"r"(b.x), "r"(b.y), "r"(b.x),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "r"(b.y),
"n"(K / 2) "r"(c.x),
); "r"(c.y),
"r"(c.z),
"r"(c.w),
"n"(K / 2));
#endif #endif
return d; return d;
} }
}; // namespace nunchaku::kernels }; // namespace nunchaku::kernels
\ No newline at end of file
...@@ -5,50 +5,55 @@ ...@@ -5,50 +5,55 @@
namespace nunchaku::kernels { namespace nunchaku::kernels {
void gemm_w4a4( void gemm_w4a4(Tensor act, // packed act [M, K / 2]
Tensor act, // packed act [M, K / 2] Tensor wgt, // packed act [N, K / 2]
Tensor wgt, // packed act [N, K / 2] Tensor out, // linear [M, N]
Tensor out, // linear [M, N] Tensor qout, // packed act [M, N / 2]
Tensor qout, // packed act [M, N / 2] Tensor ascales, // packed as [K / 64, M]
Tensor ascales, // packed as [K / 64, M] Tensor wscales, // packed ws [K / 64, N]
Tensor wscales, // packed ws [K / 64, N] Tensor oscales, // packed as [N / 64, M]
Tensor oscales, // packed as [N / 64, M] Tensor poolout, // linear [M / PoolSize, N]
Tensor poolout, // linear [M / PoolSize, N] Tensor lora_act_in, // packed lora_act [M, R]
Tensor lora_act_in, // packed lora_act [M, R] Tensor lora_up, // packed lora_wgt [N, R]
Tensor lora_up, // packed lora_wgt [N, R] Tensor lora_down, // packed lora_wgt [N, R]
Tensor lora_down, // packed lora_wgt [N, R] Tensor lora_act_out, // packed lora_act [M, R]
Tensor lora_act_out, // packed lora_act [M, R] Tensor norm_q, // linear [HEAD_DIM]
Tensor norm_q, // linear [HEAD_DIM] Tensor norm_k, // linear [HEAD_DIM]
Tensor norm_k, // linear [HEAD_DIM] Tensor rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
Tensor rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2] Tensor bias, // packed ws [N]
Tensor bias, // packed ws [N] Tensor smooth_factor, // packed ws [N], for quantization of the next layer
Tensor smooth_factor, // packed ws [N], for quantization of the next layer Tensor out_vk, // linear [B, num_heads, head_dim + 1, head_dim]
Tensor out_vk, // linear [B, num_heads, head_dim + 1, head_dim] Tensor out_linearattn, // linear [B, (M), N / 3]
Tensor out_linearattn,// linear [B, (M), N / 3] bool act_unsigned,
bool act_unsigned, std::vector<float> lora_scales, // [R / 16]
std::vector<float> lora_scales, // [R / 16] bool fuse_silu,
bool fuse_silu, bool fp4,
bool fp4, float alpha,
float alpha, Tensor wcscales,
Tensor wcscales, Tensor out_q, // packed attention [B, H, M, D]
Tensor out_q, // packed attention [B, H, M, D] Tensor out_k, // packed attention [B, H, M, D]
Tensor out_k, // packed attention [B, H, M, D] Tensor out_v, // packed attention [B, H, M, D]
Tensor out_v, // packed attention [B, H, M, D] int attn_tokens);
int attn_tokens
);
void linearattn_vk_mul_q(Tensor q, Tensor vk); void linearattn_vk_mul_q(Tensor q, Tensor vk);
void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth = {}, bool fuse_glu = false, bool fp4 = false); void quantize_w4a4_act_fuse_lora(Tensor input,
Tensor output,
Tensor oscales,
Tensor lora_down,
Tensor lora_act_out,
Tensor smooth = {},
bool fuse_glu = false,
bool fp4 = false);
void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales); void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales);
void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales); void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales);
void gemm_w8a8(Tensor act, // [M, K] void gemm_w8a8(Tensor act, // [M, K]
Tensor wgt, // [N, K] Tensor wgt, // [N, K]
Tensor out, // [M, N] Tensor out, // [M, N]
Tensor ascales, // [1, M] Tensor ascales, // [1, M]
Tensor wscales, // [1, N] Tensor wscales, // [1, N]
Tensor bias // packed ws [N] Tensor bias // packed ws [N]
); );
void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_glu); void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_glu);
...@@ -61,13 +66,11 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl ...@@ -61,13 +66,11 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
// Tensor wscales // [1, N] // Tensor wscales // [1, N]
// ); // );
void attention_fp16( void attention_fp16(Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM] Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM] Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM] Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM]
Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM] float scale);
float scale
);
// EXPERIMENTAL, for sm_75 // EXPERIMENTAL, for sm_75
void set_faster_i2f_mode(std::string mode); void set_faster_i2f_mode(std::string mode);
...@@ -76,4 +79,4 @@ void set_faster_i2f_mode(std::string mode); ...@@ -76,4 +79,4 @@ void set_faster_i2f_mode(std::string mode);
void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k, Tensor rotary_emb); void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k, Tensor rotary_emb);
void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int numTokens); void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int numTokens);
}; // namespace nunchaku::kernels }; // namespace nunchaku::kernels
\ No newline at end of file
#include "layernorm.h" #include "layernorm.h"
#include "kernels/layernorm_kernels.h" #include "kernels/layernorm_kernels.h"
LayerNorm::LayerNorm(int hidden_size, float eps, bool elementwise_affine, Tensor::ScalarType dtype, Device device) : LayerNorm::LayerNorm(int hidden_size, float eps, bool elementwise_affine, Tensor::ScalarType dtype, Device device)
hidden_size(hidden_size), eps(eps) : hidden_size(hidden_size), eps(eps) {
{
if (elementwise_affine) { if (elementwise_affine) {
weight = Tensor::allocate({hidden_size}, dtype, device); weight = Tensor::allocate({hidden_size}, dtype, device);
bias = Tensor::allocate({hidden_size}, dtype, device); bias = Tensor::allocate({hidden_size}, dtype, device);
} }
registerParams registerParams(weight, "weight")(bias, "bias");
(weight, "weight")
(bias, "bias")
;
} }
Tensor LayerNorm::forward(Tensor x) { Tensor LayerNorm::forward(Tensor x) {
...@@ -27,10 +23,23 @@ Tensor RMSNorm::forward(Tensor x) { ...@@ -27,10 +23,23 @@ Tensor RMSNorm::forward(Tensor x) {
return out; return out;
} }
void RMSNormGeneral::forward_with_act_sum(Tensor x, Tensor quantized_hidden_states_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer) { void RMSNormGeneral::forward_with_act_sum(Tensor x,
rms_norm_general_fuse_sum(quantized_hidden_states_buffer, x, this->weight, quantized_sum_buffer, quantized_scale_buffer, variance_epsilon, use_per_token_quant); Tensor quantized_hidden_states_buffer,
Tensor quantized_scale_buffer,
Tensor quantized_sum_buffer) {
rms_norm_general_fuse_sum(quantized_hidden_states_buffer,
x,
this->weight,
quantized_sum_buffer,
quantized_scale_buffer,
variance_epsilon,
use_per_token_quant);
} }
void RMSNormGeneral::forward_wo_act_sum(Tensor x, Tensor quantized_hidden_states_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer) { void RMSNormGeneral::forward_wo_act_sum(Tensor x,
rms_norm_general(quantized_hidden_states_buffer, x, this->weight, quantized_scale_buffer, variance_epsilon, use_per_token_quant); Tensor quantized_hidden_states_buffer,
Tensor quantized_scale_buffer,
Tensor quantized_sum_buffer) {
rms_norm_general(
quantized_hidden_states_buffer, x, this->weight, quantized_scale_buffer, variance_epsilon, use_per_token_quant);
} }
...@@ -20,9 +20,8 @@ private: ...@@ -20,9 +20,8 @@ private:
class RMSNorm : public Module { class RMSNorm : public Module {
public: public:
RMSNorm(int hidden_size, float eps, bool use_quant, Tensor::ScalarType dtype, Device device) : RMSNorm(int hidden_size, float eps, bool use_quant, Tensor::ScalarType dtype, Device device)
use_quant(use_quant), variance_epsilon(eps) : use_quant(use_quant), variance_epsilon(eps) {
{
weight = Tensor::allocate({hidden_size}, dtype, device); weight = Tensor::allocate({hidden_size}, dtype, device);
registerParams(weight, "weight"); registerParams(weight, "weight");
} }
...@@ -36,13 +35,16 @@ public: ...@@ -36,13 +35,16 @@ public:
class RMSNormGeneral { class RMSNormGeneral {
friend class LlamaDecoderLayer; friend class LlamaDecoderLayer;
public: public:
RMSNormGeneral(int hidden_size, bool act_sum, float eps, bool use_per_token_quant, Device device) RMSNormGeneral(int hidden_size, bool act_sum, float eps, bool use_per_token_quant, Device device)
: act_sum(act_sum), use_per_token_quant(use_per_token_quant), variance_epsilon(eps) : act_sum(act_sum), use_per_token_quant(use_per_token_quant), variance_epsilon(eps) {
{
this->weight = Tensor::ones({hidden_size}, Tensor::FP32, device); this->weight = Tensor::ones({hidden_size}, Tensor::FP32, device);
} }
void forward(Tensor x, Tensor quantized_hidden_states_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer) { void forward(Tensor x,
Tensor quantized_hidden_states_buffer,
Tensor quantized_scale_buffer,
Tensor quantized_sum_buffer) {
if (act_sum) { if (act_sum) {
forward_with_act_sum(x, quantized_hidden_states_buffer, quantized_scale_buffer, quantized_sum_buffer); forward_with_act_sum(x, quantized_hidden_states_buffer, quantized_scale_buffer, quantized_sum_buffer);
} else { } else {
...@@ -51,12 +53,18 @@ public: ...@@ -51,12 +53,18 @@ public:
} }
private: private:
void forward_with_act_sum(Tensor x, Tensor quantized_hidden_states_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer); void forward_with_act_sum(Tensor x,
void forward_wo_act_sum(Tensor x, Tensor quantized_hidden_states_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer); Tensor quantized_hidden_states_buffer,
Tensor quantized_scale_buffer,
Tensor quantized_sum_buffer);
void forward_wo_act_sum(Tensor x,
Tensor quantized_hidden_states_buffer,
Tensor quantized_scale_buffer,
Tensor quantized_sum_buffer);
private: private:
const bool act_sum; const bool act_sum;
const bool use_per_token_quant; const bool use_per_token_quant;
const float variance_epsilon; const float variance_epsilon;
Tensor weight; Tensor weight;
}; };
\ No newline at end of file
...@@ -4,103 +4,106 @@ ...@@ -4,103 +4,106 @@
#include "Tensor.h" #include "Tensor.h"
namespace pytorch_compat { namespace pytorch_compat {
inline void TORCH_CHECK(bool cond, const std::string &msg = "") { inline void TORCH_CHECK(bool cond, const std::string &msg = "") {
assert (cond); assert(cond);
} }
template<typename T> template<typename T>
inline void C10_CUDA_CHECK(T ret) { inline void C10_CUDA_CHECK(T ret) {
return checkCUDA(ret); return checkCUDA(ret);
} }
namespace at { namespace at {
using ::Tensor; using ::Tensor;
constexpr auto kFloat32 = Tensor::FP32;
constexpr auto kFloat = Tensor::FP32;
constexpr auto kFloat16 = Tensor::FP16;
constexpr auto kBFloat16 = Tensor::BF16;
constexpr auto kInt32 = Tensor::INT32;
constexpr auto kInt64 = Tensor::INT64;
struct Generator {
Generator() { throw std::runtime_error("Not implemented"); }
std::mutex mutex_;
};
namespace cuda {
using ::getCurrentDeviceProperties;
struct StreamWrapper {
cudaStream_t st;
cudaStream_t stream() const { return st; }
};
inline StreamWrapper getCurrentCUDAStream() {
return StreamWrapper(::getCurrentCUDAStream());
}
struct CUDAGuard {
int dev;
};
namespace detail {
inline Generator getDefaultCUDAGenerator() {
return Generator();
}
}
}
using CUDAGeneratorImpl = Generator;
template<typename T>
std::unique_ptr<Generator> get_generator_or_default(std::optional<Generator> gen, T gen2) {
throw std::runtime_error("Not implemented");
}
}
namespace torch { constexpr auto kFloat32 = Tensor::FP32;
using at::kFloat32; constexpr auto kFloat = Tensor::FP32;
using at::kFloat; constexpr auto kFloat16 = Tensor::FP16;
using at::kFloat16; constexpr auto kBFloat16 = Tensor::BF16;
using at::kBFloat16; constexpr auto kInt32 = Tensor::INT32;
using at::kInt32; constexpr auto kInt64 = Tensor::INT64;
using at::kInt64;
constexpr Device kCUDA = Device::cuda(); struct Generator {
Generator() {
using IntArrayRef = std::vector<int>; throw std::runtime_error("Not implemented");
using TensorOptions = Tensor::TensorOptions;
inline Tensor empty_like(const Tensor &tensor) {
return Tensor::empty_like(tensor);
}
inline Tensor empty(TensorShape shape, Tensor::TensorOptions options) {
return Tensor::empty(shape, options.dtype(), options.device());
}
inline Tensor zeros(TensorShape shape, Tensor::TensorOptions options) {
return Tensor::empty(shape, options.dtype(), options.device()).zero_();
}
namespace nn {
namespace functional {
using PadFuncOptions = std::vector<int>;
inline Tensor pad(Tensor x, PadFuncOptions options) {
throw std::runtime_error("Not implemented");
}
}
}
namespace indexing {
constexpr int None = 0;
struct Slice {
int a;
int b;
};
}
} }
std::mutex mutex_;
};
namespace cuda {
using ::getCurrentDeviceProperties;
namespace c10 { struct StreamWrapper {
using std::optional; cudaStream_t st;
cudaStream_t stream() const {
return st;
} }
};
inline StreamWrapper getCurrentCUDAStream() {
return StreamWrapper(::getCurrentCUDAStream());
}
struct CUDAGuard {
int dev;
};
namespace detail {
inline Generator getDefaultCUDAGenerator() {
return Generator();
}
} // namespace detail
} // namespace cuda
using CUDAGeneratorImpl = Generator;
template<typename T>
std::unique_ptr<Generator> get_generator_or_default(std::optional<Generator> gen, T gen2) {
throw std::runtime_error("Not implemented");
}
} // namespace at
namespace torch {
using at::kFloat32;
using at::kFloat;
using at::kFloat16;
using at::kBFloat16;
using at::kInt32;
using at::kInt64;
constexpr Device kCUDA = Device::cuda();
using IntArrayRef = std::vector<int>;
using TensorOptions = Tensor::TensorOptions;
inline Tensor empty_like(const Tensor &tensor) {
return Tensor::empty_like(tensor);
}
inline Tensor empty(TensorShape shape, Tensor::TensorOptions options) {
return Tensor::empty(shape, options.dtype(), options.device());
}
inline Tensor zeros(TensorShape shape, Tensor::TensorOptions options) {
return Tensor::empty(shape, options.dtype(), options.device()).zero_();
}
namespace nn {
namespace functional {
using PadFuncOptions = std::vector<int>;
inline Tensor pad(Tensor x, PadFuncOptions options) {
throw std::runtime_error("Not implemented");
}
} // namespace functional
} // namespace nn
namespace indexing {
constexpr int None = 0;
struct Slice {
int a;
int b;
};
} // namespace indexing
} // namespace torch
namespace c10 {
using std::optional;
} }
} // namespace pytorch_compat
...@@ -35,7 +35,7 @@ To test visual output correctness, you can: ...@@ -35,7 +35,7 @@ To test visual output correctness, you can:
lpips = compute_lpips(dir1, dir2) lpips = compute_lpips(dir1, dir2)
``` ```
Here, `dir1` should point to the directory containing the reference images, and `dir2` should contain the images generated by your method. Here, `dir1` should point to the directory containing the reference images, and `dir2` should contain the images generated by your method.
### Setting the LPIPS Threshold ### Setting the LPIPS Threshold
...@@ -43,4 +43,4 @@ To pass the test, the LPIPS score must be below a predefined threshold—typical ...@@ -43,4 +43,4 @@ To pass the test, the LPIPS score must be below a predefined threshold—typical
## Acknowledgments ## Acknowledgments
This contribution guide is adapted from [SGLang](https://github.com/sgl-project/sglang/tree/main/test). We thank them for the inspiration. This contribution guide is adapted from [SGLang](https://github.com/sgl-project/sglang/tree/main/test). We thank them for the inspiration.
\ No newline at end of file
...@@ -3,7 +3,6 @@ import random ...@@ -3,7 +3,6 @@ import random
import datasets import datasets
import yaml import yaml
from huggingface_hub import snapshot_download
from nunchaku.utils import fetch_or_download from nunchaku.utils import fetch_or_download
......
import pytest import pytest
from nunchaku.utils import get_precision, is_turing from nunchaku.utils import get_precision, is_turing
from .utils import run_test from .utils import run_test
...@@ -8,7 +9,7 @@ from .utils import run_test ...@@ -8,7 +9,7 @@ from .utils import run_test
@pytest.mark.parametrize( @pytest.mark.parametrize(
"cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips", "cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips",
[ [
(0.12, 1024, 1024, 30, None, 1, 0.212 if get_precision() == "int4" else 0.144), (0.12, 1024, 1024, 30, None, 1, 0.212 if get_precision() == "int4" else 0.161),
], ],
) )
def test_flux_dev_cache( def test_flux_dev_cache(
......
import pytest import pytest
from nunchaku.utils import get_precision, is_turing from nunchaku.utils import get_precision, is_turing
from .utils import run_test from .utils import run_test
...@@ -9,7 +10,7 @@ from .utils import run_test ...@@ -9,7 +10,7 @@ from .utils import run_test
"height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips", "height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips",
[ [
(1024, 1024, 50, "flashattn2", False, 0.139 if get_precision() == "int4" else 0.146), (1024, 1024, 50, "flashattn2", False, 0.139 if get_precision() == "int4" else 0.146),
(2048, 512, 25, "nunchaku-fp16", False, 0.168 if get_precision() == "int4" else 0.133), (2048, 512, 25, "nunchaku-fp16", False, 0.168 if get_precision() == "int4" else 0.156),
], ],
) )
def test_flux_dev( def test_flux_dev(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment