Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6a92ff93
Unverified
Commit
6a92ff93
authored
Mar 01, 2025
by
YajieWang
Committed by
GitHub
Feb 28, 2025
Browse files
[Misc][Kernel]: Add GPTQAllSpark Quantization (#12931)
parent
6a84164a
Changes
12
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
2005 additions
and
4 deletions
+2005
-4
CMakeLists.txt
CMakeLists.txt
+16
-0
benchmarks/kernels/benchmark_marlin.py
benchmarks/kernels/benchmark_marlin.py
+45
-2
csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu
csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu
+1008
-0
csrc/quantization/gptq_allspark/allspark_repack.cu
csrc/quantization/gptq_allspark/allspark_repack.cu
+163
-0
csrc/quantization/gptq_allspark/allspark_utils.cuh
csrc/quantization/gptq_allspark/allspark_utils.cuh
+408
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+19
-0
tests/kernels/test_allspark_gemm.py
tests/kernels/test_allspark_gemm.py
+100
-0
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+0
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+77
-0
vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py
...r/layers/quantization/kernels/mixed_precision/__init__.py
+3
-0
vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py
...r/layers/quantization/kernels/mixed_precision/allspark.py
+115
-0
vllm/model_executor/layers/quantization/utils/allspark_utils.py
...odel_executor/layers/quantization/utils/allspark_utils.py
+51
-0
No files found.
CMakeLists.txt
100755 → 100644
View file @
6a92ff93
...
...
@@ -317,6 +317,22 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
" in CUDA target architectures"
)
endif
()
# Only build AllSpark kernels if we are building for at least some compatible archs.
cuda_archs_loose_intersection
(
ALLSPARK_ARCHS
"8.0;8.6;8.7;8.9"
"
${
CUDA_ARCHS
}
"
)
if
(
ALLSPARK_ARCHS
)
set
(
ALLSPARK_SRCS
"csrc/quantization/gptq_allspark/allspark_repack.cu"
"csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu"
)
set_gencode_flags_for_srcs
(
SRCS
"
${
ALLSPARK_SRCS
}
"
CUDA_ARCHS
"
${
ALLSPARK_ARCHS
}
"
)
list
(
APPEND VLLM_EXT_SRC
"
${
ALLSPARK_SRCS
}
"
)
message
(
STATUS
"Building AllSpark kernels for archs:
${
ALLSPARK_ARCHS
}
"
)
else
()
message
(
STATUS
"Not building AllSpark kernels as no compatible archs found"
" in CUDA target architectures"
)
endif
()
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
cuda_archs_loose_intersection
(
SCALED_MM_3X_ARCHS
"9.0a"
"
${
CUDA_ARCHS
}
"
)
...
...
benchmarks/kernels/benchmark_marlin.py
View file @
6a92ff93
...
...
@@ -10,6 +10,8 @@ from vllm import _custom_ops as ops
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQ_MARLIN_24_MAX_PARALLEL
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
)
from
vllm.model_executor.layers.quantization.utils.allspark_utils
import
(
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
,
ALLSPARK_SUPPORTED_QUANT_TYPES
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
MARLIN_SUPPORTED_GROUP_SIZES
,
query_marlin_supported_quant_types
)
...
...
@@ -18,12 +20,12 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test_24
import
(
marlin_24_quantize
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
gptq_pack
,
gptq_quantize_weights
,
sort_weights
)
gptq_pack
,
gptq_quantize_weights
,
quantize_weights
,
sort_weights
)
from
vllm.scalar_type
import
ScalarType
from
vllm.utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
[
"meta-llama/Llama-2-7b-hf/TP1"
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
]
ACT_ORDER_OPTS
=
[
False
,
True
]
K_FULL_OPTS
=
[
False
,
True
]
...
...
@@ -81,6 +83,27 @@ def bench_run(results: List[benchmark.Measurement], model: str,
GPTQ_MARLIN_24_MAX_PARALLEL
)
marlin_zp
=
torch
.
zeros_like
(
marlin_s
,
dtype
=
torch
.
int
)
# AllSpark W8A16 quant
as_supported_case
=
(
quant_type
in
ALLSPARK_SUPPORTED_QUANT_TYPES
and
group_size
==
-
1
and
not
act_order
and
is_k_full
)
if
as_supported_case
:
properties
=
torch
.
cuda
.
get_device_properties
(
b
.
device
.
index
)
sm_count
=
properties
.
multi_processor_count
sm_version
=
properties
.
major
*
10
+
properties
.
minor
supported_arch
=
(
sm_version
>=
80
and
sm_version
<
90
)
as_supported_case
=
as_supported_case
and
supported_arch
if
supported_arch
:
has_zp
=
False
w_ref
,
qw
,
s
,
zp
=
quantize_weights
(
b
,
quant_type
,
group_size
,
has_zp
)
qw
=
qw
.
to
(
torch
.
uint8
)
qw_reorder
,
s_reorder
,
zp_reorder
=
\
ops
.
allspark_repack_weight
(
qw
,
s
,
zp
,
has_zp
)
CUBLAS_M_THRESHOLD
=
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
globals
=
{
# Gen params
"quant_type"
:
quant_type
,
...
...
@@ -109,10 +132,19 @@ def bench_run(results: List[benchmark.Measurement], model: str,
# GPTQ params
"q_w_gptq"
:
q_w_gptq
,
"repack_sort_indices"
:
repack_sort_indices
,
# AllSpark W8A16 params
"qw_reorder"
:
qw_reorder
if
as_supported_case
else
None
,
"s_reorder"
:
s_reorder
if
as_supported_case
else
None
,
"zp_reorder"
:
zp_reorder
if
as_supported_case
else
None
,
"sm_count"
:
sm_count
if
as_supported_case
else
None
,
"sm_version"
:
sm_version
if
as_supported_case
else
None
,
"CUBLAS_M_THRESHOLD"
:
CUBLAS_M_THRESHOLD
if
as_supported_case
else
None
,
# Kernels
"gptq_marlin_gemm"
:
ops
.
gptq_marlin_gemm
,
"gptq_marlin_24_gemm"
:
ops
.
gptq_marlin_24_gemm
,
"gptq_marlin_repack"
:
ops
.
gptq_marlin_repack
,
"allspark_w8a16_gemm"
:
ops
.
allspark_w8a16_gemm
,
}
min_run_time
=
1
...
...
@@ -172,6 +204,17 @@ def bench_run(results: List[benchmark.Measurement], model: str,
description
=
"gptq_marlin_repack"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
))
if
as_supported_case
:
results
.
append
(
benchmark
.
Timer
(
stmt
=
"output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"allspark_w8a16_gemm_fp32"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
))
def
main
(
args
):
print
(
"Benchmarking models:"
)
...
...
csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu
0 → 100644
View file @
6a92ff93
This diff is collapsed.
Click to expand it.
csrc/quantization/gptq_allspark/allspark_repack.cu
0 → 100644
View file @
6a92ff93
#include "allspark_utils.cuh"
#include <torch/all.h>
#include "core/registration.h"
namespace
allspark
{
// Rearrange B to facilitate Ampere Tensor Core load data
// reorder B from (K, N) to (N_32align / 4, K * 4)
// K % 16 == 0, N % 16 == 0, N_32align % 32 == 0
template
<
typename
FType
>
__global__
void
__launch_bounds__
(
128
)
rearrange_kn_weight_as_n32k16_order_ldg16_kernel
(
const
uint8_t
*
B
,
const
FType
*
B_scale
,
const
FType
*
B_zero
,
uint8_t
*
B_result
,
FType
*
B_scale_result
,
FType
*
B_zero_result
,
const
int
K
,
const
int
N
,
const
int
N_32align
)
{
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
warp_id
=
threadIdx
.
x
/
32
;
if
(
blockIdx
.
x
!=
gridDim
.
x
-
1
)
{
// Load B
// per block process 64(k) * 128(n) B elements
// per warp process 16(k) * 128 B elements
const
int
src_row_base_idx
=
blockIdx
.
x
*
64
+
warp_id
*
16
+
((
lane_id
%
8
)
/
2
)
*
2
;
const
int
src_col_idx
=
blockIdx
.
y
*
128
+
(
lane_id
/
8
)
*
32
+
(
lane_id
%
2
)
*
16
;
uint8_t
B_frag
[
4
][
16
];
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
int
src_row_idx
=
src_row_base_idx
+
(
i
/
2
)
*
8
+
(
i
%
2
);
int
src_offset
=
src_row_idx
*
N
+
src_col_idx
;
bool
guard
=
src_row_idx
<
K
&&
src_col_idx
<
N
;
ldg128_cg_0
(
*
reinterpret_cast
<
uint32_t
*>
(
B_frag
[
i
]),
*
(
reinterpret_cast
<
uint32_t
*>
(
B_frag
[
i
])
+
1
),
*
(
reinterpret_cast
<
uint32_t
*>
(
B_frag
[
i
])
+
2
),
*
(
reinterpret_cast
<
uint32_t
*>
(
B_frag
[
i
])
+
3
),
B
+
src_offset
,
guard
);
}
// reorder B
uint8_t
B_reorder_frag
[
8
][
8
];
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
16
;
++
j
)
{
int
dst_i
=
j
%
8
;
int
dst_j
=
i
+
(
j
/
8
)
*
4
;
B_reorder_frag
[
dst_i
][
dst_j
]
=
B_frag
[
i
][
j
];
}
}
// Store B
const
int
dst_row_base_idx
=
blockIdx
.
y
*
(
128
/
4
)
+
(
lane_id
/
8
)
*
8
;
const
int
dst_col_idx
=
blockIdx
.
x
*
(
64
*
4
)
+
warp_id
*
64
+
(
lane_id
%
8
)
*
8
;
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
int
dst_row_idx
=
dst_row_base_idx
+
i
;
int
dst_offset
=
dst_row_idx
*
K
*
4
+
dst_col_idx
;
bool
guard
=
(
dst_row_base_idx
<
N_32align
/
4
)
&&
(
dst_col_idx
<
K
*
4
);
if
(
guard
)
{
*
reinterpret_cast
<
int2
*>
(
B_result
+
dst_offset
)
=
*
reinterpret_cast
<
int2
*>
(
B_reorder_frag
[
i
]);
}
}
}
else
{
// Load B_scale and B_zero
FType
b_scale_reg
,
b_zero_reg
;
int
src_offset
=
blockIdx
.
y
*
128
+
threadIdx
.
x
;
ldg16_cg_0
(
b_scale_reg
,
B_scale
+
src_offset
,
src_offset
<
N
);
if
(
B_zero
!=
nullptr
)
ldg16_cg_0
(
b_zero_reg
,
B_zero
+
src_offset
,
src_offset
<
N
);
int
dst_offset
=
blockIdx
.
y
*
128
+
warp_id
*
32
+
(
lane_id
%
8
)
*
4
+
lane_id
/
8
;
if
(
dst_offset
<
N_32align
)
{
B_scale_result
[
dst_offset
]
=
b_scale_reg
;
if
(
B_zero
!=
nullptr
)
B_zero_result
[
dst_offset
]
=
b_zero_reg
;
}
}
}
template
<
typename
FType
>
void
rearrange_kn_weight_as_n32k16_order_ldg16
(
const
uint8_t
*
B
,
const
FType
*
B_scale
,
const
FType
*
B_zero
,
uint8_t
*
B_result
,
FType
*
B_scale_result
,
FType
*
B_zero_result
,
const
int64_t
K
,
const
int64_t
N
,
const
int64_t
N_32align
,
cudaStream_t
stream
)
{
if
(
N
%
16
!=
0
||
K
%
16
!=
0
)
{
std
::
cerr
<<
"Now only support N and K is multiples of 16"
<<
std
::
endl
;
}
const
int
BLOCK
=
128
;
int
grid_x
=
(
K
+
64
-
1
)
/
64
+
1
;
int
grid_y
=
(
N
+
128
-
1
)
/
128
;
dim3
grid
(
grid_x
,
grid_y
);
rearrange_kn_weight_as_n32k16_order_ldg16_kernel
<
FType
>
<<<
grid
,
BLOCK
,
0
,
stream
>>>
(
B
,
B_scale
,
B_zero
,
B_result
,
B_scale_result
,
B_zero_result
,
K
,
N
,
N_32align
);
}
}
// namespace allspark
void
rearrange_kn_weight_as_n32k16_order
(
torch
::
Tensor
const
&
b_qweight
,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
b_zeros
,
bool
has_zp
,
torch
::
Tensor
&
b_qweight_reorder
,
torch
::
Tensor
&
b_scales_reorder
,
c10
::
optional
<
torch
::
Tensor
>
const
&
b_zeros_reorder
,
const
int64_t
K
,
const
int64_t
N
,
const
int64_t
N_32align
)
{
// Verify device and strides
TORCH_CHECK
(
b_qweight
.
device
().
is_cuda
(),
"b_qweight is not on GPU"
);
TORCH_CHECK
(
b_qweight
.
is_contiguous
(),
"b_qweight is not contiguous"
);
TORCH_CHECK
(
b_scales
.
device
().
is_cuda
(),
"b_scales is not on GPU"
);
TORCH_CHECK
(
b_scales
.
is_contiguous
(),
"b_scales is not contiguous"
);
TORCH_CHECK
(
b_qweight_reorder
.
device
().
is_cuda
(),
"b_qweight_reorder is not on GPU"
);
TORCH_CHECK
(
b_qweight_reorder
.
is_contiguous
(),
"b_qweight_reorder is not contiguous"
);
TORCH_CHECK
(
b_scales_reorder
.
device
().
is_cuda
(),
"b_scales_reorder is not on GPU"
);
TORCH_CHECK
(
b_scales_reorder
.
is_contiguous
(),
"b_scales_reorder is not contiguous"
);
if
(
has_zp
)
{
TORCH_CHECK
(
b_zeros
.
value
().
device
().
is_cuda
(),
"b_zeros is not on GPU"
);
TORCH_CHECK
(
b_zeros
.
value
().
is_contiguous
(),
"b_zeros is not contiguous"
);
TORCH_CHECK
(
b_zeros_reorder
.
value
().
device
().
is_cuda
(),
"b_zeros_reorder is not on GPU"
);
TORCH_CHECK
(
b_zeros_reorder
.
value
().
is_contiguous
(),
"b_zeros_reorder is not contiguous"
);
}
const
uint8_t
*
matB
=
reinterpret_cast
<
const
uint8_t
*>
(
b_qweight
.
data_ptr
());
const
void
*
b_scale
=
b_scales
.
data_ptr
();
const
void
*
b_zero
=
has_zp
?
b_zeros
.
value
().
data_ptr
()
:
nullptr
;
uint8_t
*
matB_reorder
=
reinterpret_cast
<
uint8_t
*>
(
b_qweight_reorder
.
data_ptr
());
void
*
b_scale_reorder
=
b_scales_reorder
.
data_ptr
();
void
*
b_zero_reorder
=
has_zp
?
b_zeros_reorder
.
value
().
data_ptr
()
:
nullptr
;
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
b_scales
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
allspark
::
rearrange_kn_weight_as_n32k16_order_ldg16
<
__half
>
(
matB
,
reinterpret_cast
<
const
__half
*>
(
b_scale
),
reinterpret_cast
<
const
__half
*>
(
b_zero
),
matB_reorder
,
reinterpret_cast
<
__half
*>
(
b_scale_reorder
),
reinterpret_cast
<
__half
*>
(
b_zero_reorder
),
K
,
N
,
N_32align
,
stream
);
}
else
if
(
b_scales
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
allspark
::
rearrange_kn_weight_as_n32k16_order_ldg16
<
__nv_bfloat16
>
(
matB
,
reinterpret_cast
<
const
__nv_bfloat16
*>
(
b_scale
),
reinterpret_cast
<
const
__nv_bfloat16
*>
(
b_zero
),
matB_reorder
,
reinterpret_cast
<
__nv_bfloat16
*>
(
b_scale_reorder
),
reinterpret_cast
<
__nv_bfloat16
*>
(
b_zero_reorder
),
K
,
N
,
N_32align
,
stream
);
}
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"rearrange_kn_weight_as_n32k16_order"
,
&
rearrange_kn_weight_as_n32k16_order
);
}
csrc/quantization/gptq_allspark/allspark_utils.cuh
0 → 100644
View file @
6a92ff93
#pragma once
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <iostream>
namespace
allspark
{
#define CHECK_CUDA(cmd) \
do { \
cudaError_t cuda_status = cmd; \
if (cuda_status != cudaSuccess) { \
std::string err_str = cudaGetErrorString(cuda_status); \
std::cerr << "Failed: " << __FILE__ << ":" << __LINE__ << " " \
<< err_str; \
exit(-1); \
} \
} while (0)
#define CHECK_CUBLAS(cmd) \
do { \
cublasStatus_t cublas_status = cmd; \
if (cublas_status != CUBLAS_STATUS_SUCCESS) { \
std::cerr << "Failed: " << __FILE__ << ":" << __LINE__ << " " \
<< cublas_status << std::endl; \
exit(-1); \
} \
} while (0)
template
<
typename
FType
,
typename
QType
>
struct
SM8x_GEMM_W8A16_Splitk_Params
{
const
FType
*
A_ptr
;
const
QType
*
B_ptr
;
const
FType
*
B_scale_ptr
;
const
FType
*
B_zero_ptr
;
FType
*
C_ptr
;
int
M
;
int
N
;
int
K
;
int
SplitK
;
int
GroupCnt
;
int
GroupSize
;
FType
*
C_split_ptr
;
// for non-fused splitk reduce
float
*
C_tmp_ptr
;
// for fused splitk reduce
uint32_t
*
red_count_ptr
;
// for fused splitk reduce
};
struct
alignas
(
16
)
BlockTileSplitkParams
{
int
Mtile
;
int
Ntile
;
int
SplitK
;
bool
EnableFuse
;
};
template
<
typename
FType
,
int
BLOCK
,
int
N_MATRIX
>
__global__
void
f16_gemm_splitk_reduce_kernel
(
const
FType
*
C_split
,
FType
*
C
,
uint32_t
n
,
uint32_t
n_matrix
,
uint32_t
matrix_size
)
{
int
idx
=
blockIdx
.
x
*
BLOCK
+
threadIdx
.
x
;
if
(
idx
>=
matrix_size
)
{
return
;
}
FType
sum
(
0
);
int
n_mat
=
N_MATRIX
>
0
?
N_MATRIX
:
(
int
)
n_matrix
;
for
(
int
i
=
0
;
i
<
n_mat
;
++
i
)
{
sum
+=
C_split
[
idx
+
i
*
matrix_size
];
}
C
[
idx
]
=
sum
;
}
template
<
typename
FType
>
void
f16_gemm_splitk_reduce
(
const
FType
*
C_split
,
FType
*
C
,
const
uint32_t
m
,
const
uint32_t
n
,
const
uint32_t
n_matrix
,
cudaStream_t
stream
)
{
const
int
BLOCK
=
128
;
uint32_t
matrix_size
=
m
*
n
;
int
grid
=
(
matrix_size
+
BLOCK
-
1
)
/
BLOCK
;
void
(
*
kernel
)(
const
FType
*
,
FType
*
,
uint32_t
,
uint32_t
,
uint32_t
)
=
nullptr
;
switch
(
n_matrix
)
{
case
4
:
kernel
=
f16_gemm_splitk_reduce_kernel
<
FType
,
BLOCK
,
4
>
;
break
;
case
5
:
kernel
=
f16_gemm_splitk_reduce_kernel
<
FType
,
BLOCK
,
5
>
;
break
;
case
6
:
kernel
=
f16_gemm_splitk_reduce_kernel
<
FType
,
BLOCK
,
6
>
;
break
;
case
7
:
kernel
=
f16_gemm_splitk_reduce_kernel
<
FType
,
BLOCK
,
7
>
;
break
;
case
8
:
kernel
=
f16_gemm_splitk_reduce_kernel
<
FType
,
BLOCK
,
8
>
;
break
;
case
9
:
kernel
=
f16_gemm_splitk_reduce_kernel
<
FType
,
BLOCK
,
9
>
;
break
;
case
10
:
kernel
=
f16_gemm_splitk_reduce_kernel
<
FType
,
BLOCK
,
10
>
;
break
;
case
11
:
kernel
=
f16_gemm_splitk_reduce_kernel
<
FType
,
BLOCK
,
11
>
;
break
;
case
12
:
kernel
=
f16_gemm_splitk_reduce_kernel
<
FType
,
BLOCK
,
12
>
;
break
;
default:
kernel
=
f16_gemm_splitk_reduce_kernel
<
FType
,
BLOCK
,
-
1
>
;
break
;
}
kernel
<<<
grid
,
BLOCK
,
0
,
stream
>>>
(
C_split
,
C
,
n
,
n_matrix
,
matrix_size
);
}
template
<
typename
T
>
struct
HalfType
;
template
<
>
struct
HalfType
<
half
>
{
using
T1
=
__half
;
using
T2
=
__half2
;
};
template
<
>
struct
HalfType
<
__nv_bfloat16
>
{
using
T1
=
__nv_bfloat16
;
using
T2
=
__nv_bfloat162
;
};
// convert 64-bit pointer to 32-bit smem addr
__device__
__forceinline__
uint32_t
smem_u32addr
(
const
void
*
smem_ptr
)
{
uint32_t
addr
;
asm
(
"{.reg .u64 u64addr;
\n
"
" cvta.to.shared.u64 u64addr, %1;
\n
"
" cvt.u32.u64 %0, u64addr;}
\n
"
:
"=r"
(
addr
)
:
"l"
(
smem_ptr
));
return
addr
;
}
template
<
typename
T
>
__device__
__forceinline__
void
ldg16_cg_0
(
T
&
r0
,
const
void
*
ptr
,
bool
guard
)
{
static_assert
(
sizeof
(
T
)
==
2
,
"ldg16_cg_0: invalid T"
);
asm
volatile
(
"{.reg .pred p;
\n
"
" setp.ne.b32 p, %2, 0;
\n
"
" @!p mov.b16 %0, 0;
\n
"
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 4 && \
__CUDA_ARCH__ >= 750
" @p ld.global.cg.L2::128B.b16 {%0}, [%1];}
\n
"
#else
" @p ld.global.ca.b16 {%0}, [%1];}
\n
"
#endif
:
"=h"
(
reinterpret_cast
<
uint16_t
&>
(
r0
))
:
"l"
(
ptr
),
"r"
((
int
)
guard
));
}
template
<
typename
T
>
__device__
__forceinline__
void
ldg64_ca
(
T
&
r0
,
T
&
r1
,
const
void
*
ptr
,
bool
guard
)
{
static_assert
(
sizeof
(
T
)
==
4
,
"ldg64_ca: invalid T"
);
asm
volatile
(
"{.reg .pred p;
\n
"
" setp.ne.b32 p, %3, 0;
\n
"
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 4 && \
__CUDA_ARCH__ >= 750
" @p ld.global.ca.L2::128B.v2.b32 {%0, %1}, [%2];}
\n
"
#else
" @p ld.global.ca.v2.b32 {%0, %1}, [%2];}
\n
"
#endif
:
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
r0
)),
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
r1
))
:
"l"
(
ptr
),
"r"
((
int
)
guard
));
}
template
<
typename
T
>
__device__
__forceinline__
void
ldg128_cg_0
(
T
&
r0
,
T
&
r1
,
T
&
r2
,
T
&
r3
,
const
void
*
ptr
,
bool
guard
)
{
static_assert
(
sizeof
(
T
)
==
4
,
"ldg128_cg_0: invalid T"
);
asm
volatile
(
"{.reg .pred p;
\n
"
" setp.ne.b32 p, %5, 0;
\n
"
" @!p mov.b32 %0, 0;
\n
"
" @!p mov.b32 %1, 0;
\n
"
" @!p mov.b32 %2, 0;
\n
"
" @!p mov.b32 %3, 0;
\n
"
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 4 && \
__CUDA_ARCH__ >= 750
" @p ld.global.cg.L2::128B.v4.b32 {%0, %1, %2, %3}, [%4];}
\n
"
#else
" @p ld.global.cg.v4.b32 {%0, %1, %2, %3}, [%4];}
\n
"
#endif
:
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
r0
)),
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
r1
)),
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
r2
)),
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
r3
))
:
"l"
(
ptr
),
"r"
((
int
)
guard
));
}
template
<
typename
T
>
__device__
__forceinline__
void
lds128
(
T
&
reg0
,
T
&
reg1
,
T
&
reg2
,
T
&
reg3
,
const
uint32_t
addr
)
{
static_assert
(
sizeof
(
T
)
==
4
,
"lds128: invalid T"
);
asm
volatile
(
"ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
reg0
)),
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
reg1
)),
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
reg2
)),
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
reg3
))
:
"r"
(
addr
));
}
template
<
typename
T
>
__device__
__forceinline__
void
stg128
(
const
T
&
r0
,
const
T
&
r1
,
const
T
&
r2
,
const
T
&
r3
,
const
void
*
ptr
,
bool
guard
)
{
static_assert
(
sizeof
(
T
)
==
4
,
"stg128: invalid T"
);
asm
volatile
(
"{.reg .pred p;
\n
"
" setp.ne.b32 p, %1, 0;
\n
"
" @p st.global.v4.b32 [%0], {%2, %3, %4, %5};}
\n
"
:
:
"l"
(
ptr
),
"r"
((
int
)
guard
),
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
r0
)),
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
r1
)),
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
r2
)),
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
r3
)));
}
template
<
typename
T
>
__device__
__forceinline__
void
ldsm_4
(
T
&
r0
,
T
&
r1
,
T
&
r2
,
T
&
r3
,
const
uint32_t
&
addr
)
{
static_assert
(
sizeof
(
T
)
==
4
,
"ldsm_4: invalid T"
);
#if (__CUDA_ARCH__ >= 750) && (__CUDACC_VER_MAJOR__ >= 11)
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
r0
)),
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
r1
)),
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
r2
)),
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
r3
))
:
"r"
(
addr
));
#endif
}
template
<
typename
FType
>
__device__
__forceinline__
void
hmma16816_f32
(
float
(
&
d
)[
4
],
const
uint32_t
(
&
a
)[
4
],
const
uint32_t
(
&
b
)[
2
]);
template
<
>
__device__
__forceinline__
void
hmma16816_f32
<
__half
>
(
float
(
&
d
)[
4
],
const
uint32_t
(
&
a
)[
4
],
const
uint32_t
(
&
b
)[
2
])
{
#if (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11)
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};
\n
"
:
"+f"
(
d
[
0
]),
"+f"
(
d
[
1
]),
"+f"
(
d
[
2
]),
"+f"
(
d
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]));
#endif
}
template
<
>
__device__
__forceinline__
void
hmma16816_f32
<
__nv_bfloat16
>
(
float
(
&
d
)[
4
],
const
uint32_t
(
&
a
)[
4
],
const
uint32_t
(
&
b
)[
2
])
{
#if (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11)
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};
\n
"
:
"+f"
(
d
[
0
]),
"+f"
(
d
[
1
]),
"+f"
(
d
[
2
]),
"+f"
(
d
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]));
#endif
}
template
<
int
SIZE_IN_BYTES
>
__device__
__forceinline__
void
cp_async
(
const
uint32_t
smem_addr
,
const
void
*
gmem_ptr
,
const
int
src_in_bytes
,
bool
guard
)
{
static_assert
(
(
SIZE_IN_BYTES
==
4
||
SIZE_IN_BYTES
==
8
||
SIZE_IN_BYTES
==
16
),
"Size is not supported"
);
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
asm
volatile
(
"{.reg.pred p;
\n
"
" setp.ne.b32 p, %4, 0;
\n
"
#if __CUDACC_VER_MINOR__ >= 4
" @p cp.async.cg.shared.global.L2::256B [%0], [%1], %2, %3;}
\n
"
#else
" @p cp.async.cg.shared.global [%0], [%1], %2, %3;}
\n
"
#endif
::
"r"
(
smem_addr
),
"l"
(
gmem_ptr
),
"n"
(
SIZE_IN_BYTES
),
"r"
(
src_in_bytes
),
"r"
((
int
)
guard
));
#endif
}
template
<
int
SIZE_IN_BYTES
>
__device__
__forceinline__
void
cp_async_ca
(
const
uint32_t
smem_addr
,
const
void
*
gmem_ptr
,
const
int
src_in_bytes
,
bool
guard
)
{
static_assert
(
(
SIZE_IN_BYTES
==
4
||
SIZE_IN_BYTES
==
8
||
SIZE_IN_BYTES
==
16
),
"Size is not supported"
);
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
asm
volatile
(
"{.reg.pred p;
\n
"
" setp.ne.b32 p, %4, 0;
\n
"
#if __CUDACC_VER_MINOR__ >= 4
" @p cp.async.ca.shared.global.L2::256B [%0], [%1], %2, %3;}
\n
"
#else
" @p cp.async.ca.shared.global [%0], [%1], %2, %3;}
\n
"
#endif
::
"r"
(
smem_addr
),
"l"
(
gmem_ptr
),
"n"
(
SIZE_IN_BYTES
),
"r"
(
src_in_bytes
),
"r"
((
int
)
guard
));
#endif
}
__device__
__forceinline__
void
cp_async_commit_group
()
{
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
asm
volatile
(
"cp.async.commit_group;
\n
"
);
#endif
}
template
<
int
N
>
__device__
__forceinline__
void
cp_asyc_wait_group
()
{
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
asm
volatile
(
"cp.async.wait_group %0;
\n
"
:
:
"n"
(
N
));
#endif
}
template
<
typename
T
>
__device__
__forceinline__
void
cvt_8bx4_to_16bx4_bias128
(
const
uint32_t
&
idata
,
T
*
fdata
);
template
<
>
// fast conversion: 4xuint8 to 4xhalf, subtracting bias = 128
__device__
__forceinline__
void
cvt_8bx4_to_16bx4_bias128
<
__half2
>
(
const
uint32_t
&
idata
,
__half2
*
fdata
)
{
uint32_t
i10
,
i32
;
asm
volatile
(
"prmt.b32 %0, %2, 0x64, 0x4140;"
"prmt.b32 %1, %2, 0x64, 0x4342;"
:
"=r"
(
i10
),
"=r"
(
i32
)
:
"r"
(
idata
));
static
constexpr
uint32_t
MAGIC_NUM
=
0x64806480
;
fdata
[
0
]
=
__hsub2
(
reinterpret_cast
<
const
__half2
&>
(
i10
),
reinterpret_cast
<
const
__half2
&>
(
MAGIC_NUM
));
fdata
[
1
]
=
__hsub2
(
reinterpret_cast
<
const
__half2
&>
(
i32
),
reinterpret_cast
<
const
__half2
&>
(
MAGIC_NUM
));
}
template
<
>
// fast conversion: 4xuint8 to 4xbfloat16, subtracting bias = 128
// reference from marlin fast implementation
__device__
__forceinline__
void
cvt_8bx4_to_16bx4_bias128
<
__nv_bfloat162
>
(
const
uint32_t
&
idata
,
__nv_bfloat162
*
fdata
)
{
float
fp32_imd
[
4
];
uint32_t
*
fp32_imd_casted
=
reinterpret_cast
<
uint32_t
*>
(
fp32_imd
);
asm
volatile
(
"prmt.b32 %0, %4, 0x4B000000, 0x7650;"
"prmt.b32 %1, %4, 0x4B000000, 0x7651;"
"prmt.b32 %2, %4, 0x4B000000, 0x7652;"
"prmt.b32 %3, %4, 0x4B000000, 0x7653;"
:
"=r"
(
fp32_imd_casted
[
0
]),
"=r"
(
fp32_imd_casted
[
1
]),
"=r"
(
fp32_imd_casted
[
2
]),
"=r"
(
fp32_imd_casted
[
3
])
:
"r"
(
idata
));
fp32_imd
[
0
]
-=
8388736.
f
;
fp32_imd
[
1
]
-=
8388736.
f
;
fp32_imd
[
2
]
-=
8388736.
f
;
fp32_imd
[
3
]
-=
8388736.
f
;
uint32_t
*
bf16_res
=
reinterpret_cast
<
uint32_t
*>
(
fdata
);
asm
volatile
(
"prmt.b32 %0, %2, %3, 0x7632;"
"prmt.b32 %1, %4, %5, 0x7632;"
:
"=r"
(
bf16_res
[
0
]),
"=r"
(
bf16_res
[
1
])
:
"r"
(
fp32_imd_casted
[
0
]),
"r"
(
fp32_imd_casted
[
1
]),
"r"
(
fp32_imd_casted
[
2
]),
"r"
(
fp32_imd_casted
[
3
]));
}
static
__device__
nv_bfloat162
inline
num2num2
(
const
nv_bfloat16
x
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
return
__bfloat162bfloat162
(
x
);
#endif
__builtin_unreachable
();
// Suppress missing return statement warning
}
static
__device__
half2
inline
num2num2
(
const
half
x
)
{
return
__half2half2
(
x
);
}
}
// namespace allspark
\ No newline at end of file
csrc/torch_bindings.cpp
View file @
6a92ff93
...
...
@@ -447,6 +447,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor!? azp) -> ()"
);
ops
.
impl
(
"dynamic_scaled_int8_quant"
,
torch
::
kCUDA
,
&
dynamic_scaled_int8_quant
);
#ifndef USE_ROCM
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
ops
.
def
(
"rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "
"Tensor? b_zeros, "
"bool has_zp, Tensor! b_qweight_reorder, Tensor! b_scales_reorder, "
"Tensor!? b_zeros_reorder, "
"int K, int N, int N_32align) -> ()"
);
// conditionally compiled so impl in source file
// AllSpark quantization ops
ops
.
def
(
"allspark_w8a16_gemm(Tensor a, Tensor b_qweight, Tensor b_scales, "
"Tensor? b_qzeros, "
"SymInt n, SymInt group_size, SymInt sm_count, SymInt sm_version, SymInt "
"CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) -> Tensor"
);
// conditionally compiled so impl in source file
#endif
}
TORCH_LIBRARY_EXPAND
(
CONCAT
(
TORCH_EXTENSION_NAME
,
_cache_ops
),
cache_ops
)
{
...
...
tests/kernels/test_allspark_gemm.py
0 → 100644
View file @
6a92ff93
# SPDX-License-Identifier: Apache-2.0
import
pytest
import
torch
from
tests.kernels.utils
import
DEFAULT_OPCHECK_TEST_UTILS
,
opcheck
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.allspark_utils
import
(
ALLSPARK_AMPERE_K_ALIGN
,
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
,
ALLSPARK_AMPERE_N_ALIGN
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
quantize_weights
)
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
def
is_gptq_allspark_supported
(
min_capability
:
int
,
max_capability
:
int
)
->
bool
:
if
not
current_platform
.
is_cuda
():
return
False
capability
=
current_platform
.
get_device_capability
()
assert
capability
is
not
None
return
capability
.
to_int
()
>=
min_capability
\
and
capability
.
to_int
()
<=
max_capability
MNK_FACTORS
=
[
(
1
,
4
,
8
),
(
13
,
17
,
67
),
(
26
,
37
,
13
),
(
48
,
16
,
24
),
(
67
,
13
,
88
),
(
257
,
13
,
11
),
(
658
,
13
,
11
),
(
1033
,
9
,
17
),
]
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
HAS_ZP_OPTS
=
[
False
,
True
]
def
compute_max_diff
(
output
,
output_ref
):
return
torch
.
mean
(
torch
.
abs
(
output
-
output_ref
))
/
torch
.
mean
(
torch
.
abs
(
output_ref
))
def
rand_data
(
shape
,
dtype
=
torch
.
float16
):
return
torch
.
randn
(
shape
,
dtype
=
dtype
,
device
=
"cuda"
)
@
pytest
.
mark
.
skipif
(
not
is_gptq_allspark_supported
(
80
,
89
),
reason
=
"AllSpark Ampere kernel is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
-
1
])
@
pytest
.
mark
.
parametrize
(
"has_zp"
,
HAS_ZP_OPTS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
def
test_gptq_allspark_gemm_ampere
(
mnk_factors
,
group_size
,
has_zp
,
dtype
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
m
=
m_factor
n
=
n_factor
*
ALLSPARK_AMPERE_N_ALIGN
k
=
k_factor
*
ALLSPARK_AMPERE_K_ALIGN
input
=
rand_data
((
m
,
k
),
dtype
=
dtype
)
weight
=
rand_data
((
k
,
n
),
dtype
=
dtype
)
# Quantize (and apply act_order if provided)
w_ref
,
qw
,
s
,
zp
=
quantize_weights
(
weight
,
scalar_types
.
uint8b128
,
group_size
,
has_zp
)
qw
=
qw
.
to
(
torch
.
uint8
)
if
has_zp
:
zp
=
zp
.
to
(
dtype
)
properties
=
torch
.
cuda
.
get_device_properties
(
qw
.
device
.
index
)
sm_count
=
properties
.
multi_processor_count
sm_version
=
properties
.
major
*
10
+
properties
.
minor
n_32align
=
(
n
+
32
-
1
)
//
32
*
32
qw_reorder
,
s_reorder
,
zp_reorder
=
ops
.
allspark_repack_weight
(
qw
,
s
,
zp
,
has_zp
)
opcheck
(
torch
.
ops
.
_C
.
rearrange_kn_weight_as_n32k16_order
,
(
qw
,
s
,
zp
,
has_zp
,
qw_reorder
,
s_reorder
,
zp_reorder
,
k
,
n
,
n_32align
))
opcheck
(
torch
.
ops
.
_C
.
allspark_w8a16_gemm
,
(
input
,
qw_reorder
,
s_reorder
,
zp_reorder
,
n
,
group_size
,
sm_count
,
sm_version
,
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
,
has_zp
,
True
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
)
output
=
ops
.
allspark_w8a16_gemm
(
input
,
qw_reorder
,
s_reorder
,
zp_reorder
,
n
,
group_size
,
sm_count
,
sm_version
,
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
,
has_zp
,
True
)
output_ref
=
torch
.
matmul
(
input
,
w_ref
)
torch
.
cuda
.
synchronize
()
max_diff
=
compute_max_diff
(
output
,
output_ref
)
assert
max_diff
<
0.04
tests/quantization/test_compressed_tensors.py
View file @
6a92ff93
...
...
@@ -215,8 +215,6 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
assert
qkv_proj
.
scheme
.
group_size
==
(
-
1
if
group
is
None
else
group
)
assert
qkv_proj
.
weight_packed
.
dtype
is
torch
.
int32
assert
qkv_proj
.
weight_scale
.
dtype
is
torch
.
float16
assert
qkv_proj
.
scheme
.
pack_factor
==
pack_factor
llm
.
apply_model
(
check_model
)
...
...
vllm/_custom_ops.py
View file @
6a92ff93
...
...
@@ -404,6 +404,22 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
memory_format
=
torch
.
contiguous_format
)
if
hasattr
(
torch
.
ops
.
_C
,
"allspark_w8a16_gemm"
):
@
register_fake
(
"_C::allspark_w8a16_gemm"
)
def
_allspark_w8a16_gemm_fake
(
a
:
torch
.
Tensor
,
b_qweight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
b_qzeros
:
Optional
[
torch
.
Tensor
],
n
:
torch
.
SymInt
,
group_size
:
torch
.
SymInt
,
sm_count
:
torch
.
SymInt
,
sm_version
:
torch
.
SymInt
,
CUBLAS_M_THRESHOLD
:
torch
.
SymInt
,
has_zp
:
bool
,
n32k16_reorder
:
bool
)
->
torch
.
Tensor
:
m
=
a
.
size
(
0
)
return
torch
.
empty
((
m
,
n
),
device
=
a
.
device
,
dtype
=
a
.
dtype
)
if
hasattr
(
torch
.
ops
.
_C
,
"ggml_dequantize"
):
@
register_fake
(
"_C::ggml_dequantize"
)
...
...
@@ -881,6 +897,67 @@ def scaled_fp8_quant(
return
output
,
scale
# gptq allspark
def
allspark_repack_weight
(
qweight
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
zero_point
:
Optional
[
torch
.
Tensor
]
=
None
,
has_zp
:
bool
=
False
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Rearrange qweight, scale, and zero_point(if asymmetric) to n32k16 format
for Ampere W8A16 Fused Gemm kernel
Args:
qweight: uint8 weight tensor, original k x n format.
scale: fp16/bf16 weight scale tensor, 1 x n format.
zero_point: fp16/bf16 weight zero_point tensor, 1 x n format.
Must be provided for asymmetric quantization.
has_zp: if use symmetric quantization, has_zp = False.
if use asymmetric quantization, has_zp = True.
Returns:
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] :
rearranged weight, scale, and optionally zero_point.
"""
K
=
qweight
.
shape
[
0
]
N
=
qweight
.
shape
[
1
]
N_32align
=
(
N
+
32
-
1
)
//
32
*
32
qweight_reorder
=
torch
.
empty
((
N_32align
,
K
),
device
=
qweight
.
device
,
dtype
=
qweight
.
dtype
)
scale_reorder
=
torch
.
empty
((
1
,
N_32align
),
device
=
scale
.
device
,
dtype
=
scale
.
dtype
)
zero_point_reorder
=
None
if
has_zp
:
assert
zero_point
is
not
None
,
(
"zero_point must be provided for asymmetric quantization."
)
zero_point_reorder
=
torch
.
empty
((
1
,
N_32align
),
device
=
zero_point
.
device
,
dtype
=
zero_point
.
dtype
)
torch
.
ops
.
_C
.
rearrange_kn_weight_as_n32k16_order
(
qweight
,
scale
,
zero_point
,
has_zp
,
qweight_reorder
,
scale_reorder
,
zero_point_reorder
,
K
,
N
,
N_32align
)
return
qweight_reorder
,
scale_reorder
,
zero_point_reorder
def
allspark_w8a16_gemm
(
a
:
torch
.
Tensor
,
b_qweight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
b_qzeros
:
Optional
[
torch
.
Tensor
],
n
:
int
,
group_size
:
int
,
sm_count
:
int
,
sm_version
:
int
,
CUBLAS_M_THRESHOLD
:
int
,
has_zp
:
bool
,
n32k16_reorder
:
bool
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
allspark_w8a16_gemm
(
a
,
b_qweight
,
b_scales
,
b_qzeros
,
n
,
group_size
,
sm_count
,
sm_version
,
CUBLAS_M_THRESHOLD
,
has_zp
,
n32k16_reorder
)
# int8
def
scaled_int8_quant
(
input
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py
View file @
6a92ff93
...
...
@@ -3,6 +3,8 @@
from
typing
import
List
,
Optional
,
Type
import
vllm.envs
as
envs
from
vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark
import
(
# noqa: E501
AllSparkLinearKernel
)
from
vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama
import
(
# noqa: E501
ExllamaLinearKernel
)
from
vllm.model_executor.layers.quantization.kernels.mixed_precision.machete
import
(
# noqa: E501
...
...
@@ -16,6 +18,7 @@ from vllm.platforms import current_platform
# in priority/performance order (when available)
_POSSIBLE_KERNELS
:
List
[
Type
[
MPLinearKernel
]]
=
[
MacheteLinearKernel
,
AllSparkLinearKernel
,
MarlinLinearKernel
,
ExllamaLinearKernel
,
]
...
...
vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py
0 → 100644
View file @
6a92ff93
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
,
Tuple
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils
import
replace_parameter
from
vllm.model_executor.layers.quantization.utils.allspark_utils
import
(
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
,
check_allspark_supported_dtype_shape
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
permute_param_layout_
)
from
.MPLinearKernel
import
MPLinearKernel
,
MPLinearLayerConfig
class
AllSparkLinearKernel
(
MPLinearKernel
):
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
can_implement
(
cls
,
c
:
MPLinearLayerConfig
)
->
Tuple
[
bool
,
Optional
[
str
]]:
if
c
.
has_g_idx
:
return
False
,
"Act reordering currently not supported by AllSpark"
if
c
.
zero_points
:
return
False
,
"Zero points currently not supported by AllSpark"
return
check_allspark_supported_dtype_shape
(
c
.
partition_weight_shape
[
0
],
# in_features
c
.
partition_weight_shape
[
1
],
# out_features
c
.
group_size
,
c
.
weight_type
,
c
.
act_type
)
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
device
=
getattr
(
layer
,
self
.
w_q_name
).
device
c
=
self
.
config
# prepare the parameters required for the kernel
properties
=
torch
.
cuda
.
get_device_properties
(
device
.
index
)
sm_count
=
properties
.
multi_processor_count
sm_version
=
properties
.
major
*
10
+
properties
.
minor
gemm_args
=
{}
gemm_args
[
'sm_count'
]
=
sm_count
gemm_args
[
'sm_version'
]
=
sm_version
self
.
gemm_args
=
gemm_args
# transform param weight, scale
old_weight_param
=
getattr
(
layer
,
self
.
w_q_name
)
old_scale_param
=
getattr
(
layer
,
self
.
w_s_name
)
assert
isinstance
(
old_weight_param
,
BasevLLMParameter
)
permute_param_layout_
(
old_weight_param
,
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
0
)
assert
isinstance
(
old_scale_param
,
BasevLLMParameter
)
permute_param_layout_
(
old_scale_param
,
input_dim
=
0
,
output_dim
=
1
)
# unpack weight from K / 4 x N int32 to K x N uint8
new_weight_param
=
torch
.
nn
.
Parameter
(
old_weight_param
.
data
,
requires_grad
=
False
)
new_weight_param
.
data
=
new_weight_param
.
data
.
t
().
contiguous
().
view
(
dtype
=
torch
.
uint8
)
new_weight_param
.
data
=
new_weight_param
.
data
.
t
().
contiguous
()
new_scale_param
=
torch
.
nn
.
Parameter
(
old_scale_param
.
data
,
requires_grad
=
False
)
# reorder K x N weight as N32K16 format for Ampere W8A16
new_weight_param
.
data
,
new_scale_param
.
data
,
_
=
\
ops
.
allspark_repack_weight
(
new_weight_param
.
data
,
new_scale_param
.
data
,
None
,
c
.
zero_points
)
replace_parameter
(
layer
,
self
.
w_q_name
,
new_weight_param
.
data
)
replace_parameter
(
layer
,
self
.
w_s_name
,
new_scale_param
.
data
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
c
=
self
.
config
gemm_args
=
self
.
gemm_args
w_q
,
w_s
,
_
,
_
=
self
.
_get_weight_params
(
layer
)
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out_shape
=
x
.
shape
[:
-
1
]
+
(
c
.
partition_weight_shape
[
1
],
)
output
=
ops
.
allspark_w8a16_gemm
(
a
=
reshaped_x
,
b_qweight
=
w_q
,
b_scales
=
w_s
,
b_qzeros
=
None
,
n
=
c
.
partition_weight_shape
[
1
],
group_size
=
c
.
group_size
,
sm_count
=
gemm_args
[
'sm_count'
],
sm_version
=
gemm_args
[
'sm_version'
],
CUBLAS_M_THRESHOLD
=
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
,
has_zp
=
c
.
zero_points
,
n32k16_reorder
=
True
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
vllm/model_executor/layers/quantization/utils/allspark_utils.py
0 → 100644
View file @
6a92ff93
# SPDX-License-Identifier: Apache-2.0
import
torch
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
ScalarType
,
scalar_types
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
=
1024
ALLSPARK_SUPPORTED_QUANT_TYPES
=
[
scalar_types
.
uint8b128
]
ALLSPARK_AMPERE_N_ALIGN
=
16
ALLSPARK_AMPERE_K_ALIGN
=
16
def
check_allspark_supported_dtype_shape
(
input_size_per_partition
:
int
,
output_size_per_partition
:
int
,
group_size
:
int
,
weight_dtype
:
ScalarType
,
act_dtype
:
torch
.
dtype
):
capability_tuple
=
current_platform
.
get_device_capability
()
device_capability
=
(
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
())
# For Ampere GPU
if
device_capability
>=
80
and
device_capability
<
90
:
if
group_size
!=
-
1
:
return
False
,
\
"For Ampere GPU, AllSpark does not support group_size "
\
f
"=
{
group_size
}
. Only group_size = -1 are supported."
if
weight_dtype
not
in
ALLSPARK_SUPPORTED_QUANT_TYPES
:
return
False
,
"For Ampere GPU, AllSpark does not support "
\
f
"quant type (
{
weight_dtype
}
). Only quant type "
\
f
"(
{
ALLSPARK_SUPPORTED_QUANT_TYPES
}
) are supported."
if
input_size_per_partition
%
ALLSPARK_AMPERE_K_ALIGN
!=
0
\
or
output_size_per_partition
%
ALLSPARK_AMPERE_N_ALIGN
!=
0
:
return
False
,
\
"AllSpark needs input_size_per_partition % "
\
f
"
{
ALLSPARK_AMPERE_K_ALIGN
}
= 0 and "
\
f
"output_size_per_partition %
{
ALLSPARK_AMPERE_N_ALIGN
}
= 0 "
\
"for Ampere GPU optimized kernels."
if
act_dtype
!=
torch
.
float16
and
act_dtype
!=
torch
.
bfloat16
:
return
False
,
\
"AllSpark only supports act_dtype = float16 or bfloat16,"
\
f
"for Ampere GPU, but got act_dtype =
{
act_dtype
}
."
else
:
return
False
,
"AllSpark currently does not support "
\
f
"device_capability =
{
device_capability
}
."
return
True
,
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment