Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
e120c9b6
Commit
e120c9b6
authored
Sep 11, 2023
by
Casper Hansen
Browse files
Use CUDA stream
parent
1aa8aebd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
4 deletions
+8
-4
awq_cuda/quantization/gemm_cuda_gen.cu
awq_cuda/quantization/gemm_cuda_gen.cu
+4
-2
awq_cuda/quantization/gemv_cuda.cu
awq_cuda/quantization/gemv_cuda.cu
+4
-2
No files found.
awq_cuda/quantization/gemm_cuda_gen.cu
View file @
e120c9b6
...
@@ -10,6 +10,7 @@
...
@@ -10,6 +10,7 @@
*/
*/
#include <torch/extension.h>
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "gemm_cuda.h"
#include "gemm_cuda.h"
#include "dequantize.cuh"
#include "dequantize.cuh"
#include <cuda_fp16.h>
#include <cuda_fp16.h>
...
@@ -439,6 +440,7 @@ torch::Tensor gemm_forward_cuda(
...
@@ -439,6 +440,7 @@ torch::Tensor gemm_forward_cuda(
auto
scaling_factors
=
reinterpret_cast
<
half
*>
(
_scaling_factors
.
data_ptr
<
at
::
Half
>
());
auto
scaling_factors
=
reinterpret_cast
<
half
*>
(
_scaling_factors
.
data_ptr
<
at
::
Half
>
());
auto
zeros
=
reinterpret_cast
<
int
*>
(
_zeros
.
data_ptr
<
int
>
());
auto
zeros
=
reinterpret_cast
<
int
*>
(
_zeros
.
data_ptr
<
int
>
());
int
group_size
=
num_in_channels
/
_scaling_factors
.
size
(
0
);
int
group_size
=
num_in_channels
/
_scaling_factors
.
size
(
0
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
num_out_channels
%
64
!=
0
)
if
(
num_out_channels
%
64
!=
0
)
throw
std
::
invalid_argument
(
"OC is not multiple of cta_N = 64"
);
throw
std
::
invalid_argument
(
"OC is not multiple of cta_N = 64"
);
...
@@ -456,7 +458,7 @@ torch::Tensor gemm_forward_cuda(
...
@@ -456,7 +458,7 @@ torch::Tensor gemm_forward_cuda(
// threadIdx.x: 32
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
// threadIdx.y: i_factors[2] * j_factors[2]
dim3
threads_per_block
(
32
,
2
);
dim3
threads_per_block
(
32
,
2
);
gemm_forward_4bit_cuda_m16n128k32
<<<
num_blocks
,
threads_per_block
>>>
(
gemm_forward_4bit_cuda_m16n128k32
<<<
num_blocks
,
threads_per_block
,
0
,
stream
>>>
(
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
}
}
else
if
(
num_out_channels
%
64
==
0
)
else
if
(
num_out_channels
%
64
==
0
)
...
@@ -467,7 +469,7 @@ torch::Tensor gemm_forward_cuda(
...
@@ -467,7 +469,7 @@ torch::Tensor gemm_forward_cuda(
// threadIdx.x: 32
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
// threadIdx.y: i_factors[2] * j_factors[2]
dim3
threads_per_block
(
32
,
2
);
dim3
threads_per_block
(
32
,
2
);
gemm_forward_4bit_cuda_m16n64k32
<<<
num_blocks
,
threads_per_block
>>>
(
gemm_forward_4bit_cuda_m16n64k32
<<<
num_blocks
,
threads_per_block
,
0
,
stream
>>>
(
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
}
}
return
_out_feats
.
sum
(
0
);
return
_out_feats
.
sum
(
0
);
...
...
awq_cuda/quantization/gemv_cuda.cu
View file @
e120c9b6
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <stdio.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "gemv_cuda.h"
#include "gemv_cuda.h"
#define VECTORIZE_FACTOR 8
#define VECTORIZE_FACTOR 8
#define Q_VECTORIZE_FACTOR 8
#define Q_VECTORIZE_FACTOR 8
...
@@ -224,9 +225,10 @@ torch::Tensor gemv_forward_cuda(
...
@@ -224,9 +225,10 @@ torch::Tensor gemv_forward_cuda(
int
blockDim_z
=
num_out_feats
;
int
blockDim_z
=
num_out_feats
;
dim3
num_blocks
(
1
,
num_out_channels
/
4
,
num_out_feats
);
dim3
num_blocks
(
1
,
num_out_channels
/
4
,
num_out_feats
);
dim3
num_threads
(
32
,
4
);
dim3
num_threads
(
32
,
4
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
group_size
==
64
)
if
(
group_size
==
64
)
{
{
gemv_kernel_g64
<<<
num_blocks
,
num_threads
>>>
(
gemv_kernel_g64
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
// pointers
// pointers
in_feats
,
kernel
,
zeros
,
scaling_factors
,
out_feats
,
in_feats
,
kernel
,
zeros
,
scaling_factors
,
out_feats
,
// constants
// constants
...
@@ -235,7 +237,7 @@ torch::Tensor gemv_forward_cuda(
...
@@ -235,7 +237,7 @@ torch::Tensor gemv_forward_cuda(
}
}
else
if
(
group_size
==
128
)
else
if
(
group_size
==
128
)
{
{
gemv_kernel_g128
<<<
num_blocks
,
num_threads
>>>
(
gemv_kernel_g128
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
// pointers
// pointers
in_feats
,
kernel
,
zeros
,
scaling_factors
,
out_feats
,
in_feats
,
kernel
,
zeros
,
scaling_factors
,
out_feats
,
// constants
// constants
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment