"vscode:/vscode.git/clone" did not exist on "3c47bfdfb42d805ae18355cd59c609358fa1660c"
Unverified Commit 29678cd2 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Minor fix on AWQ kernel launch (#1356)

parent d0740dff
...@@ -534,6 +534,7 @@ torch::Tensor awq_gemm( ...@@ -534,6 +534,7 @@ torch::Tensor awq_gemm(
if (num_out_channels % group_size != 0) if (num_out_channels % group_size != 0)
throw std::invalid_argument("OC is not multiple of Group size"); throw std::invalid_argument("OC is not multiple of Group size");
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (num_out_channels % 128 == 0) if (num_out_channels % 128 == 0)
{ {
int j_factors1 = num_out_channels / 128 / 1; int j_factors1 = num_out_channels / 128 / 1;
...@@ -541,18 +542,18 @@ torch::Tensor awq_gemm( ...@@ -541,18 +542,18 @@ torch::Tensor awq_gemm(
// threadIdx.x: 32 // threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2] // threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2); dim3 threads_per_block(32, 2);
vllm::awq::gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>( vllm::awq::gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block, 0, stream>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
} }
else if (num_out_channels % 64 == 0) else if (num_out_channels % 64 == 0)
{ {
int j_factors1 = num_out_channels / 64 / 1; int j_factors1 = num_out_channels / 64 / 1;
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
// threadIdx.x: 32 // threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2] // threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2); dim3 threads_per_block(32, 2);
vllm::awq::gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>( vllm::awq::gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block, 0, stream>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
} }
return _out_feats.sum(0); return _out_feats.sum(0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment