Commit 54e6d065 authored by muyangli's avatar muyangli
Browse files

[major] support NVFP4; upgrade to 0.1

parent c7f41661
This diff is collapsed.
...@@ -12,6 +12,8 @@ class GEMM_W4A4_Launch { ...@@ -12,6 +12,8 @@ class GEMM_W4A4_Launch {
using packed_wgt_t = typename GEMM::packed_wgt_t; using packed_wgt_t = typename GEMM::packed_wgt_t;
using packed_ascale_t = typename GEMM::packed_ascale_t; using packed_ascale_t = typename GEMM::packed_ascale_t;
using packed_wscale_t = typename GEMM::packed_wscale_t; using packed_wscale_t = typename GEMM::packed_wscale_t;
using packed_amscale_t = typename GEMM::packed_amscale_t;
using packed_wmscale_t = typename GEMM::packed_wmscale_t;
using packed_fpsum_t = typename GEMM::packed_fpsum_t; using packed_fpsum_t = typename GEMM::packed_fpsum_t;
using half_t = typename GEMM::half_t; using half_t = typename GEMM::half_t;
...@@ -38,9 +40,12 @@ public: ...@@ -38,9 +40,12 @@ public:
Tensor out_linearattn,// linear [B, (M), N / 3] Tensor out_linearattn,// linear [B, (M), N / 3]
bool act_unsigned, bool act_unsigned,
std::vector<float> lora_scales, // [R / 16] std::vector<float> lora_scales, // [R / 16]
bool fuse_silu bool fuse_silu,
bool fp4,
float alpha,
Tensor wcscales // packed ws [N]
); );
static void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu); static void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu, bool fp4);
static void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales); static void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales);
static void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales); static void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales);
......
...@@ -30,7 +30,10 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4( ...@@ -30,7 +30,10 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
Tensor out_linearattn,// linear [B, (M), N / 3] Tensor out_linearattn,// linear [B, (M), N / 3]
bool act_unsigned, bool act_unsigned,
std::vector<float> lora_scales, // [R / 16] std::vector<float> lora_scales, // [R / 16]
bool fuse_silu bool fuse_silu,
bool fp4,
float alpha,
Tensor wcscales // packed ws [N]
) { ) {
int M = act.numel() / act.shape[-1]; int M = act.numel() / act.shape[-1];
int N = wgt.shape[0]; int N = wgt.shape[0];
...@@ -68,58 +71,111 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4( ...@@ -68,58 +71,111 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
std::swap(grid.x, grid.y); std::swap(grid.x, grid.y);
} }
dispatchBool(act_unsigned, [&]<bool ACT_UNSIGNED>() { dispatchBool(fp4, [&]<bool USE_FP4>() {
// test_sizeof<typename Epilogue::Arguments>(); // test_sizeof<typename Epilogue::Arguments>();
// std::apply([](auto ...args) { // std::apply([](auto ...args) {
// (test_sizeof<decltype(args)>(), ...); // (test_sizeof<decltype(args)>(), ...);
// }, args); // }, args);
using kernel = typename GEMM::gemm_w4a4_kernel<Epilogue, ACT_UNSIGNED>; // constexpr bool FP4_AVAILABLE = __CUDA_ARCH__ >= 1200;
auto func = invoke_kernel<kernel, if constexpr (!USE_FP4) {
const packed_act_t *, dispatchBool(act_unsigned, [&]<bool ACT_UNSIGNED>() {
const packed_wgt_t *, auto func = invoke_kernel<typename GEMM::gemm_w4a4_kernel<Epilogue, ACT_UNSIGNED>,
const packed_ascale_t *, const packed_act_t *,
const packed_wscale_t *, const packed_wgt_t *,
int, int, int, const packed_ascale_t *,
typename Epilogue::Arguments, const packed_wscale_t *,
bool, int, int, int,
bool>; typename Epilogue::Arguments,
bool,
bool>;
if (shmem >= 24 * 1024) {
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
}
assert(alpha == 1.0f);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem>>>(
act.data_ptr<packed_act_t>(),
wgt.data_ptr<packed_wgt_t>(),
ascales.data_ptr<packed_ascale_t>(),
wscales.data_ptr<packed_wscale_t>(),
M, N, K,
args,
swapBlockMN,
false
);
checkCUDA(cudaGetLastError());
});
return;
}
if (shmem >= 24 * 1024) { if constexpr (USE_FP4) {
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); dispatchBool(alpha != 1.0f, [&]<bool USE_ALPHA>() {
assert(!act_unsigned);
auto func = invoke_kernel<typename GEMM::gemm_w4a4_fp4_kernel<Epilogue, USE_ALPHA>,
const packed_act_t *,
const packed_wgt_t *,
const packed_amscale_t *,
const packed_wmscale_t *,
float,
int, int, int,
typename Epilogue::Arguments,
bool,
bool>;
if (shmem >= 24 * 1024) {
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
}
assert(ascales.dtype() == Tensor::FP8_E4M3);
assert(wscales.dtype() == Tensor::FP8_E4M3);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem>>>(
act.data_ptr<packed_act_t>(),
wgt.data_ptr<packed_wgt_t>(),
ascales.data_ptr<packed_amscale_t>(),
wscales.data_ptr<packed_wmscale_t>(),
alpha,
M, N, K,
args,
swapBlockMN,
false
);
checkCUDA(cudaGetLastError());
});
return;
} }
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem>>>( // if constexpr (USE_FP4 && !FP4_AVAILABLE) {
act.data_ptr<packed_act_t>(), // throw std::runtime_error("FP4 kernel is not available");
wgt.data_ptr<packed_wgt_t>(), // }
ascales.data_ptr<packed_ascale_t>(),
wscales.data_ptr<packed_wscale_t>(),
M, N, K,
args,
swapBlockMN,
false
);
checkCUDA(cudaGetLastError());
}); });
}; };
auto launch_bias = [&]<typename NextEpilogue>(NextEpilogue::Arguments nextArgs) { auto launch_bias = [&]<typename NextEpilogue>(NextEpilogue::Arguments nextArgs) {
if (!bias.valid()) { assert(!bias.valid() || bias.numel() == N);
return launch.template operator()<NextEpilogue>(nextArgs); assert(!wcscales.valid() || wcscales.numel() == N);
}
dispatchBool(bias.valid(), [&]<bool USE_BIAS>() {
assert(bias.numel() == N); dispatchBool(wcscales.valid(), [&]<bool USE_SCALE>() {
using EpilogueBias = typename GEMM::EpilogueBias<USE_BIAS, USE_SCALE>;
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows // append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device ** // ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using Epilogue = typename GEMM::EpilogueCombination<typename GEMM::EpilogueBias, NextEpilogue, typename GEMM::EpilogueNop>; using Epilogue = typename GEMM::EpilogueCombination<EpilogueBias, NextEpilogue, typename GEMM::EpilogueNop>;
return launch.template operator()<Epilogue>({ return launch.template operator()<Epilogue>({
typename GEMM::EpilogueBias::Arguments{ typename EpilogueBias::Arguments{
.bias = bias.data_ptr<packed_wscale_t>(), .bias = USE_BIAS ? bias.data_ptr<packed_wscale_t>() : nullptr,
}, .scale = USE_SCALE ? wcscales.data_ptr<packed_wscale_t>() : nullptr,
nextArgs, },
{} nextArgs,
{}
});
});
}); });
}; };
// auto launch_bias = launch; // auto launch_bias = launch;
...@@ -206,29 +262,32 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4( ...@@ -206,29 +262,32 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
static constexpr float SHIFT_GELU = 0.171875f; static constexpr float SHIFT_GELU = 0.171875f;
dispatchBool(fp4, [&]<bool USE_FP4>() {
constexpr bool USE_UNSIGNED = !USE_FP4;
using EpilogueQuantize = typename GEMM::EpilogueQuantize<false, USE_UNSIGNED, USE_FP4>;
auto argsQuantize = typename EpilogueQuantize::Arguments{
.qout = qout.data_ptr<packed_act_t>(),
.oscales = oscales.data_ptr<typename EpilogueQuantize::oscales_t>(),
.shift_value = USE_FP4 ? 0.0f : SHIFT_GELU,
.smooth_factor = smooth_factor.data_ptr<packed_wscale_t>()
};
// TODO: check if gelu is needed
if (out.valid()) {
launch_lora.template operator()<typename GEMM::EpilogueCombination<typename GEMM::EpilogueDefault, EpilogueQuantize>, typename GEMM::EpilogueGelu>({
typename GEMM::EpilogueDefault::Arguments{
.out = out.data_ptr<half_t>(),
.actualM = actualM,
.actualN = actualN,
},
argsQuantize
}, {});
} else {
launch_lora.template operator()<EpilogueQuantize, typename GEMM::EpilogueGelu>(argsQuantize, {});
}
});
constexpr bool USE_UNSIGNED = true;
using EpilogueQuantize = typename GEMM::EpilogueQuantize<false, USE_UNSIGNED>;
auto argsQuantize = typename EpilogueQuantize::Arguments{
.qout = qout.data_ptr<packed_act_t>(),
.oscales = oscales.data_ptr<packed_ascale_t>(),
.shift_value = SHIFT_GELU,
.smooth_factor = smooth_factor.data_ptr<packed_wscale_t>()
};
// TODO: check if gelu is needed
if (out.valid()) {
launch_lora.template operator()<typename GEMM::EpilogueCombination<typename GEMM::EpilogueDefault, EpilogueQuantize>, typename GEMM::EpilogueGelu>({
typename GEMM::EpilogueDefault::Arguments{
.out = out.data_ptr<half_t>(),
.actualM = actualM,
.actualN = actualN,
},
argsQuantize
}, {});
} else {
launch_lora.template operator()<EpilogueQuantize, typename GEMM::EpilogueGelu>(argsQuantize, {});
}
} else if (out_linearattn.valid()) { } else if (out_linearattn.valid()) {
assert(out_vk.valid()); assert(out_vk.valid());
...@@ -326,7 +385,7 @@ void GEMM_W4A4_Launch<Config>::linearattn_vk_mul_q(Tensor q, Tensor vk) { ...@@ -326,7 +385,7 @@ void GEMM_W4A4_Launch<Config>::linearattn_vk_mul_q(Tensor q, Tensor vk) {
} }
template<typename Config> template<typename Config>
void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu) { void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu, bool fp4) {
const int actualM = input.numel() / input.shape[-1]; const int actualM = input.numel() / input.shape[-1];
const int actualN = input.shape[-1]; const int actualN = input.shape[-1];
...@@ -338,8 +397,13 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor ...@@ -338,8 +397,13 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor
assert(output.shape[-1] == N / 2); assert(output.shape[-1] == N / 2);
// assert(oscales.dtype() == Tensor::FP16); // assert(oscales.dtype() == Tensor::FP16);
assert(isTypeMatch<half_t>(oscales.dtype())); if (fp4) {
assert(oscales.numel() == M * N / GEMM::WARP_K); assert(oscales.dtype() == Tensor::FP8_E4M3);
assert(oscales.numel() == M * N / GEMM::WARP_K * 4);
} else {
assert(isTypeMatch<half_t>(oscales.dtype()));
assert(oscales.numel() == M * N / GEMM::WARP_K);
}
const int rank = lora_down.shape[1]; const int rank = lora_down.shape[1];
...@@ -354,30 +418,32 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor ...@@ -354,30 +418,32 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor
dispatchVal(rank, LoraRanks(), [&]<int RANK>() { dispatchVal(rank, LoraRanks(), [&]<int RANK>() {
dispatchBool(fuse_glu, [&]<bool FUSE_GLU>() { dispatchBool(fuse_glu, [&]<bool FUSE_GLU>() {
using Lora = typename GEMM::Lora<RANK>; dispatchBool(fp4, [&]<bool USE_FP4>() {
using kernel = typename Lora::quantize_w4a4_fuse_lora_kernel<FUSE_GLU>; using Lora = typename GEMM::Lora<RANK>;
using kernel = typename Lora::quantize_w4a4_fuse_lora_kernel<FUSE_GLU, USE_FP4>;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE>>>(
typename kernel::Arguments{ func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE>>>(
.input = input.data_ptr<half_t>(), typename kernel::Arguments{
.smooth_factor = smooth.valid() ? smooth.data_ptr<packed_wscale_t>() : nullptr, .input = input.data_ptr<half_t>(),
.output = output.data_ptr<packed_act_t>(), .smooth_factor = smooth.valid() ? smooth.data_ptr<packed_wscale_t>() : nullptr,
.oscales = oscales.data_ptr<packed_ascale_t>(), .output = output.data_ptr<packed_act_t>(),
.lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(), .oscales = oscales.data_ptr<typename kernel::oscales_t>(),
.lora_act = lora_act_out.data_ptr<float>(), .lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(),
.M = M, .lora_act = lora_act_out.data_ptr<float>(),
.N = N, .M = M,
.actualM = actualM, .N = N,
.actualN = actualN, .actualM = actualM,
} .actualN = actualN,
); }
checkCUDA(cudaGetLastError()); );
checkCUDA(cudaGetLastError());
});
}); });
}); });
} }
......
...@@ -100,9 +100,9 @@ void gemm_w8a8(Tensor act, // [M, K] ...@@ -100,9 +100,9 @@ void gemm_w8a8(Tensor act, // [M, K]
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows // append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device ** // ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using Epilogue = GEMM::EpilogueCombination<GEMM::EpilogueBias, NextEpilogue, GEMM::EpilogueNop>; using Epilogue = GEMM::EpilogueCombination<GEMM::EpilogueBias<true, false>, NextEpilogue, GEMM::EpilogueNop>;
return launch.template operator()<Epilogue>({ return launch.template operator()<Epilogue>({
GEMM::EpilogueBias::Arguments{ GEMM::EpilogueBias<true, false>::Arguments{
.bias = bias.data_ptr<GEMM::packed_wscale_t>(), .bias = bias.data_ptr<GEMM::packed_wscale_t>(),
}, },
nextArgs, nextArgs,
......
...@@ -27,11 +27,14 @@ void gemm_w4a4( ...@@ -27,11 +27,14 @@ void gemm_w4a4(
Tensor out_linearattn,// linear [B, (M), N / 3] Tensor out_linearattn,// linear [B, (M), N / 3]
bool act_unsigned, bool act_unsigned,
std::vector<float> lora_scales, // [R / 16] std::vector<float> lora_scales, // [R / 16]
bool fuse_silu bool fuse_silu,
bool fp4,
float alpha,
Tensor wcscales
); );
void linearattn_vk_mul_q(Tensor q, Tensor vk); void linearattn_vk_mul_q(Tensor q, Tensor vk);
void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth = {}, bool fuse_glu = false); void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth = {}, bool fuse_glu = false, bool fp4 = false);
void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales); void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales);
void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales); void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment