Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
29678cd2
Unverified
Commit
29678cd2
authored
Oct 15, 2023
by
Woosuk Kwon
Committed by
GitHub
Oct 15, 2023
Browse files
Minor fix on AWQ kernel launch (#1356)
parent
d0740dff
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
3 deletions
+4
-3
csrc/quantization/awq/gemm_kernels.cu
csrc/quantization/awq/gemm_kernels.cu
+4
-3
No files found.
csrc/quantization/awq/gemm_kernels.cu
View file @
29678cd2
...
@@ -534,6 +534,7 @@ torch::Tensor awq_gemm(
...
@@ -534,6 +534,7 @@ torch::Tensor awq_gemm(
if
(
num_out_channels
%
group_size
!=
0
)
if
(
num_out_channels
%
group_size
!=
0
)
throw
std
::
invalid_argument
(
"OC is not multiple of Group size"
);
throw
std
::
invalid_argument
(
"OC is not multiple of Group size"
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
num_out_channels
%
128
==
0
)
if
(
num_out_channels
%
128
==
0
)
{
{
int
j_factors1
=
num_out_channels
/
128
/
1
;
int
j_factors1
=
num_out_channels
/
128
/
1
;
...
@@ -541,7 +542,7 @@ torch::Tensor awq_gemm(
...
@@ -541,7 +542,7 @@ torch::Tensor awq_gemm(
// threadIdx.x: 32
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
// threadIdx.y: i_factors[2] * j_factors[2]
dim3
threads_per_block
(
32
,
2
);
dim3
threads_per_block
(
32
,
2
);
vllm
::
awq
::
gemm_forward_4bit_cuda_m16n128k32
<<<
num_blocks
,
threads_per_block
>>>
(
vllm
::
awq
::
gemm_forward_4bit_cuda_m16n128k32
<<<
num_blocks
,
threads_per_block
,
0
,
stream
>>>
(
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
}
}
else
if
(
num_out_channels
%
64
==
0
)
else
if
(
num_out_channels
%
64
==
0
)
...
@@ -552,7 +553,7 @@ torch::Tensor awq_gemm(
...
@@ -552,7 +553,7 @@ torch::Tensor awq_gemm(
// threadIdx.x: 32
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
// threadIdx.y: i_factors[2] * j_factors[2]
dim3
threads_per_block
(
32
,
2
);
dim3
threads_per_block
(
32
,
2
);
vllm
::
awq
::
gemm_forward_4bit_cuda_m16n64k32
<<<
num_blocks
,
threads_per_block
>>>
(
vllm
::
awq
::
gemm_forward_4bit_cuda_m16n64k32
<<<
num_blocks
,
threads_per_block
,
0
,
stream
>>>
(
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
}
}
return
_out_feats
.
sum
(
0
);
return
_out_feats
.
sum
(
0
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment