Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6a92ff93
Unverified
Commit
6a92ff93
authored
Mar 01, 2025
by
YajieWang
Committed by
GitHub
Feb 28, 2025
Browse files
[Misc][Kernel]: Add GPTQAllSpark Quantization (#12931)
parent
6a84164a
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
2005 additions
and
4 deletions
+2005
-4
CMakeLists.txt
CMakeLists.txt
+16
-0
benchmarks/kernels/benchmark_marlin.py
benchmarks/kernels/benchmark_marlin.py
+45
-2
csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu
csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu
+1008
-0
csrc/quantization/gptq_allspark/allspark_repack.cu
csrc/quantization/gptq_allspark/allspark_repack.cu
+163
-0
csrc/quantization/gptq_allspark/allspark_utils.cuh
csrc/quantization/gptq_allspark/allspark_utils.cuh
+408
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+19
-0
tests/kernels/test_allspark_gemm.py
tests/kernels/test_allspark_gemm.py
+100
-0
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+0
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+77
-0
vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py
...r/layers/quantization/kernels/mixed_precision/__init__.py
+3
-0
vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py
...r/layers/quantization/kernels/mixed_precision/allspark.py
+115
-0
vllm/model_executor/layers/quantization/utils/allspark_utils.py
...odel_executor/layers/quantization/utils/allspark_utils.py
+51
-0
No files found.
CMakeLists.txt
100755 → 100644
View file @
6a92ff93
...
...
@@ -317,6 +317,22 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
" in CUDA target architectures"
)
endif
()
# Only build AllSpark kernels if we are building for at least some compatible archs.
cuda_archs_loose_intersection
(
ALLSPARK_ARCHS
"8.0;8.6;8.7;8.9"
"
${
CUDA_ARCHS
}
"
)
if
(
ALLSPARK_ARCHS
)
set
(
ALLSPARK_SRCS
"csrc/quantization/gptq_allspark/allspark_repack.cu"
"csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu"
)
set_gencode_flags_for_srcs
(
SRCS
"
${
ALLSPARK_SRCS
}
"
CUDA_ARCHS
"
${
ALLSPARK_ARCHS
}
"
)
list
(
APPEND VLLM_EXT_SRC
"
${
ALLSPARK_SRCS
}
"
)
message
(
STATUS
"Building AllSpark kernels for archs:
${
ALLSPARK_ARCHS
}
"
)
else
()
message
(
STATUS
"Not building AllSpark kernels as no compatible archs found"
" in CUDA target architectures"
)
endif
()
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
cuda_archs_loose_intersection
(
SCALED_MM_3X_ARCHS
"9.0a"
"
${
CUDA_ARCHS
}
"
)
...
...
benchmarks/kernels/benchmark_marlin.py
View file @
6a92ff93
...
...
@@ -10,6 +10,8 @@ from vllm import _custom_ops as ops
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQ_MARLIN_24_MAX_PARALLEL
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
)
from
vllm.model_executor.layers.quantization.utils.allspark_utils
import
(
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
,
ALLSPARK_SUPPORTED_QUANT_TYPES
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
MARLIN_SUPPORTED_GROUP_SIZES
,
query_marlin_supported_quant_types
)
...
...
@@ -18,12 +20,12 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test_24
import
(
marlin_24_quantize
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
gptq_pack
,
gptq_quantize_weights
,
sort_weights
)
gptq_pack
,
gptq_quantize_weights
,
quantize_weights
,
sort_weights
)
from
vllm.scalar_type
import
ScalarType
from
vllm.utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
[
"meta-llama/Llama-2-7b-hf/TP1"
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
]
ACT_ORDER_OPTS
=
[
False
,
True
]
K_FULL_OPTS
=
[
False
,
True
]
...
...
@@ -81,6 +83,27 @@ def bench_run(results: List[benchmark.Measurement], model: str,
GPTQ_MARLIN_24_MAX_PARALLEL
)
marlin_zp
=
torch
.
zeros_like
(
marlin_s
,
dtype
=
torch
.
int
)
# AllSpark W8A16 quant
as_supported_case
=
(
quant_type
in
ALLSPARK_SUPPORTED_QUANT_TYPES
and
group_size
==
-
1
and
not
act_order
and
is_k_full
)
if
as_supported_case
:
properties
=
torch
.
cuda
.
get_device_properties
(
b
.
device
.
index
)
sm_count
=
properties
.
multi_processor_count
sm_version
=
properties
.
major
*
10
+
properties
.
minor
supported_arch
=
(
sm_version
>=
80
and
sm_version
<
90
)
as_supported_case
=
as_supported_case
and
supported_arch
if
supported_arch
:
has_zp
=
False
w_ref
,
qw
,
s
,
zp
=
quantize_weights
(
b
,
quant_type
,
group_size
,
has_zp
)
qw
=
qw
.
to
(
torch
.
uint8
)
qw_reorder
,
s_reorder
,
zp_reorder
=
\
ops
.
allspark_repack_weight
(
qw
,
s
,
zp
,
has_zp
)
CUBLAS_M_THRESHOLD
=
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
globals
=
{
# Gen params
"quant_type"
:
quant_type
,
...
...
@@ -109,10 +132,19 @@ def bench_run(results: List[benchmark.Measurement], model: str,
# GPTQ params
"q_w_gptq"
:
q_w_gptq
,
"repack_sort_indices"
:
repack_sort_indices
,
# AllSpark W8A16 params
"qw_reorder"
:
qw_reorder
if
as_supported_case
else
None
,
"s_reorder"
:
s_reorder
if
as_supported_case
else
None
,
"zp_reorder"
:
zp_reorder
if
as_supported_case
else
None
,
"sm_count"
:
sm_count
if
as_supported_case
else
None
,
"sm_version"
:
sm_version
if
as_supported_case
else
None
,
"CUBLAS_M_THRESHOLD"
:
CUBLAS_M_THRESHOLD
if
as_supported_case
else
None
,
# Kernels
"gptq_marlin_gemm"
:
ops
.
gptq_marlin_gemm
,
"gptq_marlin_24_gemm"
:
ops
.
gptq_marlin_24_gemm
,
"gptq_marlin_repack"
:
ops
.
gptq_marlin_repack
,
"allspark_w8a16_gemm"
:
ops
.
allspark_w8a16_gemm
,
}
min_run_time
=
1
...
...
@@ -172,6 +204,17 @@ def bench_run(results: List[benchmark.Measurement], model: str,
description
=
"gptq_marlin_repack"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
))
if
as_supported_case
:
results
.
append
(
benchmark
.
Timer
(
stmt
=
"output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"allspark_w8a16_gemm_fp32"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
))
def
main
(
args
):
print
(
"Benchmarking models:"
)
...
...
csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu
0 → 100644
View file @
6a92ff93
#include "allspark_utils.cuh"
#include <torch/all.h>
#include "core/registration.h"
#include <cublas_v2.h>
at
::
Tensor
as_g_workspace
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
torch
::
Tensor
allspark_w8a16_gemm
(
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b_qweight
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
b_qzeros
,
int64_t
n
,
int64_t
group_size
,
int64_t
sm_count
,
int64_t
sm_version
,
int64_t
CUBLAS_M_THRESHOLD
,
bool
has_zp
,
bool
n32k16_reorder
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"allspark_w8a16_gemm(..) requires CUDA_ARCH >= 8.0"
);
return
torch
::
empty
({
1
,
1
});
}
#else
namespace
allspark
{
/*
* GemmTile manage data movement from Global Memory to Shared Memory
* requiring N % 8 == 0, K % 16 == 0 by loading uint
* BN is obtained by padding the original N to a multiple of 32
* weight B is rearranged as N32K16 order,
* i.e. a initial data block of size 32(n)x16(k) is reordered as n8k4n4k4,
* in order to put data loaded by the same thread of 32x16 data block together
* continuously (see
* https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type)
*/
template
<
typename
FType
,
typename
QType
,
int
Mtile
,
int
Ntile
,
int
NStage
,
int
BLOCK
>
struct
GmemTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK
{
// element num loaded by a LDG inst.
static
constexpr
int
LDG_ELEMENT_CNT_A
=
8
;
static
constexpr
int
LDG_ELEMENT_CNT_B
=
16
;
static
constexpr
int
WARP_SIZE
=
32
;
static
constexpr
int
M_SIZE_ONE_LOAD
=
(
BLOCK
*
LDG_ELEMENT_CNT_A
)
/
32
;
static
constexpr
int
N_SIZE_ONE_LOAD
=
(
BLOCK
*
LDG_ELEMENT_CNT_B
)
/
32
;
__device__
GmemTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK
(
const
SM8x_GEMM_W8A16_Splitk_Params
<
FType
,
QType
>&
k_params
,
const
uint32_t
&
A_smem_addr
,
const
uint32_t
&
BQ_smem_addr
,
const
uint32_t
&
A_stage_stride
,
const
uint32_t
&
BQ_stage_stride
)
:
params
(
k_params
),
A_smem_base_addr
(
A_smem_addr
),
BQ_smem_base_addr
(
BQ_smem_addr
),
A_smem_stage_stride
(
A_stage_stride
),
BQ_smem_stage_stride
(
BQ_stage_stride
)
{
this_block_A_base_ptr
=
params
.
A_ptr
+
blockIdx
.
x
*
Mtile
*
params
.
K
+
blockIdx
.
z
*
params
.
SplitK
;
// here B is rearranged as N32K16 order, i.e. 4 continuous N-direction
// 8(N)x16(K) size data blocks are packed together
this_block_B_base_ptr
=
params
.
B_ptr
+
blockIdx
.
y
*
Ntile
*
params
.
K
+
blockIdx
.
z
*
params
.
SplitK
*
4
;
const
int
lane_id
=
threadIdx
.
x
%
WARP_SIZE
;
// For matrix A, a block load/store Mtile(row) x 32(col) elements in
// multiple iters, 8x4 warp load/store 8(row) x 32(col) elements per iter
const
int
Aldg_row_base_idx
=
threadIdx
.
x
/
4
;
Aldg_col_idx
=
(
threadIdx
.
x
%
4
)
*
LDG_ELEMENT_CNT_A
;
const
int
Aldg_base_offset
=
Aldg_row_base_idx
*
params
.
K
+
Aldg_col_idx
;
// For matrix B, a block load/store elements of (Ntile / 4) row x 128 col
// elements of N32K16 packing in multiple iters, 4x8 warp load/store 4(row)
// * 128(col) per iter
Bldg_col_idx
=
(
threadIdx
.
x
%
8
)
*
LDG_ELEMENT_CNT_B
;
const
int
Bldg_row_base_idx
=
threadIdx
.
x
/
8
;
const
int
Bldg_base_offset
=
Bldg_row_base_idx
*
params
.
K
*
4
+
Bldg_col_idx
;
this_block_A_base_ptr
+=
Aldg_base_offset
;
this_block_B_base_ptr
+=
Bldg_base_offset
;
const
int
sts_a_base_offset
=
(
threadIdx
.
x
/
4
)
*
32
+
((
lane_id
%
4
)
^
((
lane_id
/
4
)
%
4
)
^
((
lane_id
/
4
)
/
4
))
*
LDG_ELEMENT_CNT_A
;
const
int
sts_bq_base_offset
=
Bldg_row_base_idx
*
32
*
4
+
((
threadIdx
.
x
%
8
)
^
(((
threadIdx
.
x
/
8
)
%
2
)
*
4
))
*
LDG_ELEMENT_CNT_B
;
A_smem_base_addr
+=
sts_a_base_offset
*
sizeof
(
FType
);
BQ_smem_base_addr
+=
sts_bq_base_offset
*
sizeof
(
uint8_t
);
A_ldg_guard
=
0
;
B_ldg_guard
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
(
Mtile
+
M_SIZE_ONE_LOAD
-
1
)
/
M_SIZE_ONE_LOAD
;
++
i
)
{
int
m_idx
=
blockIdx
.
x
*
Mtile
+
Aldg_row_base_idx
+
i
*
M_SIZE_ONE_LOAD
;
if
(
m_idx
<
params
.
M
)
{
A_ldg_guard
|=
(
1u
<<
i
);
}
}
const
int
N_padded
=
(
params
.
N
+
31
)
/
32
*
32
;
#pragma unroll
for
(
int
i
=
0
;
i
<
(
Ntile
+
N_SIZE_ONE_LOAD
-
1
)
/
N_SIZE_ONE_LOAD
;
++
i
)
{
int
n_idx
=
blockIdx
.
y
*
Ntile
+
(
Bldg_row_base_idx
/
8
)
*
32
+
i
*
N_SIZE_ONE_LOAD
;
if
(
n_idx
<
N_padded
)
{
B_ldg_guard
|=
(
1u
<<
i
);
}
}
}
__device__
void
ldgsts_first_ktiles
(
const
int
&
first_k_tile
,
const
int
&
k_tiles
)
{
// load first k_tile
// load A
const
int
A_src_size
=
Aldg_col_idx
<
first_k_tile
?
16
:
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
(
Mtile
+
M_SIZE_ONE_LOAD
-
1
)
/
M_SIZE_ONE_LOAD
;
++
i
)
{
cp_async
<
16
>
(
A_smem_base_addr
+
(
i
*
M_SIZE_ONE_LOAD
*
32
)
*
sizeof
(
FType
),
this_block_A_base_ptr
+
i
*
M_SIZE_ONE_LOAD
*
params
.
K
,
A_src_size
,
(
A_ldg_guard
&
(
1u
<<
i
))
!=
0
);
}
// load B
const
int
B_src_size
=
(
Bldg_col_idx
/
4
)
<
first_k_tile
?
16
:
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
(
Ntile
+
N_SIZE_ONE_LOAD
-
1
)
/
N_SIZE_ONE_LOAD
;
++
i
)
{
cp_async
<
16
>
(
BQ_smem_base_addr
+
(
i
*
N_SIZE_ONE_LOAD
*
32
)
*
sizeof
(
uint8_t
),
this_block_B_base_ptr
+
i
*
N_SIZE_ONE_LOAD
*
params
.
K
,
B_src_size
,
(
B_ldg_guard
&
(
1u
<<
i
))
!=
0
);
}
cp_async_commit_group
();
this_block_A_base_ptr
+=
first_k_tile
;
this_block_B_base_ptr
+=
(
first_k_tile
*
4
);
// load second to (N-stage - 1) k_tiles
for
(
int
stage_idx
=
1
;
stage_idx
<
NStage
-
1
;
++
stage_idx
)
{
if
(
stage_idx
<
k_tiles
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
(
Mtile
+
M_SIZE_ONE_LOAD
-
1
)
/
M_SIZE_ONE_LOAD
;
++
i
)
{
cp_async
<
16
>
(
A_smem_base_addr
+
stage_idx
*
A_smem_stage_stride
+
(
i
*
M_SIZE_ONE_LOAD
*
32
)
*
sizeof
(
FType
),
this_block_A_base_ptr
+
i
*
M_SIZE_ONE_LOAD
*
params
.
K
,
16
,
(
A_ldg_guard
&
(
1u
<<
i
))
!=
0
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
(
Ntile
+
N_SIZE_ONE_LOAD
-
1
)
/
N_SIZE_ONE_LOAD
;
++
i
)
{
cp_async
<
16
>
(
BQ_smem_base_addr
+
stage_idx
*
BQ_smem_stage_stride
+
(
i
*
N_SIZE_ONE_LOAD
*
32
)
*
sizeof
(
uint8_t
),
this_block_B_base_ptr
+
i
*
N_SIZE_ONE_LOAD
*
params
.
K
,
16
,
(
B_ldg_guard
&
(
1u
<<
i
))
!=
0
);
}
this_block_A_base_ptr
+=
32
;
this_block_B_base_ptr
+=
(
32
*
4
);
}
cp_async_commit_group
();
}
}
__device__
void
ldgsts
(
const
int
&
sts_stage_idx
)
{
const
int
a_stage_offset
=
sts_stage_idx
*
A_smem_stage_stride
;
const
int
bq_stage_offset
=
sts_stage_idx
*
BQ_smem_stage_stride
;
#pragma unroll
for
(
int
i
=
0
;
i
<
(
Mtile
+
M_SIZE_ONE_LOAD
-
1
)
/
M_SIZE_ONE_LOAD
;
++
i
)
{
cp_async
<
16
>
(
A_smem_base_addr
+
a_stage_offset
+
(
i
*
M_SIZE_ONE_LOAD
*
32
)
*
sizeof
(
FType
),
this_block_A_base_ptr
+
i
*
M_SIZE_ONE_LOAD
*
params
.
K
,
16
,
(
A_ldg_guard
&
(
1u
<<
i
))
!=
0
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
(
Ntile
+
N_SIZE_ONE_LOAD
-
1
)
/
N_SIZE_ONE_LOAD
;
++
i
)
{
cp_async
<
16
>
(
BQ_smem_base_addr
+
bq_stage_offset
+
(
i
*
N_SIZE_ONE_LOAD
*
32
)
*
sizeof
(
uint8_t
),
this_block_B_base_ptr
+
i
*
N_SIZE_ONE_LOAD
*
params
.
K
,
16
,
(
B_ldg_guard
&
(
1u
<<
i
))
!=
0
);
}
cp_async_commit_group
();
this_block_A_base_ptr
+=
32
;
this_block_B_base_ptr
+=
(
32
*
4
);
}
const
FType
*
this_block_A_base_ptr
=
nullptr
;
const
QType
*
this_block_B_base_ptr
=
nullptr
;
int
Aldg_col_idx
;
int
Bldg_col_idx
;
uint32_t
A_ldg_guard
;
uint32_t
B_ldg_guard
;
uint32_t
A_smem_base_addr
,
BQ_smem_base_addr
;
const
uint32_t
A_smem_stage_stride
,
BQ_smem_stage_stride
;
const
SM8x_GEMM_W8A16_Splitk_Params
<
FType
,
QType
>&
params
;
};
/*
* requiring N % 8 == 0
*/
template
<
typename
FType
,
typename
QType
,
int
Mtile
,
int
Ntile
,
int
BLOCK
,
bool
EnableFuse
,
bool
has_zp
>
struct
ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK
{
static
constexpr
int
WARP_SIZE
=
32
;
static
constexpr
int
WARP_CNT
=
BLOCK
/
WARP_SIZE
;
static
constexpr
int
WARP_NTILE
=
Ntile
/
WARP_CNT
;
static
constexpr
int
WARP_NITER
=
WARP_NTILE
/
8
;
// hmma16816
static_assert
(
WARP_NTILE
==
32
or
WARP_NTILE
==
64
,
"now only support WARP_NTILE = 32 or 64!"
);
__device__
ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK
(
const
SM8x_GEMM_W8A16_Splitk_Params
<
FType
,
QType
>&
k_params
,
const
uint32_t
&
A_smem_addr
,
const
uint32_t
&
BQ_smem_addr
,
const
uint32_t
&
A_stage_stride
,
const
uint32_t
&
BQ_stage_stride
)
:
params
(
k_params
),
A_smem_base_addr
(
A_smem_addr
),
BQ_smem_base_addr
(
BQ_smem_addr
),
A_smem_stage_stride
(
A_stage_stride
),
BQ_smem_stage_stride
(
BQ_stage_stride
)
{
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
lane_id
=
threadIdx
.
x
%
WARP_SIZE
;
load_a_base_offset
[
0
]
=
(
lane_id
%
16
)
*
32
+
((
lane_id
/
16
)
^
(
lane_id
%
4
)
^
((
lane_id
/
4
)
%
2
))
*
8
;
load_a_base_offset
[
1
]
=
(
lane_id
%
16
)
*
32
+
((
lane_id
/
16
+
2
)
^
(
lane_id
%
4
)
^
((
lane_id
/
4
)
%
2
))
*
8
;
load_b_base_offset
[
0
]
=
(
lane_id
/
4
+
warp_id
*
(
WARP_NTILE
/
4
))
*
32
*
4
+
(
lane_id
%
4
)
*
16
+
((
lane_id
/
4
)
%
2
)
*
16
*
4
;
load_b_base_offset
[
1
]
=
(
lane_id
/
4
+
warp_id
*
(
WARP_NTILE
/
4
))
*
32
*
4
+
(
lane_id
%
4
)
*
16
+
(((
lane_id
/
4
)
%
2
)
^
1
)
*
16
*
4
;
sts_c_base_offset
=
warp_id
*
Mtile
*
WARP_NTILE
+
(
lane_id
/
4
)
*
WARP_NTILE
+
(
lane_id
%
4
)
*
2
;
if
(
EnableFuse
)
{
this_block_C_base_ptr
=
params
.
C_ptr
+
blockIdx
.
x
*
Mtile
*
params
.
N
+
blockIdx
.
y
*
Ntile
;
}
else
{
this_block_C_base_ptr
=
params
.
C_split_ptr
+
blockIdx
.
z
*
params
.
M
*
params
.
N
+
blockIdx
.
x
*
Mtile
*
params
.
N
+
blockIdx
.
y
*
Ntile
;
}
int
store_thds_in_row
=
WARP_NTILE
/
8
;
store_c_row_base_idx
=
lane_id
/
store_thds_in_row
;
store_c_col_idx
=
warp_id
*
WARP_NTILE
+
(
lane_id
%
store_thds_in_row
)
*
8
;
store_c_base_offset
=
store_c_row_base_idx
*
params
.
N
+
store_c_col_idx
;
#pragma unroll
for
(
int
i
=
0
;
i
<
Mtile
/
16
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
WARP_NITER
;
++
j
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
C_frag
[
i
][
j
][
k
]
=
0.
f
;
}
}
}
params_n_idx
=
blockIdx
.
y
*
Ntile
+
warp_id
*
WARP_NTILE
+
(
lane_id
/
4
)
*
4
;
}
__device__
void
lds
(
const
int
&
smem_stage_idx
,
const
int
&
reg_buf_idx
,
const
int
&
k_phase_idx
)
{
uint32_t
A_smem_addr
=
A_smem_base_addr
+
A_smem_stage_stride
*
smem_stage_idx
;
uint32_t
B_smem_addr
=
BQ_smem_base_addr
+
BQ_smem_stage_stride
*
smem_stage_idx
;
#pragma unroll
for
(
int
i
=
0
;
i
<
Mtile
/
16
;
++
i
)
{
ldsm_4
(
A_frag
[
reg_buf_idx
][
i
][
0
],
A_frag
[
reg_buf_idx
][
i
][
1
],
A_frag
[
reg_buf_idx
][
i
][
2
],
A_frag
[
reg_buf_idx
][
i
][
3
],
A_smem_addr
+
(
load_a_base_offset
[
k_phase_idx
]
+
i
*
16
*
32
)
*
sizeof
(
FType
));
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_NTILE
/
32
;
++
i
)
{
lds128
(
BQ_frag
[
reg_buf_idx
][
4
*
i
+
0
],
BQ_frag
[
reg_buf_idx
][
4
*
i
+
1
],
BQ_frag
[
reg_buf_idx
][
4
*
i
+
2
],
BQ_frag
[
reg_buf_idx
][
4
*
i
+
3
],
B_smem_addr
+
(
load_b_base_offset
[
k_phase_idx
]
+
i
*
32
*
32
)
*
sizeof
(
uint8_t
));
}
// dequant B
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_NITER
/
2
;
++
i
)
{
cvt_8bx4_to_16bx4_bias128
(
BQ_frag
[
reg_buf_idx
][
2
*
i
],
BF_frag
[
reg_buf_idx
][
2
*
i
]);
if
(
has_zp
)
{
BF_frag
[
reg_buf_idx
][
2
*
i
][
0
]
=
__hsub2
(
BF_frag
[
reg_buf_idx
][
2
*
i
][
0
],
num2num2
(
B_zero
[
i
].
x
));
BF_frag
[
reg_buf_idx
][
2
*
i
][
1
]
=
__hsub2
(
BF_frag
[
reg_buf_idx
][
2
*
i
][
1
],
num2num2
(
B_zero
[
i
].
x
));
}
BF_frag
[
reg_buf_idx
][
2
*
i
][
0
]
=
__hmul2
(
BF_frag
[
reg_buf_idx
][
2
*
i
][
0
],
num2num2
(
B_scale
[
i
].
x
));
BF_frag
[
reg_buf_idx
][
2
*
i
][
1
]
=
__hmul2
(
BF_frag
[
reg_buf_idx
][
2
*
i
][
1
],
num2num2
(
B_scale
[
i
].
x
));
cvt_8bx4_to_16bx4_bias128
(
BQ_frag
[
reg_buf_idx
][
2
*
i
+
1
],
BF_frag
[
reg_buf_idx
][
2
*
i
+
1
]);
if
(
has_zp
)
{
BF_frag
[
reg_buf_idx
][
2
*
i
+
1
][
0
]
=
__hsub2
(
BF_frag
[
reg_buf_idx
][
2
*
i
+
1
][
0
],
num2num2
(
B_zero
[
i
].
y
));
BF_frag
[
reg_buf_idx
][
2
*
i
+
1
][
1
]
=
__hsub2
(
BF_frag
[
reg_buf_idx
][
2
*
i
+
1
][
1
],
num2num2
(
B_zero
[
i
].
y
));
}
BF_frag
[
reg_buf_idx
][
2
*
i
+
1
][
0
]
=
__hmul2
(
BF_frag
[
reg_buf_idx
][
2
*
i
+
1
][
0
],
num2num2
(
B_scale
[
i
].
y
));
BF_frag
[
reg_buf_idx
][
2
*
i
+
1
][
1
]
=
__hmul2
(
BF_frag
[
reg_buf_idx
][
2
*
i
+
1
][
1
],
num2num2
(
B_scale
[
i
].
y
));
}
}
__device__
void
ldg_params
()
{
const
int
N_padded
=
(
params
.
N
+
31
)
/
32
*
32
;
// load B scale and zero_point
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_NTILE
/
32
;
++
i
)
{
ldg64_ca
(
B_scale
[
2
*
i
+
0
],
B_scale
[
2
*
i
+
1
],
params
.
B_scale_ptr
+
params_n_idx
+
i
*
32
,
(
params_n_idx
+
i
*
32
)
<
N_padded
);
if
(
has_zp
)
{
ldg64_ca
(
B_zero
[
2
*
i
+
0
],
B_zero
[
2
*
i
+
1
],
params
.
B_zero_ptr
+
params_n_idx
+
i
*
32
,
(
params_n_idx
+
i
*
32
)
<
N_padded
);
}
}
}
__device__
void
mma
(
const
int
&
reg_buf_idx
)
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
Mtile
/
16
;
++
m_idx
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
WARP_NITER
;
++
n_idx
)
{
hmma16816_f32
<
FType
>
(
C_frag
[
m_idx
][
n_idx
],
A_frag
[
reg_buf_idx
][
m_idx
],
reinterpret_cast
<
uint32_t
(
&
)[
2
]
>
(
BF_frag
[
reg_buf_idx
][
n_idx
]));
}
}
}
__device__
void
fused_splitk_reduce
()
{
// need splitk-reduce if enable splitk
if
(
gridDim
.
z
>
1
)
{
int
blk_red_idx
=
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
;
// Wait for all previous blocks in the splitk direction to accumulate the
// results into C_tmp
if
(
threadIdx
.
x
==
0
)
{
uint32_t
*
red_count_ptr
=
params
.
red_count_ptr
+
blk_red_idx
;
uint32_t
count
;
do
{
// make sure the ld.cg inside the do-wile loop
__threadfence_block
();
asm
volatile
(
"ld.global.cg.b32 %0, [%1];"
:
"=r"
(
count
)
:
"l"
(
red_count_ptr
));
}
while
(
count
!=
blockIdx
.
z
);
}
__syncthreads
();
int
C_tmp_base_offset
=
blk_red_idx
*
Mtile
*
Ntile
+
threadIdx
.
x
*
4
;
if
(
blockIdx
.
z
!=
0
)
{
// expecting that temporary register here reuses the previous A&B frag
// register
float
temp_frag
[
Mtile
/
16
][
WARP_NITER
][
4
];
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
Mtile
/
16
;
++
m_idx
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
WARP_NITER
;
++
n_idx
)
{
int
offset
=
C_tmp_base_offset
+
(
m_idx
*
WARP_NITER
+
n_idx
)
*
BLOCK
*
4
;
*
reinterpret_cast
<
int4
*>
(
temp_frag
[
m_idx
][
n_idx
])
=
*
reinterpret_cast
<
int4
*>
(
params
.
C_tmp_ptr
+
offset
);
}
}
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
Mtile
/
16
;
++
m_idx
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
WARP_NITER
;
++
n_idx
)
{
#pragma unroll
for
(
int
idx
=
0
;
idx
<
4
;
++
idx
)
{
C_frag
[
m_idx
][
n_idx
][
idx
]
+=
temp_frag
[
m_idx
][
n_idx
][
idx
];
}
}
}
}
// first splitk - 1 blocks need to write partial results into C_tmp
if
(
blockIdx
.
z
!=
gridDim
.
z
-
1
)
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
Mtile
/
16
;
++
m_idx
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
WARP_NITER
;
++
n_idx
)
{
int
offset
=
C_tmp_base_offset
+
(
m_idx
*
WARP_NITER
+
n_idx
)
*
BLOCK
*
4
;
asm
volatile
(
"{st.global.cg.v4.b32 [%0], {%1, %2, %3, %4};}
\n
"
:
:
"l"
(
params
.
C_tmp_ptr
+
offset
),
"f"
(
C_frag
[
m_idx
][
n_idx
][
0
]),
"f"
(
C_frag
[
m_idx
][
n_idx
][
1
]),
"f"
(
C_frag
[
m_idx
][
n_idx
][
2
]),
"f"
(
C_frag
[
m_idx
][
n_idx
][
3
]));
}
}
__threadfence
();
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
uint32_t
*
red_count_ptr
=
params
.
red_count_ptr
+
blk_red_idx
;
atomicInc
(
red_count_ptr
,
gridDim
.
z
);
}
}
}
}
__device__
void
stg
(
char
*
smem
)
{
if
(
EnableFuse
)
{
if
(
blockIdx
.
z
!=
gridDim
.
z
-
1
)
return
;
}
uint32_t
*
C_sts_ptr
=
reinterpret_cast
<
uint32_t
*>
(
smem
+
sts_c_base_offset
*
sizeof
(
FType
));
// C_tile sts
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
Mtile
/
16
;
++
m_idx
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
WARP_NITER
;
++
n_idx
)
{
#pragma unroll
for
(
int
k_idx
=
0
;
k_idx
<
2
;
++
k_idx
)
{
FType
low16
=
static_cast
<
FType
>
(
C_frag
[
m_idx
][
n_idx
][
k_idx
*
2
]);
FType
high16
=
static_cast
<
FType
>
(
C_frag
[
m_idx
][
n_idx
][
k_idx
*
2
+
1
]);
uint32_t
tmp
=
(
reinterpret_cast
<
uint32_t
&>
(
low16
)
&
0xffff
)
|
(
reinterpret_cast
<
uint32_t
&>
(
high16
)
<<
16
);
int
sts_offset
=
m_idx
*
16
*
(
WARP_NTILE
/
2
)
+
(((
lane_id
/
(
32
/
WARP_NITER
))
+
n_idx
)
%
WARP_NITER
)
*
(
8
/
2
)
+
k_idx
*
8
*
(
WARP_NTILE
/
2
);
C_sts_ptr
[
sts_offset
]
=
tmp
;
}
}
}
__syncthreads
();
FType
*
C_base_ptr
=
this_block_C_base_ptr
+
store_c_base_offset
;
// C_tile lds and stg
int
m_base_idx
=
store_c_row_base_idx
+
blockIdx
.
x
*
Mtile
;
bool
n_guard
=
(
store_c_col_idx
+
blockIdx
.
y
*
Ntile
)
<
params
.
N
;
if
(
WARP_NTILE
==
32
)
{
int
lds_c_base_offset
=
warp_id
*
Mtile
*
WARP_NTILE
+
(
lane_id
/
4
)
*
WARP_NTILE
+
((
lane_id
%
4
+
lane_id
/
8
)
%
4
)
*
8
;
uint4
*
C_lds_ptr
=
reinterpret_cast
<
uint4
*>
(
smem
+
lds_c_base_offset
*
sizeof
(
FType
));
#pragma unroll
for
(
int
i
=
0
;
i
<
(
Mtile
/
16
)
*
(
WARP_NITER
/
2
);
++
i
)
{
uint4
stg_reg
=
C_lds_ptr
[
i
*
8
*
4
];
stg128
(
stg_reg
.
x
,
stg_reg
.
y
,
stg_reg
.
z
,
stg_reg
.
w
,
C_base_ptr
+
i
*
8
*
params
.
N
,
(
m_base_idx
+
i
*
8
)
<
params
.
M
&&
n_guard
);
}
}
else
if
(
WARP_NTILE
==
64
)
{
int
lds_c_base_offset
=
warp_id
*
Mtile
*
WARP_NTILE
+
(
lane_id
/
8
)
*
WARP_NTILE
;
#pragma unroll
for
(
int
i
=
0
;
i
<
(
Mtile
/
16
)
*
(
WARP_NITER
/
2
);
++
i
)
{
int
lds_c_offset
=
lds_c_base_offset
+
i
*
4
*
WARP_NTILE
+
((
lane_id
%
8
+
lane_id
/
8
+
(
i
%
2
)
*
4
)
%
8
)
*
8
;
uint4
stg_reg
=
*
reinterpret_cast
<
uint4
*>
(
smem
+
lds_c_offset
*
sizeof
(
FType
));
stg128
(
stg_reg
.
x
,
stg_reg
.
y
,
stg_reg
.
z
,
stg_reg
.
w
,
C_base_ptr
+
i
*
4
*
params
.
N
,
(
m_base_idx
+
i
*
4
)
<
params
.
M
&&
n_guard
);
}
}
}
const
SM8x_GEMM_W8A16_Splitk_Params
<
FType
,
QType
>&
params
;
int
load_a_base_offset
[
2
];
int
load_b_base_offset
[
2
];
int
sts_c_base_offset
;
int
store_c_base_offset
;
int
store_c_row_base_idx
,
store_c_col_idx
;
FType
*
this_block_C_base_ptr
=
nullptr
;
int
params_n_idx
;
const
uint32_t
A_smem_base_addr
,
BQ_smem_base_addr
;
const
uint32_t
A_smem_stage_stride
,
BQ_smem_stage_stride
;
int
lane_id
;
int
warp_id
;
// first 2 denotes double buffer, second dim denotes M direction
uint32_t
A_frag
[
2
][
Mtile
/
16
][
4
];
typename
HalfType
<
FType
>::
T2
B_scale
[
WARP_NITER
/
2
];
typename
HalfType
<
FType
>::
T2
B_zero
[
WARP_NITER
/
2
];
uint32_t
BQ_frag
[
2
][
WARP_NITER
];
// first 2 denotes double buffer, second dim denotes N direction, last 2
// denotes K direction
typename
HalfType
<
FType
>::
T2
BF_frag
[
2
][
WARP_NITER
][
2
];
// first dim denotes M direction, second dim denotes N direction
float
C_frag
[
Mtile
/
16
][
WARP_NITER
][
4
];
};
/*
* @brief W8A16 Perchannel Quantization GEMM,
* requires N % 8 == 0, K % 16 == 0
* accumulator precision: FP32
* @tparam FType: DataType for A, B_scale, B_zero, and C, supports half or
* nv_bfloat16
* @tparam QType: DataType for B, support uint8(bias128)
* @tparam Mtile: M-dimensional size of the gemm block tile, supports 16, 32,
* 48 or 64
* @tparam Ntile: N-dimensional size of the gemm block tile, supports 128 or
* 256
* @tparam NStage: Num of stages for async copy
* @tparam BLOCK: BLOCK size
* @tparam EnableFuse: If true, use fused splitk-reduce, otherwise use
* non-fused splitk-reduce
* @tparam has_zp: whether to use zero_point
*
* @fparam params struct consists of following parameters:
* @param A_ptr: Matrix A value ptr, A = (M, K)
* @param B_ptr: Matrix B value ptr, B = (N32_align, K) (N32K16 special
* format), N32_align = (N + 32 - 1) / 32 * 32
* @param B_scale_ptr: B_scale value ptr, B_scale = (N32_align,) (N32K16
* special format)
* @param B_zero_ptr: B_zero value ptr, B_zero = (N32_align,) (N32K16
* special format)
* @param C_ptr: Matrix C value ptr, C = (M, N)
* @param M: dimnesion m
* @param N: dimnesion n
* @param K: dimnesion k
* @param SplitK: split size along K-dimension
* @param C_split_ptr: Matrix C_split value ptr, used only in non-fused
* splitk-reduce
* @param C_tmp_ptr: Matrix C_tmp value ptr, used only in fused
* splitk-reduce
* @param red_count_ptr: 1-D red_count value ptr, used only in fused
* splitk-reduce
*/
template
<
typename
FType
,
typename
QType
,
int
Mtile
,
int
Ntile
,
int
NStage
,
int
BLOCK
,
bool
EnableFuse
,
bool
has_zp
>
__global__
void
__launch_bounds__
(
BLOCK
)
ampere_hgemm_W8A16_perc_f16_f16_MtilexNtilex32_hmma16816_multistage_AN_BTN32K16_CN_splitk_kernel
(
const
SM8x_GEMM_W8A16_Splitk_Params
<
FType
,
QType
>
params
)
{
// A smem size = 64 * 32 * 2B/elem * 4(stage) = 16KB
// B smem size = 128 * 32 * 1B/elem * 4(stage) = 16KB
constexpr
int
smem_size_one_stage
=
Mtile
*
32
*
2
+
Ntile
*
32
;
__shared__
char
smem
[
NStage
*
smem_size_one_stage
];
char
*
A_smem
=
smem
;
char
*
BQ_smem
=
smem
+
Mtile
*
32
*
2
*
NStage
;
uint32_t
A_smem_addr
=
smem_u32addr
(
A_smem
);
uint32_t
BQ_smem_addr
=
smem_u32addr
(
BQ_smem
);
uint32_t
A_smem_stage_stride
=
Mtile
*
32
*
2
;
uint32_t
BQ_smem_stage_stride
=
Ntile
*
32
;
// initialize the data move process from GM to SMEM for this block
GmemTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK
<
FType
,
QType
,
Mtile
,
Ntile
,
NStage
,
BLOCK
>
gmem_tile
(
params
,
A_smem_addr
,
BQ_smem_addr
,
A_smem_stage_stride
,
BQ_smem_stage_stride
);
int
sts_stage_idx
=
0
;
int
lds_stage_idx
=
0
;
int
tb_k_slice
=
blockIdx
.
z
*
params
.
SplitK
+
params
.
SplitK
<=
params
.
K
?
params
.
SplitK
:
params
.
K
-
blockIdx
.
z
*
params
.
SplitK
;
int
k_tiles
=
(
tb_k_slice
+
31
)
/
32
;
int
first_k_tile
=
tb_k_slice
-
(
k_tiles
-
1
)
*
32
;
// load first three tiles to shared memory
gmem_tile
.
ldgsts_first_ktiles
(
first_k_tile
,
k_tiles
);
sts_stage_idx
+=
(
NStage
-
2
);
ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK
<
FType
,
QType
,
Mtile
,
Ntile
,
BLOCK
,
EnableFuse
,
has_zp
>
compute_tile
(
params
,
A_smem_addr
,
BQ_smem_addr
,
A_smem_stage_stride
,
BQ_smem_stage_stride
);
compute_tile
.
ldg_params
();
cp_asyc_wait_group
<
NStage
-
2
>
();
__syncthreads
();
compute_tile
.
lds
(
lds_stage_idx
,
0
,
0
);
int
reg_buf_idx
=
1
;
// main loop
for
(;
k_tiles
>
NStage
-
1
;
--
k_tiles
)
{
// load next A&B tile
sts_stage_idx
=
sts_stage_idx
<
NStage
-
1
?
sts_stage_idx
+
1
:
0
;
gmem_tile
.
ldgsts
(
sts_stage_idx
);
#pragma unroll
for
(
int
k_phase_idx
=
0
;
k_phase_idx
<
2
;
k_phase_idx
++
)
{
// dequantize next B tile
if
(
k_phase_idx
==
1
)
{
cp_asyc_wait_group
<
NStage
-
2
>
();
__syncthreads
();
lds_stage_idx
=
lds_stage_idx
<
NStage
-
1
?
lds_stage_idx
+
1
:
0
;
}
compute_tile
.
lds
(
lds_stage_idx
,
reg_buf_idx
,
(
k_phase_idx
+
1
)
%
2
);
compute_tile
.
mma
(
reg_buf_idx
^
1
);
reg_buf_idx
^=
1
;
}
}
// last NStage-1 tiles
for
(;
k_tiles
>
0
;
--
k_tiles
)
{
cp_async_commit_group
();
#pragma unroll
for
(
int
k_phase_idx
=
0
;
k_phase_idx
<
2
;
k_phase_idx
++
)
{
// dequantize next B tile
if
(
k_phase_idx
==
1
)
{
cp_asyc_wait_group
<
NStage
-
2
>
();
__syncthreads
();
lds_stage_idx
=
lds_stage_idx
<
NStage
-
1
?
lds_stage_idx
+
1
:
0
;
}
compute_tile
.
lds
(
lds_stage_idx
,
reg_buf_idx
,
(
k_phase_idx
+
1
)
%
2
);
compute_tile
.
mma
(
reg_buf_idx
^
1
);
reg_buf_idx
^=
1
;
}
}
if
(
EnableFuse
)
{
compute_tile
.
fused_splitk_reduce
();
}
compute_tile
.
stg
(
smem
);
}
#define __CALL_IF(MTILE, NTILE, NUM_THREADS, ENABLE_FUSE, HAS_ZP) \
else if (Mtile == MTILE && Ntile == NTILE && BLOCK == NUM_THREADS && \
enable_fuse == ENABLE_FUSE && has_zp == HAS_ZP) { \
ampere_hgemm_W8A16_perc_f16_f16_MtilexNtilex32_hmma16816_multistage_AN_BTN32K16_CN_splitk_kernel< \
FType, QType, MTILE, NTILE, 4, NUM_THREADS, ENABLE_FUSE, HAS_ZP> \
<<<grid, block, 0, stream>>>(params); \
}
template
<
typename
FType
,
typename
QType
>
void
ampere_hgemm_W8A16_perc_f16_f16_MtilexNtilex32_mma16816_multistage_AN_BTN32K16_CN_splitk
(
const
FType
*
A
,
const
QType
*
B
,
const
FType
*
B_scale
,
const
FType
*
B_zero
,
FType
*
C
,
const
int
M
,
const
int
N
,
const
int
K
,
void
*
workspace
,
const
int
sm_version
,
const
BlockTileSplitkParams
&
fused_gemm_params
,
cudaStream_t
stream
)
{
int
Mtile
=
fused_gemm_params
.
Mtile
;
int
grid_x
=
(
M
+
Mtile
-
1
)
/
Mtile
;
int
Ntile
=
fused_gemm_params
.
Ntile
;
int
grid_y
=
(
N
+
Ntile
-
1
)
/
Ntile
;
int
SplitK
=
fused_gemm_params
.
SplitK
;
int
grid_z
=
(
K
+
SplitK
-
1
)
/
SplitK
;
int
BLOCK
=
(
Ntile
==
256
)
?
256
:
128
;
dim3
grid
(
grid_x
,
grid_y
,
grid_z
);
dim3
block
(
BLOCK
);
bool
enable_fuse
=
fused_gemm_params
.
EnableFuse
;
bool
has_zp
=
B_zero
!=
nullptr
;
if
(
enable_fuse
)
{
float
*
C_tmp
=
reinterpret_cast
<
float
*>
(
workspace
);
uint32_t
*
red_count
=
reinterpret_cast
<
uint32_t
*>
(
(
char
*
)
workspace
+
grid_x
*
Mtile
*
grid_y
*
Ntile
*
sizeof
(
float
));
CHECK_CUDA
(
cudaMemsetAsync
(
red_count
,
0
,
grid_x
*
grid_y
*
sizeof
(
uint32_t
),
stream
));
SM8x_GEMM_W8A16_Splitk_Params
<
FType
,
QType
>
params
{
A
,
B
,
B_scale
,
B_zero
,
C
,
M
,
N
,
K
,
SplitK
,
0
,
-
1
,
nullptr
,
C_tmp
,
red_count
};
if
(
false
)
{
}
// Select the template parameters for kernel launch
// according to the above settings. Tuning is not supported.
__CALL_IF
(
16
,
256
,
256
,
true
,
false
)
__CALL_IF
(
32
,
256
,
256
,
true
,
false
)
__CALL_IF
(
48
,
256
,
256
,
true
,
false
)
__CALL_IF
(
64
,
128
,
128
,
true
,
false
)
__CALL_IF
(
64
,
256
,
256
,
true
,
false
)
__CALL_IF
(
16
,
256
,
256
,
true
,
true
)
__CALL_IF
(
32
,
256
,
256
,
true
,
true
)
__CALL_IF
(
48
,
256
,
256
,
true
,
true
)
__CALL_IF
(
64
,
128
,
128
,
true
,
true
)
__CALL_IF
(
64
,
256
,
256
,
true
,
true
)
}
else
{
FType
*
C_split
=
reinterpret_cast
<
FType
*>
(
workspace
);
SM8x_GEMM_W8A16_Splitk_Params
<
FType
,
QType
>
params
{
A
,
B
,
B_scale
,
B_zero
,
C
,
M
,
N
,
K
,
SplitK
,
0
,
-
1
,
C_split
,
nullptr
,
nullptr
};
if
(
false
)
{
}
// Select the template parameters for kernel launch
// according to the above settings. Tuning is not supported.
__CALL_IF
(
16
,
256
,
256
,
false
,
false
)
__CALL_IF
(
32
,
256
,
256
,
false
,
false
)
__CALL_IF
(
48
,
256
,
256
,
false
,
false
)
__CALL_IF
(
64
,
128
,
128
,
false
,
false
)
__CALL_IF
(
64
,
256
,
256
,
false
,
false
)
__CALL_IF
(
16
,
256
,
256
,
false
,
true
)
__CALL_IF
(
32
,
256
,
256
,
false
,
true
)
__CALL_IF
(
48
,
256
,
256
,
false
,
true
)
__CALL_IF
(
64
,
128
,
128
,
false
,
true
)
__CALL_IF
(
64
,
256
,
256
,
false
,
true
)
// SplitK reduce
f16_gemm_splitk_reduce
(
C_split
,
C
,
M
,
N
,
grid_z
,
stream
);
}
}
size_t
allspark_qgemm_w8a16_perc_n32k16_ampere_workspace_size
(
int
m
,
int
n
,
int
k
,
int
sm_count
,
BlockTileSplitkParams
&
fused_gemm_params
)
{
// Determine the block tile and splitk strategy
int
m16_times
=
(
m
+
16
-
1
)
/
16
;
int
Mtile
=
m16_times
<=
4
?
m16_times
*
16
:
64
;
int
grid_x
=
(
m
+
Mtile
-
1
)
/
Mtile
;
int
Ntile
=
(
float
(
grid_x
*
((
n
+
127
)
/
128
))
/
sm_count
>
10
)
||
(
Mtile
<
64
)
?
256
:
128
;
int
grid_y
=
(
n
+
Ntile
-
1
)
/
Ntile
;
int
grid_z
;
// split-k
const
float
SPLIT_THRESHOLD
=
0.8
;
int
n_slice
;
for
(
n_slice
=
1
;
n_slice
<
k
/
256
;
++
n_slice
)
{
int
n_block
=
grid_x
*
grid_y
*
n_slice
;
if
(
n_block
>=
sm_count
*
SPLIT_THRESHOLD
&&
(
n_block
%
sm_count
==
0
||
n_block
%
sm_count
>=
sm_count
*
0.5
))
{
break
;
}
}
int
k_slice
=
(
k
/
n_slice
)
%
32
==
0
?
k
/
n_slice
:
k
/
n_slice
/
32
*
32
+
32
;
grid_z
=
(
k
+
k_slice
-
1
)
/
k_slice
;
bool
enable_fuse
=
float
(
grid_x
*
grid_y
)
/
sm_count
>=
0.5
?
1
:
0
;
size_t
ws_size
;
if
(
enable_fuse
)
{
ws_size
=
grid_x
*
Mtile
*
grid_y
*
Ntile
*
sizeof
(
float
)
// For C_tmp
+
grid_x
*
grid_y
*
sizeof
(
uint32_t
);
// For red_count
}
else
{
ws_size
=
grid_z
*
m
*
n
*
sizeof
(
__half
);
}
fused_gemm_params
.
Mtile
=
Mtile
;
fused_gemm_params
.
Ntile
=
Ntile
;
fused_gemm_params
.
SplitK
=
k_slice
;
fused_gemm_params
.
EnableFuse
=
enable_fuse
;
return
ws_size
;
}
// restore from N32K16 order to original N-major order
// K % 16 == 0, N % 8 == 0
// each block process 64(k) * 32(n) result elements
template
<
typename
FT
,
typename
QT
>
__global__
void
restore_N32_K16_dequantize_rhs_w8a16_perc_kernel
(
const
QT
*
qdata
,
const
FT
*
scales
,
const
FT
*
zeros
,
FT
*
fdata
,
const
int
N_32align
,
const
int
N
,
const
int
K
)
{
__shared__
FT
smem
[
64
*
32
];
int
warp_id
=
threadIdx
.
x
/
32
;
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
src_row_idx
=
blockIdx
.
x
*
8
+
lane_id
/
4
;
const
int
src_col_idx
=
blockIdx
.
y
*
64
*
4
+
warp_id
*
16
*
4
+
(
lane_id
%
4
)
*
16
;
const
int
src_offset
=
src_row_idx
*
K
*
4
+
src_col_idx
;
int
params_nidx
=
blockIdx
.
x
*
32
+
(
lane_id
/
4
)
*
4
;
QT
qval_reg
[
16
];
const
QT
*
pdata
=
qdata
+
src_offset
;
if
(
src_col_idx
<
(
K
*
4
))
{
*
(
reinterpret_cast
<
uint4
*>
(
qval_reg
))
=
*
(
reinterpret_cast
<
const
uint4
*>
(
qdata
+
src_offset
));
}
FT
scale_reg
[
4
];
*
(
reinterpret_cast
<
uint2
*>
(
scale_reg
))
=
*
(
reinterpret_cast
<
const
uint2
*>
(
scales
+
params_nidx
));
FT
zero_reg
[
4
]
=
{
0
};
if
(
zeros
!=
nullptr
)
{
*
(
reinterpret_cast
<
uint2
*>
(
zero_reg
))
=
*
(
reinterpret_cast
<
const
uint2
*>
(
zeros
+
params_nidx
));
}
FT
fval_reg
[
16
];
const
int
sts_base_offset
=
(
warp_id
*
16
+
(
lane_id
%
4
)
*
2
)
*
32
+
lane_id
/
4
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
4
;
++
ni
)
{
cvt_8bx4_to_16bx4_bias128
(
*
reinterpret_cast
<
uint32_t
*>
(
&
qval_reg
[
ni
*
4
]),
reinterpret_cast
<
typename
HalfType
<
FT
>::
T2
*>
(
&
(
fval_reg
[
ni
*
4
])));
#pragma unroll
for
(
int
ki
=
0
;
ki
<
4
;
++
ki
)
{
fval_reg
[
ni
*
4
+
ki
]
=
(
fval_reg
[
ni
*
4
+
ki
]
-
zero_reg
[
ni
])
*
scale_reg
[
ni
];
int
sts_offset
=
sts_base_offset
+
((
ki
/
2
)
*
8
+
(
ki
%
2
))
*
32
+
((
ni
+
lane_id
%
4
)
%
4
)
*
8
;
smem
[
sts_offset
]
=
fval_reg
[
ni
*
4
+
ki
];
}
}
__syncthreads
();
const
int
lds_base_offset
=
(
threadIdx
.
x
/
4
)
*
32
+
((
threadIdx
.
x
%
4
+
threadIdx
.
x
/
8
)
%
4
)
*
8
;
#pragma unroll
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
*
reinterpret_cast
<
uint4
*>
(
fval_reg
+
i
*
8
)
=
*
reinterpret_cast
<
uint4
*>
(
smem
+
lds_base_offset
+
i
*
32
*
32
);
}
const
int
dst_row_base_kidx
=
blockIdx
.
y
*
64
+
threadIdx
.
x
/
4
;
const
int
dst_col_nidx
=
blockIdx
.
x
*
32
+
(
threadIdx
.
x
%
4
)
*
8
;
#pragma unroll
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
int
dst_row_kidx
=
dst_row_base_kidx
+
i
*
32
;
int
dst_offset
=
dst_row_kidx
*
N
+
dst_col_nidx
;
if
(
dst_row_kidx
<
K
&&
dst_col_nidx
<
N
)
{
*
reinterpret_cast
<
uint4
*>
(
fdata
+
dst_offset
)
=
*
reinterpret_cast
<
uint4
*>
(
fval_reg
+
i
*
8
);
}
}
}
template
<
typename
FT
,
typename
QT
>
void
restore_N32_K16_dequantize_rhs_w8a16
(
const
QT
*
qdata
,
const
FT
*
scales
,
const
FT
*
zeros
,
FT
*
fdata
,
const
int
N_32align
,
const
int
N
,
const
int
K
,
const
int
GroupSize
,
cudaStream_t
stream
)
{
TORCH_CHECK
(
N
%
8
==
0
&&
K
%
16
==
0
&&
N_32align
%
32
==
0
,
"Unsupported shape"
);
if
(
GroupSize
==
-
1
)
{
const
int
BLOCK
=
128
;
dim3
grid
(
N_32align
/
32
,
((
K
/
16
)
+
3
)
/
4
);
restore_N32_K16_dequantize_rhs_w8a16_perc_kernel
<
FT
,
QT
>
<<<
grid
,
BLOCK
,
0
,
stream
>>>
(
qdata
,
scales
,
zeros
,
fdata
,
N_32align
,
N
,
K
);
}
// TODO: Support SubChannel
else
{
TORCH_CHECK
(
false
,
"Now only support PerChannel"
);
}
}
template
<
typename
FT
,
typename
QT
>
void
w8a16_gemm_dq_cublas
(
const
FT
*
in
,
const
QT
*
rhs_qdata_ptr
,
const
FT
*
rhs_scales_ptr
,
const
FT
*
rhs_zeros_ptr
,
FT
*
out
,
void
*
workspace
,
const
int
M
,
const
int
N_32align
,
const
int
N
,
const
int
K
,
const
int
group_size
,
cudaStream_t
stream
,
cublasHandle_t
handle
)
{
static_assert
(
std
::
is_same
<
FT
,
half
>::
value
||
std
::
is_same
<
FT
,
nv_bfloat16
>::
value
,
"only float16 and bfloat16 is supported"
);
// Dequant
FT
*
rhs_fdata_ptr
=
static_cast
<
FT
*>
(
workspace
);
restore_N32_K16_dequantize_rhs_w8a16
(
rhs_qdata_ptr
,
rhs_scales_ptr
,
rhs_zeros_ptr
,
rhs_fdata_ptr
,
N_32align
,
N
,
K
,
group_size
,
stream
);
// cuBLAS GEMM
int
lda
=
K
;
int
ldb
=
N
;
int
ldc
=
N
;
const
float
alpha
=
1.0
f
;
const
float
beta
=
0.0
f
;
cudaDataType_t
cuda_type
;
if
(
std
::
is_same
<
FT
,
__half
>::
value
)
{
cuda_type
=
CUDA_R_16F
;
}
else
{
cuda_type
=
CUDA_R_16BF
;
}
CHECK_CUBLAS
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
N
,
M
,
K
,
&
alpha
,
rhs_fdata_ptr
,
cuda_type
,
ldb
,
in
,
cuda_type
,
lda
,
&
beta
,
out
,
cuda_type
,
ldc
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
}
template
<
typename
FType
,
typename
QType
>
void
allspark_qgemm_w8a16_perc_ampere
(
const
FType
*
A
,
const
QType
*
B
,
const
FType
*
B_scale
,
const
FType
*
B_zero
,
FType
*
C
,
const
int
M
,
const
int
N_32align
,
const
int
N
,
const
int
K
,
void
*
workspace
,
const
BlockTileSplitkParams
&
fused_gemm_params
,
const
int
group_size
,
int
CUBLAS_M_THRESHOLD
,
const
int
sm_version
,
cudaStream_t
stream
,
cublasHandle_t
handle
)
{
if
(
M
>
CUBLAS_M_THRESHOLD
)
{
w8a16_gemm_dq_cublas
<
FType
,
QType
>
(
A
,
B
,
B_scale
,
B_zero
,
C
,
workspace
,
M
,
N_32align
,
N
,
K
,
group_size
,
stream
,
handle
);
}
else
{
ampere_hgemm_W8A16_perc_f16_f16_MtilexNtilex32_mma16816_multistage_AN_BTN32K16_CN_splitk
<
FType
,
QType
>
(
A
,
B
,
B_scale
,
B_zero
,
C
,
M
,
N
,
K
,
workspace
,
sm_version
,
fused_gemm_params
,
stream
);
}
}
}
// namespace allspark
torch
::
Tensor
allspark_w8a16_gemm
(
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b_qweight
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
b_qzeros
,
int64_t
n
,
int64_t
group_size
,
int64_t
sm_count
,
int64_t
sm_version
,
int64_t
CUBLAS_M_THRESHOLD
,
bool
has_zp
,
bool
n32k16_reorder
)
{
// Verify device and strides
TORCH_CHECK
(
a
.
device
().
is_cuda
(),
"A is not on GPU"
);
TORCH_CHECK
(
a
.
is_contiguous
(),
"A is not contiguous"
);
TORCH_CHECK
(
b_qweight
.
device
().
is_cuda
(),
"b_qweight is not on GPU"
);
TORCH_CHECK
(
b_qweight
.
is_contiguous
(),
"b_qweight is not contiguous"
);
TORCH_CHECK
(
b_scales
.
device
().
is_cuda
(),
"b_scales is not on GPU"
);
TORCH_CHECK
(
b_scales
.
is_contiguous
(),
"b_scales is not contiguous"
);
if
(
has_zp
)
{
TORCH_CHECK
(
b_qzeros
.
value
().
device
().
is_cuda
(),
"b_qzeros is not on GPU"
);
TORCH_CHECK
(
b_qzeros
.
value
().
is_contiguous
(),
"b_qzeros is not contiguous"
);
}
int
m
=
a
.
size
(
0
);
int
n_32align
=
(
n
+
32
-
1
)
/
32
*
32
;
int
k
=
a
.
size
(
1
);
// Verify shape
TORCH_CHECK
(
b_qweight
.
size
(
0
)
==
n_32align
,
"Shape mismatch: b_qweight.size(0) = "
,
b_qweight
.
size
(
0
),
", n_32align = "
,
n_32align
);
TORCH_CHECK
(
b_qweight
.
size
(
1
)
==
k
,
"Shape mismatch: b_qweight.size(1) = "
,
b_qweight
.
size
(
1
),
", k = "
,
k
);
TORCH_CHECK
(
group_size
==
-
1
,
"Currently only supports group_size = -1"
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
a
));
const
void
*
a_ptr
=
reinterpret_cast
<
const
void
*>
(
a
.
data_ptr
());
const
uint8_t
*
b_ptr
=
reinterpret_cast
<
const
uint8_t
*>
(
b_qweight
.
data_ptr
());
const
void
*
b_scale_ptr
=
reinterpret_cast
<
const
void
*>
(
b_scales
.
data_ptr
());
const
void
*
b_zero_ptr
=
nullptr
;
if
(
b_qzeros
.
has_value
())
{
b_zero_ptr
=
reinterpret_cast
<
const
void
*>
(
b_qzeros
.
value
().
data_ptr
());
}
auto
c_options
=
torch
::
TensorOptions
().
dtype
(
a
.
dtype
()).
device
(
a
.
device
());
torch
::
Tensor
c
=
torch
::
empty
({
m
,
n
},
c_options
);
void
*
c_ptr
=
reinterpret_cast
<
void
*>
(
c
.
data_ptr
());
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
allspark
::
BlockTileSplitkParams
fused_gemm_params
;
size_t
ws_size
=
0
;
if
(
m
>
CUBLAS_M_THRESHOLD
)
{
ws_size
=
k
*
n
*
2
;
// sizeof(f16)==2
}
else
{
ws_size
=
allspark
::
allspark_qgemm_w8a16_perc_n32k16_ampere_workspace_size
(
m
,
n
,
k
,
sm_count
,
fused_gemm_params
);
}
auto
ws_options
=
torch
::
TensorOptions
().
dtype
(
at
::
kChar
).
device
(
a
.
device
());
if
(
as_g_workspace
.
numel
()
<
ws_size
)
{
// ws_options: kChar, so numel() is bytes
as_g_workspace
=
torch
::
empty
({
long
(
ws_size
)},
ws_options
);
}
void
*
ws
=
reinterpret_cast
<
void
*>
(
as_g_workspace
.
data_ptr
());
if
(
a
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
allspark
::
allspark_qgemm_w8a16_perc_ampere
<
__half
,
uint8_t
>
(
reinterpret_cast
<
const
__half
*>
(
a_ptr
),
b_ptr
,
reinterpret_cast
<
const
__half
*>
(
b_scale_ptr
),
reinterpret_cast
<
const
__half
*>
(
b_zero_ptr
),
reinterpret_cast
<
__half
*>
(
c_ptr
),
m
,
n_32align
,
n
,
k
,
ws
,
fused_gemm_params
,
group_size
,
CUBLAS_M_THRESHOLD
,
sm_version
,
stream
,
handle
);
}
else
if
(
a
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
allspark
::
allspark_qgemm_w8a16_perc_ampere
<
__nv_bfloat16
,
uint8_t
>
(
reinterpret_cast
<
const
__nv_bfloat16
*>
(
a_ptr
),
b_ptr
,
reinterpret_cast
<
const
__nv_bfloat16
*>
(
b_scale_ptr
),
reinterpret_cast
<
const
__nv_bfloat16
*>
(
b_zero_ptr
),
reinterpret_cast
<
__nv_bfloat16
*>
(
c_ptr
),
m
,
n_32align
,
n
,
k
,
ws
,
fused_gemm_params
,
group_size
,
CUBLAS_M_THRESHOLD
,
sm_version
,
stream
,
handle
);
}
return
c
;
}
#endif
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"allspark_w8a16_gemm"
,
&
allspark_w8a16_gemm
);
}
\ No newline at end of file
csrc/quantization/gptq_allspark/allspark_repack.cu
0 → 100644
View file @
6a92ff93
#include "allspark_utils.cuh"
#include <torch/all.h>
#include "core/registration.h"
namespace
allspark
{
// Rearrange B to facilitate Ampere Tensor Core load data
// reorder B from (K, N) to (N_32align / 4, K * 4)
// K % 16 == 0, N % 16 == 0, N_32align % 32 == 0
template
<
typename
FType
>
__global__
void
__launch_bounds__
(
128
)
rearrange_kn_weight_as_n32k16_order_ldg16_kernel
(
const
uint8_t
*
B
,
const
FType
*
B_scale
,
const
FType
*
B_zero
,
uint8_t
*
B_result
,
FType
*
B_scale_result
,
FType
*
B_zero_result
,
const
int
K
,
const
int
N
,
const
int
N_32align
)
{
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
warp_id
=
threadIdx
.
x
/
32
;
if
(
blockIdx
.
x
!=
gridDim
.
x
-
1
)
{
// Load B
// per block process 64(k) * 128(n) B elements
// per warp process 16(k) * 128 B elements
const
int
src_row_base_idx
=
blockIdx
.
x
*
64
+
warp_id
*
16
+
((
lane_id
%
8
)
/
2
)
*
2
;
const
int
src_col_idx
=
blockIdx
.
y
*
128
+
(
lane_id
/
8
)
*
32
+
(
lane_id
%
2
)
*
16
;
uint8_t
B_frag
[
4
][
16
];
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
int
src_row_idx
=
src_row_base_idx
+
(
i
/
2
)
*
8
+
(
i
%
2
);
int
src_offset
=
src_row_idx
*
N
+
src_col_idx
;
bool
guard
=
src_row_idx
<
K
&&
src_col_idx
<
N
;
ldg128_cg_0
(
*
reinterpret_cast
<
uint32_t
*>
(
B_frag
[
i
]),
*
(
reinterpret_cast
<
uint32_t
*>
(
B_frag
[
i
])
+
1
),
*
(
reinterpret_cast
<
uint32_t
*>
(
B_frag
[
i
])
+
2
),
*
(
reinterpret_cast
<
uint32_t
*>
(
B_frag
[
i
])
+
3
),
B
+
src_offset
,
guard
);
}
// reorder B
uint8_t
B_reorder_frag
[
8
][
8
];
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
16
;
++
j
)
{
int
dst_i
=
j
%
8
;
int
dst_j
=
i
+
(
j
/
8
)
*
4
;
B_reorder_frag
[
dst_i
][
dst_j
]
=
B_frag
[
i
][
j
];
}
}
// Store B
const
int
dst_row_base_idx
=
blockIdx
.
y
*
(
128
/
4
)
+
(
lane_id
/
8
)
*
8
;
const
int
dst_col_idx
=
blockIdx
.
x
*
(
64
*
4
)
+
warp_id
*
64
+
(
lane_id
%
8
)
*
8
;
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
int
dst_row_idx
=
dst_row_base_idx
+
i
;
int
dst_offset
=
dst_row_idx
*
K
*
4
+
dst_col_idx
;
bool
guard
=
(
dst_row_base_idx
<
N_32align
/
4
)
&&
(
dst_col_idx
<
K
*
4
);
if
(
guard
)
{
*
reinterpret_cast
<
int2
*>
(
B_result
+
dst_offset
)
=
*
reinterpret_cast
<
int2
*>
(
B_reorder_frag
[
i
]);
}
}
}
else
{
// Load B_scale and B_zero
FType
b_scale_reg
,
b_zero_reg
;
int
src_offset
=
blockIdx
.
y
*
128
+
threadIdx
.
x
;
ldg16_cg_0
(
b_scale_reg
,
B_scale
+
src_offset
,
src_offset
<
N
);
if
(
B_zero
!=
nullptr
)
ldg16_cg_0
(
b_zero_reg
,
B_zero
+
src_offset
,
src_offset
<
N
);
int
dst_offset
=
blockIdx
.
y
*
128
+
warp_id
*
32
+
(
lane_id
%
8
)
*
4
+
lane_id
/
8
;
if
(
dst_offset
<
N_32align
)
{
B_scale_result
[
dst_offset
]
=
b_scale_reg
;
if
(
B_zero
!=
nullptr
)
B_zero_result
[
dst_offset
]
=
b_zero_reg
;
}
}
}
template
<
typename
FType
>
void
rearrange_kn_weight_as_n32k16_order_ldg16
(
const
uint8_t
*
B
,
const
FType
*
B_scale
,
const
FType
*
B_zero
,
uint8_t
*
B_result
,
FType
*
B_scale_result
,
FType
*
B_zero_result
,
const
int64_t
K
,
const
int64_t
N
,
const
int64_t
N_32align
,
cudaStream_t
stream
)
{
if
(
N
%
16
!=
0
||
K
%
16
!=
0
)
{
std
::
cerr
<<
"Now only support N and K is multiples of 16"
<<
std
::
endl
;
}
const
int
BLOCK
=
128
;
int
grid_x
=
(
K
+
64
-
1
)
/
64
+
1
;
int
grid_y
=
(
N
+
128
-
1
)
/
128
;
dim3
grid
(
grid_x
,
grid_y
);
rearrange_kn_weight_as_n32k16_order_ldg16_kernel
<
FType
>
<<<
grid
,
BLOCK
,
0
,
stream
>>>
(
B
,
B_scale
,
B_zero
,
B_result
,
B_scale_result
,
B_zero_result
,
K
,
N
,
N_32align
);
}
}
// namespace allspark
void
rearrange_kn_weight_as_n32k16_order
(
torch
::
Tensor
const
&
b_qweight
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
b_zeros
,
bool
has_zp
,
torch
::
Tensor
&
b_qweight_reorder
,
torch
::
Tensor
&
b_scales_reorder
,
c10
::
optional
<
torch
::
Tensor
>
const
&
b_zeros_reorder
,
const
int64_t
K
,
const
int64_t
N
,
const
int64_t
N_32align
)
{
// Verify device and strides
TORCH_CHECK
(
b_qweight
.
device
().
is_cuda
(),
"b_qweight is not on GPU"
);
TORCH_CHECK
(
b_qweight
.
is_contiguous
(),
"b_qweight is not contiguous"
);
TORCH_CHECK
(
b_scales
.
device
().
is_cuda
(),
"b_scales is not on GPU"
);
TORCH_CHECK
(
b_scales
.
is_contiguous
(),
"b_scales is not contiguous"
);
TORCH_CHECK
(
b_qweight_reorder
.
device
().
is_cuda
(),
"b_qweight_reorder is not on GPU"
);
TORCH_CHECK
(
b_qweight_reorder
.
is_contiguous
(),
"b_qweight_reorder is not contiguous"
);
TORCH_CHECK
(
b_scales_reorder
.
device
().
is_cuda
(),
"b_scales_reorder is not on GPU"
);
TORCH_CHECK
(
b_scales_reorder
.
is_contiguous
(),
"b_scales_reorder is not contiguous"
);
if
(
has_zp
)
{
TORCH_CHECK
(
b_zeros
.
value
().
device
().
is_cuda
(),
"b_zeros is not on GPU"
);
TORCH_CHECK
(
b_zeros
.
value
().
is_contiguous
(),
"b_zeros is not contiguous"
);
TORCH_CHECK
(
b_zeros_reorder
.
value
().
device
().
is_cuda
(),
"b_zeros_reorder is not on GPU"
);
TORCH_CHECK
(
b_zeros_reorder
.
value
().
is_contiguous
(),
"b_zeros_reorder is not contiguous"
);
}
const
uint8_t
*
matB
=
reinterpret_cast
<
const
uint8_t
*>
(
b_qweight
.
data_ptr
());
const
void
*
b_scale
=
b_scales
.
data_ptr
();
const
void
*
b_zero
=
has_zp
?
b_zeros
.
value
().
data_ptr
()
:
nullptr
;
uint8_t
*
matB_reorder
=
reinterpret_cast
<
uint8_t
*>
(
b_qweight_reorder
.
data_ptr
());
void
*
b_scale_reorder
=
b_scales_reorder
.
data_ptr
();
void
*
b_zero_reorder
=
has_zp
?
b_zeros_reorder
.
value
().
data_ptr
()
:
nullptr
;
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
b_scales
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
allspark
::
rearrange_kn_weight_as_n32k16_order_ldg16
<
__half
>
(
matB
,
reinterpret_cast
<
const
__half
*>
(
b_scale
),
reinterpret_cast
<
const
__half
*>
(
b_zero
),
matB_reorder
,
reinterpret_cast
<
__half
*>
(
b_scale_reorder
),
reinterpret_cast
<
__half
*>
(
b_zero_reorder
),
K
,
N
,
N_32align
,
stream
);
}
else
if
(
b_scales
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
allspark
::
rearrange_kn_weight_as_n32k16_order_ldg16
<
__nv_bfloat16
>
(
matB
,
reinterpret_cast
<
const
__nv_bfloat16
*>
(
b_scale
),
reinterpret_cast
<
const
__nv_bfloat16
*>
(
b_zero
),
matB_reorder
,
reinterpret_cast
<
__nv_bfloat16
*>
(
b_scale_reorder
),
reinterpret_cast
<
__nv_bfloat16
*>
(
b_zero_reorder
),
K
,
N
,
N_32align
,
stream
);
}
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"rearrange_kn_weight_as_n32k16_order"
,
&
rearrange_kn_weight_as_n32k16_order
);
}
csrc/quantization/gptq_allspark/allspark_utils.cuh
0 → 100644
View file @
6a92ff93
#pragma once
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <iostream>
namespace
allspark
{
#define CHECK_CUDA(cmd) \
do { \
cudaError_t cuda_status = cmd; \
if (cuda_status != cudaSuccess) { \
std::string err_str = cudaGetErrorString(cuda_status); \
std::cerr << "Failed: " << __FILE__ << ":" << __LINE__ << " " \
<< err_str; \
exit(-1); \
} \
} while (0)
#define CHECK_CUBLAS(cmd) \
do { \
cublasStatus_t cublas_status = cmd; \
if (cublas_status != CUBLAS_STATUS_SUCCESS) { \
std::cerr << "Failed: " << __FILE__ << ":" << __LINE__ << " " \
<< cublas_status << std::endl; \
exit(-1); \
} \
} while (0)
template
<
typename
FType
,
typename
QType
>
struct
SM8x_GEMM_W8A16_Splitk_Params
{
const
FType
*
A_ptr
;
const
QType
*
B_ptr
;
const
FType
*
B_scale_ptr
;
const
FType
*
B_zero_ptr
;
FType
*
C_ptr
;
int
M
;
int
N
;
int
K
;
int
SplitK
;
int
GroupCnt
;
int
GroupSize
;
FType
*
C_split_ptr
;
// for non-fused splitk reduce
float
*
C_tmp_ptr
;
// for fused splitk reduce
uint32_t
*
red_count_ptr
;
// for fused splitk reduce
};
struct
alignas
(
16
)
BlockTileSplitkParams
{
int
Mtile
;
int
Ntile
;
int
SplitK
;
bool
EnableFuse
;
};
template
<
typename
FType
,
int
BLOCK
,
int
N_MATRIX
>
__global__
void
f16_gemm_splitk_reduce_kernel
(
const
FType
*
C_split
,
FType
*
C
,
uint32_t
n
,
uint32_t
n_matrix
,
uint32_t
matrix_size
)
{
int
idx
=
blockIdx
.
x
*
BLOCK
+
threadIdx
.
x
;
if
(
idx
>=
matrix_size
)
{
return
;
}
FType
sum
(
0
);
int
n_mat
=
N_MATRIX
>
0
?
N_MATRIX
:
(
int
)
n_matrix
;
for
(
int
i
=
0
;
i
<
n_mat
;
++
i
)
{
sum
+=
C_split
[
idx
+
i
*
matrix_size
];
}
C
[
idx
]
=
sum
;
}
template
<
typename
FType
>
void
f16_gemm_splitk_reduce
(
const
FType
*
C_split
,
FType
*
C
,
const
uint32_t
m
,
const
uint32_t
n
,
const
uint32_t
n_matrix
,
cudaStream_t
stream
)
{
const
int
BLOCK
=
128
;
uint32_t
matrix_size
=
m
*
n
;
int
grid
=
(
matrix_size
+
BLOCK
-
1
)
/
BLOCK
;
void
(
*
kernel
)(
const
FType
*
,
FType
*
,
uint32_t
,
uint32_t
,
uint32_t
)
=
nullptr
;
switch
(
n_matrix
)
{
case
4
:
kernel
=
f16_gemm_splitk_reduce_kernel
<
FType
,
BLOCK
,
4
>
;
break
;
case
5
:
kernel
=
f16_gemm_splitk_reduce_kernel
<
FType
,
BLOCK
,
5
>
;
break
;
case
6
:
kernel
=
f16_gemm_splitk_reduce_kernel
<
FType
,
BLOCK
,
6
>
;
break
;
case
7
:
kernel
=
f16_gemm_splitk_reduce_kernel
<
FType
,
BLOCK
,
7
>
;
break
;
case
8
:
kernel
=
f16_gemm_splitk_reduce_kernel
<
FType
,
BLOCK
,
8
>
;
break
;
case
9
:
kernel
=
f16_gemm_splitk_reduce_kernel
<
FType
,
BLOCK
,
9
>
;
break
;
case
10
:
kernel
=
f16_gemm_splitk_reduce_kernel
<
FType
,
BLOCK
,
10
>
;
break
;
case
11
:
kernel
=
f16_gemm_splitk_reduce_kernel
<
FType
,
BLOCK
,
11
>
;
break
;
case
12
:
kernel
=
f16_gemm_splitk_reduce_kernel
<
FType
,
BLOCK
,
12
>
;
break
;
default:
kernel
=
f16_gemm_splitk_reduce_kernel
<
FType
,
BLOCK
,
-
1
>
;
break
;
}
kernel
<<<
grid
,
BLOCK
,
0
,
stream
>>>
(
C_split
,
C
,
n
,
n_matrix
,
matrix_size
);
}
template
<
typename
T
>
struct
HalfType
;
template
<
>
struct
HalfType
<
half
>
{
using
T1
=
__half
;
using
T2
=
__half2
;
};
template
<
>
struct
HalfType
<
__nv_bfloat16
>
{
using
T1
=
__nv_bfloat16
;
using
T2
=
__nv_bfloat162
;
};
// convert 64-bit pointer to 32-bit smem addr
__device__
__forceinline__
uint32_t
smem_u32addr
(
const
void
*
smem_ptr
)
{
uint32_t
addr
;
asm
(
"{.reg .u64 u64addr;
\n
"
" cvta.to.shared.u64 u64addr, %1;
\n
"
" cvt.u32.u64 %0, u64addr;}
\n
"
:
"=r"
(
addr
)
:
"l"
(
smem_ptr
));
return
addr
;
}
template
<
typename
T
>
__device__
__forceinline__
void
ldg16_cg_0
(
T
&
r0
,
const
void
*
ptr
,
bool
guard
)
{
static_assert
(
sizeof
(
T
)
==
2
,
"ldg16_cg_0: invalid T"
);
asm
volatile
(
"{.reg .pred p;
\n
"
" setp.ne.b32 p, %2, 0;
\n
"
" @!p mov.b16 %0, 0;
\n
"
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 4 && \
__CUDA_ARCH__ >= 750
" @p ld.global.cg.L2::128B.b16 {%0}, [%1];}
\n
"
#else
" @p ld.global.ca.b16 {%0}, [%1];}
\n
"
#endif
:
"=h"
(
reinterpret_cast
<
uint16_t
&>
(
r0
))
:
"l"
(
ptr
),
"r"
((
int
)
guard
));
}
template
<
typename
T
>
__device__
__forceinline__
void
ldg64_ca
(
T
&
r0
,
T
&
r1
,
const
void
*
ptr
,
bool
guard
)
{
static_assert
(
sizeof
(
T
)
==
4
,
"ldg64_ca: invalid T"
);
asm
volatile
(
"{.reg .pred p;
\n
"
" setp.ne.b32 p, %3, 0;
\n
"
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 4 && \
__CUDA_ARCH__ >= 750
" @p ld.global.ca.L2::128B.v2.b32 {%0, %1}, [%2];}
\n
"
#else
" @p ld.global.ca.v2.b32 {%0, %1}, [%2];}
\n
"
#endif
:
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
r0
)),
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
r1
))
:
"l"
(
ptr
),
"r"
((
int
)
guard
));
}
template
<
typename
T
>
__device__
__forceinline__
void
ldg128_cg_0
(
T
&
r0
,
T
&
r1
,
T
&
r2
,
T
&
r3
,
const
void
*
ptr
,
bool
guard
)
{
static_assert
(
sizeof
(
T
)
==
4
,
"ldg128_cg_0: invalid T"
);
asm
volatile
(
"{.reg .pred p;
\n
"
" setp.ne.b32 p, %5, 0;
\n
"
" @!p mov.b32 %0, 0;
\n
"
" @!p mov.b32 %1, 0;
\n
"
" @!p mov.b32 %2, 0;
\n
"
" @!p mov.b32 %3, 0;
\n
"
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 4 && \
__CUDA_ARCH__ >= 750
" @p ld.global.cg.L2::128B.v4.b32 {%0, %1, %2, %3}, [%4];}
\n
"
#else
" @p ld.global.cg.v4.b32 {%0, %1, %2, %3}, [%4];}
\n
"
#endif
:
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
r0
)),
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
r1
)),
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
r2
)),
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
r3
))
:
"l"
(
ptr
),
"r"
((
int
)
guard
));
}
template
<
typename
T
>
__device__
__forceinline__
void
lds128
(
T
&
reg0
,
T
&
reg1
,
T
&
reg2
,
T
&
reg3
,
const
uint32_t
addr
)
{
static_assert
(
sizeof
(
T
)
==
4
,
"lds128: invalid T"
);
asm
volatile
(
"ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
reg0
)),
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
reg1
)),
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
reg2
)),
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
reg3
))
:
"r"
(
addr
));
}
template
<
typename
T
>
__device__
__forceinline__
void
stg128
(
const
T
&
r0
,
const
T
&
r1
,
const
T
&
r2
,
const
T
&
r3
,
const
void
*
ptr
,
bool
guard
)
{
static_assert
(
sizeof
(
T
)
==
4
,
"stg128: invalid T"
);
asm
volatile
(
"{.reg .pred p;
\n
"
" setp.ne.b32 p, %1, 0;
\n
"
" @p st.global.v4.b32 [%0], {%2, %3, %4, %5};}
\n
"
:
:
"l"
(
ptr
),
"r"
((
int
)
guard
),
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
r0
)),
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
r1
)),
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
r2
)),
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
r3
)));
}
template
<
typename
T
>
__device__
__forceinline__
void
ldsm_4
(
T
&
r0
,
T
&
r1
,
T
&
r2
,
T
&
r3
,
const
uint32_t
&
addr
)
{
static_assert
(
sizeof
(
T
)
==
4
,
"ldsm_4: invalid T"
);
#if (__CUDA_ARCH__ >= 750) && (__CUDACC_VER_MAJOR__ >= 11)
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
r0
)),
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
r1
)),
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
r2
)),
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
r3
))
:
"r"
(
addr
));
#endif
}
template
<
typename
FType
>
__device__
__forceinline__
void
hmma16816_f32
(
float
(
&
d
)[
4
],
const
uint32_t
(
&
a
)[
4
],
const
uint32_t
(
&
b
)[
2
]);
template
<
>
__device__
__forceinline__
void
hmma16816_f32
<
__half
>
(
float
(
&
d
)[
4
],
const
uint32_t
(
&
a
)[
4
],
const
uint32_t
(
&
b
)[
2
])
{
#if (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11)
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};
\n
"
:
"+f"
(
d
[
0
]),
"+f"
(
d
[
1
]),
"+f"
(
d
[
2
]),
"+f"
(
d
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]));
#endif
}
template
<
>
__device__
__forceinline__
void
hmma16816_f32
<
__nv_bfloat16
>
(
float
(
&
d
)[
4
],
const
uint32_t
(
&
a
)[
4
],
const
uint32_t
(
&
b
)[
2
])
{
#if (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11)
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};
\n
"
:
"+f"
(
d
[
0
]),
"+f"
(
d
[
1
]),
"+f"
(
d
[
2
]),
"+f"
(
d
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]));
#endif
}
template
<
int
SIZE_IN_BYTES
>
__device__
__forceinline__
void
cp_async
(
const
uint32_t
smem_addr
,
const
void
*
gmem_ptr
,
const
int
src_in_bytes
,
bool
guard
)
{
static_assert
(
(
SIZE_IN_BYTES
==
4
||
SIZE_IN_BYTES
==
8
||
SIZE_IN_BYTES
==
16
),
"Size is not supported"
);
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
asm
volatile
(
"{.reg.pred p;
\n
"
" setp.ne.b32 p, %4, 0;
\n
"
#if __CUDACC_VER_MINOR__ >= 4
" @p cp.async.cg.shared.global.L2::256B [%0], [%1], %2, %3;}
\n
"
#else
" @p cp.async.cg.shared.global [%0], [%1], %2, %3;}
\n
"
#endif
::
"r"
(
smem_addr
),
"l"
(
gmem_ptr
),
"n"
(
SIZE_IN_BYTES
),
"r"
(
src_in_bytes
),
"r"
((
int
)
guard
));
#endif
}
template
<
int
SIZE_IN_BYTES
>
__device__
__forceinline__
void
cp_async_ca
(
const
uint32_t
smem_addr
,
const
void
*
gmem_ptr
,
const
int
src_in_bytes
,
bool
guard
)
{
static_assert
(
(
SIZE_IN_BYTES
==
4
||
SIZE_IN_BYTES
==
8
||
SIZE_IN_BYTES
==
16
),
"Size is not supported"
);
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
asm
volatile
(
"{.reg.pred p;
\n
"
" setp.ne.b32 p, %4, 0;
\n
"
#if __CUDACC_VER_MINOR__ >= 4
" @p cp.async.ca.shared.global.L2::256B [%0], [%1], %2, %3;}
\n
"
#else
" @p cp.async.ca.shared.global [%0], [%1], %2, %3;}
\n
"
#endif
::
"r"
(
smem_addr
),
"l"
(
gmem_ptr
),
"n"
(
SIZE_IN_BYTES
),
"r"
(
src_in_bytes
),
"r"
((
int
)
guard
));
#endif
}
__device__
__forceinline__
void
cp_async_commit_group
()
{
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
asm
volatile
(
"cp.async.commit_group;
\n
"
);
#endif
}
template
<
int
N
>
__device__
__forceinline__
void
cp_asyc_wait_group
()
{
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
asm
volatile
(
"cp.async.wait_group %0;
\n
"
:
:
"n"
(
N
));
#endif
}
template
<
typename
T
>
__device__
__forceinline__
void
cvt_8bx4_to_16bx4_bias128
(
const
uint32_t
&
idata
,
T
*
fdata
);
template
<
>
// fast conversion: 4xuint8 to 4xhalf, subtracting bias = 128
__device__
__forceinline__
void
cvt_8bx4_to_16bx4_bias128
<
__half2
>
(
const
uint32_t
&
idata
,
__half2
*
fdata
)
{
uint32_t
i10
,
i32
;
asm
volatile
(
"prmt.b32 %0, %2, 0x64, 0x4140;"
"prmt.b32 %1, %2, 0x64, 0x4342;"
:
"=r"
(
i10
),
"=r"
(
i32
)
:
"r"
(
idata
));
static
constexpr
uint32_t
MAGIC_NUM
=
0x64806480
;
fdata
[
0
]
=
__hsub2
(
reinterpret_cast
<
const
__half2
&>
(
i10
),
reinterpret_cast
<
const
__half2
&>
(
MAGIC_NUM
));
fdata
[
1
]
=
__hsub2
(
reinterpret_cast
<
const
__half2
&>
(
i32
),
reinterpret_cast
<
const
__half2
&>
(
MAGIC_NUM
));
}
template
<
>
// fast conversion: 4xuint8 to 4xbfloat16, subtracting bias = 128
// reference from marlin fast implementation
__device__
__forceinline__
void
cvt_8bx4_to_16bx4_bias128
<
__nv_bfloat162
>
(
const
uint32_t
&
idata
,
__nv_bfloat162
*
fdata
)
{
float
fp32_imd
[
4
];
uint32_t
*
fp32_imd_casted
=
reinterpret_cast
<
uint32_t
*>
(
fp32_imd
);
asm
volatile
(
"prmt.b32 %0, %4, 0x4B000000, 0x7650;"
"prmt.b32 %1, %4, 0x4B000000, 0x7651;"
"prmt.b32 %2, %4, 0x4B000000, 0x7652;"
"prmt.b32 %3, %4, 0x4B000000, 0x7653;"
:
"=r"
(
fp32_imd_casted
[
0
]),
"=r"
(
fp32_imd_casted
[
1
]),
"=r"
(
fp32_imd_casted
[
2
]),
"=r"
(
fp32_imd_casted
[
3
])
:
"r"
(
idata
));
fp32_imd
[
0
]
-=
8388736.
f
;
fp32_imd
[
1
]
-=
8388736.
f
;
fp32_imd
[
2
]
-=
8388736.
f
;
fp32_imd
[
3
]
-=
8388736.
f
;
uint32_t
*
bf16_res
=
reinterpret_cast
<
uint32_t
*>
(
fdata
);
asm
volatile
(
"prmt.b32 %0, %2, %3, 0x7632;"
"prmt.b32 %1, %4, %5, 0x7632;"
:
"=r"
(
bf16_res
[
0
]),
"=r"
(
bf16_res
[
1
])
:
"r"
(
fp32_imd_casted
[
0
]),
"r"
(
fp32_imd_casted
[
1
]),
"r"
(
fp32_imd_casted
[
2
]),
"r"
(
fp32_imd_casted
[
3
]));
}
static
__device__
nv_bfloat162
inline
num2num2
(
const
nv_bfloat16
x
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
return
__bfloat162bfloat162
(
x
);
#endif
__builtin_unreachable
();
// Suppress missing return statement warning
}
static
__device__
half2
inline
num2num2
(
const
half
x
)
{
return
__half2half2
(
x
);
}
}
// namespace allspark
\ No newline at end of file
csrc/torch_bindings.cpp
View file @
6a92ff93
...
...
@@ -447,6 +447,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor!? azp) -> ()"
);
ops
.
impl
(
"dynamic_scaled_int8_quant"
,
torch
::
kCUDA
,
&
dynamic_scaled_int8_quant
);
#ifndef USE_ROCM
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
ops
.
def
(
"rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "
"Tensor? b_zeros, "
"bool has_zp, Tensor! b_qweight_reorder, Tensor! b_scales_reorder, "
"Tensor!? b_zeros_reorder, "
"int K, int N, int N_32align) -> ()"
);
// conditionally compiled so impl in source file
// AllSpark quantization ops
ops
.
def
(
"allspark_w8a16_gemm(Tensor a, Tensor b_qweight, Tensor b_scales, "
"Tensor? b_qzeros, "
"SymInt n, SymInt group_size, SymInt sm_count, SymInt sm_version, SymInt "
"CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) -> Tensor"
);
// conditionally compiled so impl in source file
#endif
}
TORCH_LIBRARY_EXPAND
(
CONCAT
(
TORCH_EXTENSION_NAME
,
_cache_ops
),
cache_ops
)
{
...
...
tests/kernels/test_allspark_gemm.py
0 → 100644
View file @
6a92ff93
# SPDX-License-Identifier: Apache-2.0
import
pytest
import
torch
from
tests.kernels.utils
import
DEFAULT_OPCHECK_TEST_UTILS
,
opcheck
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.allspark_utils
import
(
ALLSPARK_AMPERE_K_ALIGN
,
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
,
ALLSPARK_AMPERE_N_ALIGN
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
quantize_weights
)
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
def
is_gptq_allspark_supported
(
min_capability
:
int
,
max_capability
:
int
)
->
bool
:
if
not
current_platform
.
is_cuda
():
return
False
capability
=
current_platform
.
get_device_capability
()
assert
capability
is
not
None
return
capability
.
to_int
()
>=
min_capability
\
and
capability
.
to_int
()
<=
max_capability
MNK_FACTORS
=
[
(
1
,
4
,
8
),
(
13
,
17
,
67
),
(
26
,
37
,
13
),
(
48
,
16
,
24
),
(
67
,
13
,
88
),
(
257
,
13
,
11
),
(
658
,
13
,
11
),
(
1033
,
9
,
17
),
]
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
HAS_ZP_OPTS
=
[
False
,
True
]
def
compute_max_diff
(
output
,
output_ref
):
return
torch
.
mean
(
torch
.
abs
(
output
-
output_ref
))
/
torch
.
mean
(
torch
.
abs
(
output_ref
))
def
rand_data
(
shape
,
dtype
=
torch
.
float16
):
return
torch
.
randn
(
shape
,
dtype
=
dtype
,
device
=
"cuda"
)
@
pytest
.
mark
.
skipif
(
not
is_gptq_allspark_supported
(
80
,
89
),
reason
=
"AllSpark Ampere kernel is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
-
1
])
@
pytest
.
mark
.
parametrize
(
"has_zp"
,
HAS_ZP_OPTS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
def
test_gptq_allspark_gemm_ampere
(
mnk_factors
,
group_size
,
has_zp
,
dtype
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
m
=
m_factor
n
=
n_factor
*
ALLSPARK_AMPERE_N_ALIGN
k
=
k_factor
*
ALLSPARK_AMPERE_K_ALIGN
input
=
rand_data
((
m
,
k
),
dtype
=
dtype
)
weight
=
rand_data
((
k
,
n
),
dtype
=
dtype
)
# Quantize (and apply act_order if provided)
w_ref
,
qw
,
s
,
zp
=
quantize_weights
(
weight
,
scalar_types
.
uint8b128
,
group_size
,
has_zp
)
qw
=
qw
.
to
(
torch
.
uint8
)
if
has_zp
:
zp
=
zp
.
to
(
dtype
)
properties
=
torch
.
cuda
.
get_device_properties
(
qw
.
device
.
index
)
sm_count
=
properties
.
multi_processor_count
sm_version
=
properties
.
major
*
10
+
properties
.
minor
n_32align
=
(
n
+
32
-
1
)
//
32
*
32
qw_reorder
,
s_reorder
,
zp_reorder
=
ops
.
allspark_repack_weight
(
qw
,
s
,
zp
,
has_zp
)
opcheck
(
torch
.
ops
.
_C
.
rearrange_kn_weight_as_n32k16_order
,
(
qw
,
s
,
zp
,
has_zp
,
qw_reorder
,
s_reorder
,
zp_reorder
,
k
,
n
,
n_32align
))
opcheck
(
torch
.
ops
.
_C
.
allspark_w8a16_gemm
,
(
input
,
qw_reorder
,
s_reorder
,
zp_reorder
,
n
,
group_size
,
sm_count
,
sm_version
,
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
,
has_zp
,
True
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
)
output
=
ops
.
allspark_w8a16_gemm
(
input
,
qw_reorder
,
s_reorder
,
zp_reorder
,
n
,
group_size
,
sm_count
,
sm_version
,
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
,
has_zp
,
True
)
output_ref
=
torch
.
matmul
(
input
,
w_ref
)
torch
.
cuda
.
synchronize
()
max_diff
=
compute_max_diff
(
output
,
output_ref
)
assert
max_diff
<
0.04
tests/quantization/test_compressed_tensors.py
View file @
6a92ff93
...
...
@@ -215,8 +215,6 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
assert
qkv_proj
.
scheme
.
group_size
==
(
-
1
if
group
is
None
else
group
)
assert
qkv_proj
.
weight_packed
.
dtype
is
torch
.
int32
assert
qkv_proj
.
weight_scale
.
dtype
is
torch
.
float16
assert
qkv_proj
.
scheme
.
pack_factor
==
pack_factor
llm
.
apply_model
(
check_model
)
...
...
vllm/_custom_ops.py
View file @
6a92ff93
...
...
@@ -404,6 +404,22 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
memory_format
=
torch
.
contiguous_format
)
if
hasattr
(
torch
.
ops
.
_C
,
"allspark_w8a16_gemm"
):
@
register_fake
(
"_C::allspark_w8a16_gemm"
)
def
_allspark_w8a16_gemm_fake
(
a
:
torch
.
Tensor
,
b_qweight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
b_qzeros
:
Optional
[
torch
.
Tensor
],
n
:
torch
.
SymInt
,
group_size
:
torch
.
SymInt
,
sm_count
:
torch
.
SymInt
,
sm_version
:
torch
.
SymInt
,
CUBLAS_M_THRESHOLD
:
torch
.
SymInt
,
has_zp
:
bool
,
n32k16_reorder
:
bool
)
->
torch
.
Tensor
:
m
=
a
.
size
(
0
)
return
torch
.
empty
((
m
,
n
),
device
=
a
.
device
,
dtype
=
a
.
dtype
)
if
hasattr
(
torch
.
ops
.
_C
,
"ggml_dequantize"
):
@
register_fake
(
"_C::ggml_dequantize"
)
...
...
@@ -881,6 +897,67 @@ def scaled_fp8_quant(
return
output
,
scale
# gptq allspark
def
allspark_repack_weight
(
qweight
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
zero_point
:
Optional
[
torch
.
Tensor
]
=
None
,
has_zp
:
bool
=
False
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Rearrange qweight, scale, and zero_point(if asymmetric) to n32k16 format
for Ampere W8A16 Fused Gemm kernel
Args:
qweight: uint8 weight tensor, original k x n format.
scale: fp16/bf16 weight scale tensor, 1 x n format.
zero_point: fp16/bf16 weight zero_point tensor, 1 x n format.
Must be provided for asymmetric quantization.
has_zp: if use symmetric quantization, has_zp = False.
if use asymmetric quantization, has_zp = True.
Returns:
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] :
rearranged weight, scale, and optionally zero_point.
"""
K
=
qweight
.
shape
[
0
]
N
=
qweight
.
shape
[
1
]
N_32align
=
(
N
+
32
-
1
)
//
32
*
32
qweight_reorder
=
torch
.
empty
((
N_32align
,
K
),
device
=
qweight
.
device
,
dtype
=
qweight
.
dtype
)
scale_reorder
=
torch
.
empty
((
1
,
N_32align
),
device
=
scale
.
device
,
dtype
=
scale
.
dtype
)
zero_point_reorder
=
None
if
has_zp
:
assert
zero_point
is
not
None
,
(
"zero_point must be provided for asymmetric quantization."
)
zero_point_reorder
=
torch
.
empty
((
1
,
N_32align
),
device
=
zero_point
.
device
,
dtype
=
zero_point
.
dtype
)
torch
.
ops
.
_C
.
rearrange_kn_weight_as_n32k16_order
(
qweight
,
scale
,
zero_point
,
has_zp
,
qweight_reorder
,
scale_reorder
,
zero_point_reorder
,
K
,
N
,
N_32align
)
return
qweight_reorder
,
scale_reorder
,
zero_point_reorder
def
allspark_w8a16_gemm
(
a
:
torch
.
Tensor
,
b_qweight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
b_qzeros
:
Optional
[
torch
.
Tensor
],
n
:
int
,
group_size
:
int
,
sm_count
:
int
,
sm_version
:
int
,
CUBLAS_M_THRESHOLD
:
int
,
has_zp
:
bool
,
n32k16_reorder
:
bool
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
allspark_w8a16_gemm
(
a
,
b_qweight
,
b_scales
,
b_qzeros
,
n
,
group_size
,
sm_count
,
sm_version
,
CUBLAS_M_THRESHOLD
,
has_zp
,
n32k16_reorder
)
# int8
def
scaled_int8_quant
(
input
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py
View file @
6a92ff93
...
...
@@ -3,6 +3,8 @@
from
typing
import
List
,
Optional
,
Type
import
vllm.envs
as
envs
from
vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark
import
(
# noqa: E501
AllSparkLinearKernel
)
from
vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama
import
(
# noqa: E501
ExllamaLinearKernel
)
from
vllm.model_executor.layers.quantization.kernels.mixed_precision.machete
import
(
# noqa: E501
...
...
@@ -16,6 +18,7 @@ from vllm.platforms import current_platform
# in priority/performance order (when available)
_POSSIBLE_KERNELS
:
List
[
Type
[
MPLinearKernel
]]
=
[
MacheteLinearKernel
,
AllSparkLinearKernel
,
MarlinLinearKernel
,
ExllamaLinearKernel
,
]
...
...
vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py
0 → 100644
View file @
6a92ff93
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
,
Tuple
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils
import
replace_parameter
from
vllm.model_executor.layers.quantization.utils.allspark_utils
import
(
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
,
check_allspark_supported_dtype_shape
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
permute_param_layout_
)
from
.MPLinearKernel
import
MPLinearKernel
,
MPLinearLayerConfig
class
AllSparkLinearKernel
(
MPLinearKernel
):
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
can_implement
(
cls
,
c
:
MPLinearLayerConfig
)
->
Tuple
[
bool
,
Optional
[
str
]]:
if
c
.
has_g_idx
:
return
False
,
"Act reordering currently not supported by AllSpark"
if
c
.
zero_points
:
return
False
,
"Zero points currently not supported by AllSpark"
return
check_allspark_supported_dtype_shape
(
c
.
partition_weight_shape
[
0
],
# in_features
c
.
partition_weight_shape
[
1
],
# out_features
c
.
group_size
,
c
.
weight_type
,
c
.
act_type
)
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
device
=
getattr
(
layer
,
self
.
w_q_name
).
device
c
=
self
.
config
# prepare the parameters required for the kernel
properties
=
torch
.
cuda
.
get_device_properties
(
device
.
index
)
sm_count
=
properties
.
multi_processor_count
sm_version
=
properties
.
major
*
10
+
properties
.
minor
gemm_args
=
{}
gemm_args
[
'sm_count'
]
=
sm_count
gemm_args
[
'sm_version'
]
=
sm_version
self
.
gemm_args
=
gemm_args
# transform param weight, scale
old_weight_param
=
getattr
(
layer
,
self
.
w_q_name
)
old_scale_param
=
getattr
(
layer
,
self
.
w_s_name
)
assert
isinstance
(
old_weight_param
,
BasevLLMParameter
)
permute_param_layout_
(
old_weight_param
,
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
0
)
assert
isinstance
(
old_scale_param
,
BasevLLMParameter
)
permute_param_layout_
(
old_scale_param
,
input_dim
=
0
,
output_dim
=
1
)
# unpack weight from K / 4 x N int32 to K x N uint8
new_weight_param
=
torch
.
nn
.
Parameter
(
old_weight_param
.
data
,
requires_grad
=
False
)
new_weight_param
.
data
=
new_weight_param
.
data
.
t
().
contiguous
().
view
(
dtype
=
torch
.
uint8
)
new_weight_param
.
data
=
new_weight_param
.
data
.
t
().
contiguous
()
new_scale_param
=
torch
.
nn
.
Parameter
(
old_scale_param
.
data
,
requires_grad
=
False
)
# reorder K x N weight as N32K16 format for Ampere W8A16
new_weight_param
.
data
,
new_scale_param
.
data
,
_
=
\
ops
.
allspark_repack_weight
(
new_weight_param
.
data
,
new_scale_param
.
data
,
None
,
c
.
zero_points
)
replace_parameter
(
layer
,
self
.
w_q_name
,
new_weight_param
.
data
)
replace_parameter
(
layer
,
self
.
w_s_name
,
new_scale_param
.
data
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
c
=
self
.
config
gemm_args
=
self
.
gemm_args
w_q
,
w_s
,
_
,
_
=
self
.
_get_weight_params
(
layer
)
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out_shape
=
x
.
shape
[:
-
1
]
+
(
c
.
partition_weight_shape
[
1
],
)
output
=
ops
.
allspark_w8a16_gemm
(
a
=
reshaped_x
,
b_qweight
=
w_q
,
b_scales
=
w_s
,
b_qzeros
=
None
,
n
=
c
.
partition_weight_shape
[
1
],
group_size
=
c
.
group_size
,
sm_count
=
gemm_args
[
'sm_count'
],
sm_version
=
gemm_args
[
'sm_version'
],
CUBLAS_M_THRESHOLD
=
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
,
has_zp
=
c
.
zero_points
,
n32k16_reorder
=
True
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
vllm/model_executor/layers/quantization/utils/allspark_utils.py
0 → 100644
View file @
6a92ff93
# SPDX-License-Identifier: Apache-2.0
import
torch
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
ScalarType
,
scalar_types
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
=
1024
ALLSPARK_SUPPORTED_QUANT_TYPES
=
[
scalar_types
.
uint8b128
]
ALLSPARK_AMPERE_N_ALIGN
=
16
ALLSPARK_AMPERE_K_ALIGN
=
16
def
check_allspark_supported_dtype_shape
(
input_size_per_partition
:
int
,
output_size_per_partition
:
int
,
group_size
:
int
,
weight_dtype
:
ScalarType
,
act_dtype
:
torch
.
dtype
):
capability_tuple
=
current_platform
.
get_device_capability
()
device_capability
=
(
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
())
# For Ampere GPU
if
device_capability
>=
80
and
device_capability
<
90
:
if
group_size
!=
-
1
:
return
False
,
\
"For Ampere GPU, AllSpark does not support group_size "
\
f
"=
{
group_size
}
. Only group_size = -1 are supported."
if
weight_dtype
not
in
ALLSPARK_SUPPORTED_QUANT_TYPES
:
return
False
,
"For Ampere GPU, AllSpark does not support "
\
f
"quant type (
{
weight_dtype
}
). Only quant type "
\
f
"(
{
ALLSPARK_SUPPORTED_QUANT_TYPES
}
) are supported."
if
input_size_per_partition
%
ALLSPARK_AMPERE_K_ALIGN
!=
0
\
or
output_size_per_partition
%
ALLSPARK_AMPERE_N_ALIGN
!=
0
:
return
False
,
\
"AllSpark needs input_size_per_partition % "
\
f
"
{
ALLSPARK_AMPERE_K_ALIGN
}
= 0 and "
\
f
"output_size_per_partition %
{
ALLSPARK_AMPERE_N_ALIGN
}
= 0 "
\
"for Ampere GPU optimized kernels."
if
act_dtype
!=
torch
.
float16
and
act_dtype
!=
torch
.
bfloat16
:
return
False
,
\
"AllSpark only supports act_dtype = float16 or bfloat16,"
\
f
"for Ampere GPU, but got act_dtype =
{
act_dtype
}
."
else
:
return
False
,
"AllSpark currently does not support "
\
f
"device_capability =
{
device_capability
}
."
return
True
,
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment