Commit 2cb9a2c7 authored by fengzch's avatar fengzch
Browse files

fix: compile gemm_w4a4_launch_bf16_fp4.cu complete

parent 336949c2
...@@ -88,7 +88,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4( ...@@ -88,7 +88,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
if constexpr (!USE_FP4) { if constexpr (!USE_FP4) {
dispatchBool(act_unsigned, [&]<bool ACT_UNSIGNED>() { dispatchBool(act_unsigned, [&]<bool ACT_UNSIGNED>() {
auto func = invoke_kernel<typename GEMM::gemm_w4a4_kernel<Epilogue, ACT_UNSIGNED>, auto func = invoke_kernel<typename GEMM::template gemm_w4a4_kernel<Epilogue, ACT_UNSIGNED>,
const packed_act_t *, const packed_act_t *,
const packed_wgt_t *, const packed_wgt_t *,
const packed_ascale_t *, const packed_ascale_t *,
...@@ -126,7 +126,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4( ...@@ -126,7 +126,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
dispatchBool(alpha != 1.0f, [&]<bool USE_ALPHA>() { dispatchBool(alpha != 1.0f, [&]<bool USE_ALPHA>() {
assert(!act_unsigned); assert(!act_unsigned);
auto func = invoke_kernel<typename GEMM::gemm_w4a4_fp4_kernel<Epilogue, USE_ALPHA>, auto func = invoke_kernel<typename GEMM::template gemm_w4a4_fp4_kernel<Epilogue, USE_ALPHA>,
const packed_act_t *, const packed_act_t *,
const packed_wgt_t *, const packed_wgt_t *,
const packed_amscale_t *, const packed_amscale_t *,
...@@ -140,7 +140,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4( ...@@ -140,7 +140,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
bool>; bool>;
if (shmem >= 24 * 1024) { if (shmem >= 24 * 1024) {
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); checkCUDA(cudaFuncSetAttribute(reinterpret_cast<const void*>(func), cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
} }
assert(ascales.dtype() == Tensor::FP8_E4M3); assert(ascales.dtype() == Tensor::FP8_E4M3);
...@@ -495,7 +495,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input ...@@ -495,7 +495,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input
auto func = invoke_kernel<kernel, typename kernel::Arguments>; auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE)); checkCUDA(cudaFuncSetAttribute(reinterpret_cast<const void*>(func), cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, // log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N,
// input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel())); // input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment