Commit 038b8469 authored by fengzch's avatar fengzch
Browse files

fix:compile gemm_awq.cu complete

parent 54241df6
......@@ -9,7 +9,7 @@
// #include "../../../nunchaku/csrc/utils.cuh"
#include "../utils.cuh"
#include <cuda_pipeline_primitives.h>
// #include <cuda_pipeline_primitives.h>
#define kInterleave 4
#define OP_M 16
......@@ -46,8 +46,8 @@
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1 * SPLITK); \
dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \
auto kernel_func = gemm_w4a16_T1<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G, SPLITK>; \
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>( \
hipFuncSetAttribute(reinterpret_cast<const void*>(kernel_func), hipFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \
hipLaunchKernelGGL(( kernel_func), dim3(num_blocks), dim3(threads_per_block), kSmemByteSize, 0, \
in_feats, kernel, scales, zeros, out_feats, semaphores, num_in_feats, num_out_channels, num_in_channels);
template<int N>
......@@ -472,11 +472,11 @@ __global__ void gemm_w4a16_T1(f16_t *__restrict__ A,
k_0_0_ld,
0,
k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1)
__pipeline_commit();
// if constexpr (STAGES > 1)
// __pipeline_commit();
}
if constexpr (STAGES > 1)
__pipeline_wait_prior(STAGES - 2);
// if constexpr (STAGES > 1)
// __pipeline_wait_prior(STAGES - 2);
__syncthreads();
share_to_reg_one_stage_A<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(
......@@ -641,17 +641,17 @@ __global__ void gemm_w4a16_T1(f16_t *__restrict__ A,
k_0_0_ld,
iter_k,
k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1) {
__pipeline_commit();
__pipeline_wait_prior(STAGES - 2);
}
// if constexpr (STAGES > 1) {
// __pipeline_commit();
// __pipeline_wait_prior(STAGES - 2);
// }
compute_stage = (k_0_0 + 1) % STAGES;
__syncthreads();
}
}
}
__pipeline_commit();
__pipeline_wait_prior(0);
// __pipeline_commit();
// __pipeline_wait_prior(0);
__syncthreads();
if constexpr (SLICES > 1) {
#pragma unroll
......@@ -1010,11 +1010,11 @@ __global__ void gemm_w4a16_T2(f16_t *__restrict__ A,
k_0_0_ld,
0,
k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1)
__pipeline_commit();
// if constexpr (STAGES > 1)
// __pipeline_commit();
}
if constexpr (STAGES > 1)
__pipeline_wait_prior(STAGES - 2);
// if constexpr (STAGES > 1)
// __pipeline_wait_prior(STAGES - 2);
__syncthreads();
share_to_reg_one_stage_A_T2<f16_t, CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(
......@@ -1165,10 +1165,10 @@ __global__ void gemm_w4a16_T2(f16_t *__restrict__ A,
k_0_0_ld,
iter_k,
k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1) {
__pipeline_commit();
__pipeline_wait_prior(STAGES - 2);
}
// if constexpr (STAGES > 1) {
// __pipeline_commit();
// __pipeline_wait_prior(STAGES - 2);
// }
compute_stage = (k_0_0 + 1) % STAGES;
__syncthreads();
}
......@@ -1277,7 +1277,7 @@ Tensor awq_gemm_forward_cuda(Tensor _in_feats, Tensor _kernel, Tensor _scales, T
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1);
dim3 threads_per_block(WARP_SIZE, NUM_WARPS);
auto kernel_func = gemm_w4a16_T2<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>;
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
cudaFuncSetAttribute(reinterpret_cast<const void*>(kernel_func), cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>(
in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels);
}
......@@ -1357,7 +1357,7 @@ Tensor awq_gemm_forward_cuda(Tensor _in_feats, Tensor _kernel, Tensor _scales, T
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1);
dim3 threads_per_block(WARP_SIZE, NUM_WARPS);
auto kernel_func = gemm_w4a16_T2<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>;
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
cudaFuncSetAttribute(reinterpret_cast<const void*>(kernel_func), cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>(
in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment