Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2c8fd993
Unverified
Commit
2c8fd993
authored
Apr 03, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Apr 02, 2025
Browse files
[sgl-kernel] per token group quant support COLUMN MAJOR (#4817)
parent
31da75ab
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
252 additions
and
80 deletions
+252
-80
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
+7
-3
sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
+49
-19
sgl-kernel/tests/test_per_token_group_quant_8bit.py
sgl-kernel/tests/test_per_token_group_quant_8bit.py
+196
-58
No files found.
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
View file @
2c8fd993
...
@@ -148,9 +148,11 @@ def sglang_per_token_group_quant_8bit(
...
@@ -148,9 +148,11 @@ def sglang_per_token_group_quant_8bit(
def
calculate_diff
(
batch_size
,
seq_len
,
group_size
,
dst_dtype
):
def
calculate_diff
(
batch_size
,
seq_len
,
group_size
,
dst_dtype
):
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
hidden_dim
=
group_size
*
2
hidden_dim
=
7168
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_dim
,
device
=
device
,
dtype
=
torch
.
float16
)
x
=
torch
.
randn
(
batch_size
*
seq_len
,
hidden_dim
,
device
=
device
,
dtype
=
torch
.
float16
)
x_q_triton
,
x_s_triton
=
triton_per_token_group_quant_8bit
(
x_q_triton
,
x_s_triton
=
triton_per_token_group_quant_8bit
(
x
.
clone
(),
group_size
,
dst_dtype
x
.
clone
(),
group_size
,
dst_dtype
...
@@ -196,7 +198,9 @@ def benchmark(batch_size, seq_len, group_size, dst_dtype, provider):
...
@@ -196,7 +198,9 @@ def benchmark(batch_size, seq_len, group_size, dst_dtype, provider):
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
hidden_dim
=
7168
hidden_dim
=
7168
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_dim
,
device
=
device
,
dtype
=
torch
.
float16
)
x
=
torch
.
randn
(
batch_size
*
seq_len
,
hidden_dim
,
device
=
device
,
dtype
=
torch
.
float16
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
quantiles
=
[
0.5
,
0.2
,
0.8
]
...
...
sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
View file @
2c8fd993
...
@@ -16,7 +16,7 @@ __device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
...
@@ -16,7 +16,7 @@ __device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
return
val
;
return
val
;
}
}
template
<
typename
T
,
typename
DST_DTYPE
>
template
<
typename
T
,
typename
DST_DTYPE
,
bool
IS_COLUMN_MAJOR
=
false
>
__global__
void
per_token_group_quant_8bit_kernel
(
__global__
void
per_token_group_quant_8bit_kernel
(
const
T
*
__restrict__
input
,
const
T
*
__restrict__
input
,
void
*
__restrict__
output_q
,
void
*
__restrict__
output_q
,
...
@@ -26,19 +26,30 @@ __global__ void per_token_group_quant_8bit_kernel(
...
@@ -26,19 +26,30 @@ __global__ void per_token_group_quant_8bit_kernel(
const
int
groups_per_block
,
const
int
groups_per_block
,
const
float
eps
,
const
float
eps
,
const
float
min_8bit
,
const
float
min_8bit
,
const
float
max_8bit
)
{
const
float
max_8bit
,
const
int
scale_num_rows
=
0
,
const
int
scale_stride
=
0
)
{
const
int
threads_per_group
=
16
;
const
int
threads_per_group
=
16
;
const
int
local_group_id
=
threadIdx
.
x
/
threads_per_group
;
const
int
local_group_id
=
threadIdx
.
x
/
threads_per_group
;
const
int
lane_id
=
threadIdx
.
x
%
threads_per_group
;
const
int
lane_id
=
threadIdx
.
x
%
threads_per_group
;
const
int
block_group_id
=
blockIdx
.
x
*
groups_per_block
;
const
int
block_group_id
=
blockIdx
.
x
*
groups_per_block
;
const
int
block_group_offset
=
(
block_group_id
+
local_group_id
)
*
group_size
;
const
int
global_group_id
=
block_group_id
+
local_group_id
;
const
int
block_group_offset
=
global_group_id
*
group_size
;
float
local_absmax
=
eps
;
float
local_absmax
=
eps
;
const
T
*
group_input
=
input
+
block_group_offset
;
const
T
*
group_input
=
input
+
block_group_offset
;
DST_DTYPE
*
group_output
=
static_cast
<
DST_DTYPE
*>
(
output_q
)
+
block_group_offset
;
DST_DTYPE
*
group_output
=
static_cast
<
DST_DTYPE
*>
(
output_q
)
+
block_group_offset
;
float
*
scale_output
=
output_s
+
(
block_group_id
+
local_group_id
);
float
*
scale_output
;
if
constexpr
(
IS_COLUMN_MAJOR
)
{
const
int
row_idx
=
global_group_id
/
scale_num_rows
;
const
int
col_idx
=
global_group_id
%
scale_num_rows
;
scale_output
=
output_s
+
(
col_idx
*
scale_stride
+
row_idx
);
}
else
{
scale_output
=
output_s
+
global_group_id
;
}
constexpr
uint32_t
vec_size
=
16
/
sizeof
(
T
);
constexpr
uint32_t
vec_size
=
16
/
sizeof
(
T
);
using
vec_t
=
flashinfer
::
vec_t
<
T
,
vec_size
>
;
using
vec_t
=
flashinfer
::
vec_t
<
T
,
vec_size
>
;
...
@@ -88,11 +99,11 @@ void sgl_per_token_group_quant_8bit(
...
@@ -88,11 +99,11 @@ void sgl_per_token_group_quant_8bit(
double
max_8bit
)
{
double
max_8bit
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
output_q
);
CHECK_INPUT
(
output_q
);
CHECK_INPUT
(
output_s
);
const
int
num_groups
=
input
.
numel
()
/
group_size
;
const
int
num_groups
=
input
.
numel
()
/
group_size
;
CHECK_EQ
(
input
.
numel
()
%
group_size
,
0
);
CHECK_EQ
(
input
.
numel
()
%
group_size
,
0
);
CHECK_EQ
(
output_s
.
dim
(),
2
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
...
@@ -114,20 +125,39 @@ void sgl_per_token_group_quant_8bit(
...
@@ -114,20 +125,39 @@ void sgl_per_token_group_quant_8bit(
const
int
num_blocks
=
num_groups
/
groups_per_block
;
const
int
num_blocks
=
num_groups
/
groups_per_block
;
const
int
num_threads
=
groups_per_block
*
THREADS_PER_GROUP
;
const
int
num_threads
=
groups_per_block
*
THREADS_PER_GROUP
;
#define LAUNCH_KERNEL(T, DST_DTYPE) \
const
bool
is_column_major
=
output_s
.
stride
(
0
)
<
output_s
.
stride
(
1
);
do { \
const
int
scale_num_rows
=
output_s
.
size
(
1
);
dim3 grid(num_blocks); \
const
int
scale_stride
=
output_s
.
stride
(
1
);
dim3 block(num_threads); \
per_token_group_quant_8bit_kernel<T, DST_DTYPE><<<grid, block, 0, stream>>>( \
#define LAUNCH_KERNEL(T, DST_DTYPE) \
static_cast<T*>(input.data_ptr()), \
do { \
output_q.data_ptr(), \
dim3 grid(num_blocks); \
static_cast<float*>(output_s.data_ptr()), \
dim3 block(num_threads); \
group_size, \
if (is_column_major) { \
num_groups, \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true><<<grid, block, 0, stream>>>( \
groups_per_block, \
static_cast<T*>(input.data_ptr()), \
(float)eps, \
output_q.data_ptr(), \
(float)min_8bit, \
static_cast<float*>(output_s.data_ptr()), \
(float)max_8bit); \
group_size, \
num_groups, \
groups_per_block, \
(float)eps, \
(float)min_8bit, \
(float)max_8bit, \
scale_num_rows, \
scale_stride); \
} else { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false><<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), \
group_size, \
num_groups, \
groups_per_block, \
(float)eps, \
(float)min_8bit, \
(float)max_8bit); \
} \
} while (0)
} while (0)
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
scalar_t
,
[
&
]
{
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
scalar_t
,
[
&
]
{
...
...
sgl-kernel/tests/test_per_token_group_quant_8bit.py
View file @
2c8fd993
...
@@ -9,12 +9,12 @@ from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_
...
@@ -9,12 +9,12 @@ from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_
from
sglang.srt.utils
import
is_hip
from
sglang.srt.utils
import
is_hip
is_hip
_
=
is_hip
()
_
is_hip
=
is_hip
()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
is_hip
_
else
torch
.
float8_e4m3fn
fp8_type_
=
torch
.
float8_e4m3fnuz
if
_
is_hip
else
torch
.
float8_e4m3fn
@
triton
.
jit
@
triton
.
jit
def
_per_token_group_quant_
8bit
(
def
_per_token_group_quant_
fp8
(
# Pointers to inputs and output
# Pointers to inputs and output
y_ptr
,
y_ptr
,
y_q_ptr
,
y_q_ptr
,
...
@@ -25,15 +25,16 @@ def _per_token_group_quant_8bit(
...
@@ -25,15 +25,16 @@ def _per_token_group_quant_8bit(
N
,
N
,
# Avoid to divide zero
# Avoid to divide zero
eps
,
eps
,
# Information for
8bit data type (int8 or fp8_type_)
# Information for
float8
max_8bit
,
fp8_min
,
min_8bit
,
fp8_max
,
# Meta-parameters
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
,
):
):
"""A Triton-accelerated function to perform per-token-group quantization on a
"""A Triton-accelerated function to perform per-token-group quantization on a
tensor.
tensor.
This function converts the tensor values into 8bit values.
This function converts the tensor values into float8 values.
"""
"""
# Map the program id to the row of X and Y it should compute.
# Map the program id to the row of X and Y it should compute.
g_id
=
tl
.
program_id
(
0
)
g_id
=
tl
.
program_id
(
0
)
...
@@ -47,8 +48,57 @@ def _per_token_group_quant_8bit(
...
@@ -47,8 +48,57 @@ def _per_token_group_quant_8bit(
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
max_8bit
y_s
=
_absmax
/
fp8_max
y_q
=
tl
.
clamp
(
y
/
y_s
,
min_8bit
,
max_8bit
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
y_s_inv
=
1.0
/
y_s
y_q
=
tl
.
clamp
(
y
*
y_s_inv
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
@
triton
.
jit
def
_per_token_group_quant_fp8_colmajor
(
# Pointers to inputs and output
y_ptr
,
y_q_ptr
,
y_s_ptr
,
group_size
,
# Num columns of y
y_num_columns
,
# Stride from one column to the next of y_s
y_s_col_stride
,
# Avoid to divide zero
eps
,
# Information for float8
fp8_min
,
fp8_max
,
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id
=
tl
.
program_id
(
0
)
y_ptr
+=
g_id
*
group_size
y_q_ptr
+=
g_id
*
group_size
# Convert g_id the flattened block coordinate to 2D so we can index
# into the output y_scales matrix
blocks_per_row
=
y_num_columns
//
group_size
scale_col
=
g_id
%
blocks_per_row
scale_row
=
g_id
//
blocks_per_row
y_s_ptr
+=
scale_col
*
y_s_col_stride
+
scale_row
cols
=
tl
.
arange
(
0
,
BLOCK
)
# group_size <= BLOCK
mask
=
cols
<
group_size
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
fp8_max
y_q
=
tl
.
clamp
(
y
/
y_s
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
tl
.
store
(
y_s_ptr
,
y_s
)
...
@@ -57,17 +107,22 @@ def _per_token_group_quant_8bit(
...
@@ -57,17 +107,22 @@ def _per_token_group_quant_8bit(
def
triton_per_token_group_quant_8bit
(
def
triton_per_token_group_quant_8bit
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
group_size
:
int
,
group_size
:
int
,
dst_dtype
:
torch
.
dtype
,
eps
:
float
=
1e-10
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
fp8_type_
,
column_major_scales
:
bool
=
False
,
scale_tma_aligned
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Function to perform per-token-group quantization on an input tensor `x`.
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
quantized tensor along with the scaling factor used for quantization.
Args:
Args:
x: The input tenosr with ndim >= 2.
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
dtype: The dype of output tensor.
Returns:
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
"""
"""
...
@@ -76,41 +131,79 @@ def triton_per_token_group_quant_8bit(
...
@@ -76,41 +131,79 @@ def triton_per_token_group_quant_8bit(
),
"the last dimension of `x` cannot be divisible by `group_size`"
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
if
dst_dtype
==
torch
.
int8
:
if
dtype
==
torch
.
int8
:
iinfo
=
torch
.
iinfo
(
dst_dtype
)
finfo
=
torch
.
iinfo
(
dtype
)
max_8bit
=
iinfo
.
max
min_8bit
=
iinfo
.
min
else
:
else
:
finfo
=
torch
.
finfo
(
dst_dtype
)
finfo
=
torch
.
finfo
(
dtype
)
max_8bit
=
finfo
.
max
min_8bit
=
finfo
.
min
fp8_max
=
finfo
.
max
if
_is_hip
:
if
dtype
==
torch
.
int8
:
fp8_max
=
127.0
else
:
fp8_max
=
224.0
fp8_min
=
-
fp8_max
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dst_
dtype
)
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
M
=
x
.
numel
()
//
group_size
N
=
group_size
N
=
group_size
x_s
=
torch
.
empty
(
if
column_major_scales
:
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
if
scale_tma_aligned
:
device
=
x
.
device
,
# aligned to 4 * sizeof(float)
dtype
=
torch
.
float32
,
aligned_size
=
(
x
.
shape
[
-
2
]
+
3
)
//
4
*
4
)
x_s
=
torch
.
empty
(
x
.
shape
[:
-
2
]
+
(
x
.
shape
[
-
1
]
//
group_size
,
aligned_size
),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
).
permute
(
-
1
,
-
2
)[:
x
.
shape
[
-
2
],
:]
else
:
x_s
=
torch
.
empty
(
(
x
.
shape
[
-
1
]
//
group_size
,)
+
x
.
shape
[:
-
1
],
device
=
x
.
device
,
dtype
=
torch
.
float32
,
).
permute
(
-
1
,
-
2
)
else
:
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
# heuristics for number of warps
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_stages
=
1
num_stages
=
1
_per_token_group_quant_8bit
[(
M
,)](
if
column_major_scales
:
x
,
_per_token_group_quant_fp8_colmajor
[(
M
,)](
x_q
,
x
,
x_s
,
x_q
,
group_size
,
x_s
,
N
,
group_size
,
eps
,
x
.
shape
[
1
],
max_8bit
,
x_s
.
stride
(
1
),
min_8bit
,
eps
,
BLOCK
=
BLOCK
,
fp8_min
=
fp8_min
,
num_warps
=
num_warps
,
fp8_max
=
fp8_max
,
num_stages
=
num_stages
,
BLOCK
=
BLOCK
,
)
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
else
:
_per_token_group_quant_fp8
[(
M
,)](
x
,
x_q
,
x_s
,
group_size
,
N
,
eps
,
fp8_min
=
fp8_min
,
fp8_max
=
fp8_max
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
return
x_q
,
x_s
return
x_q
,
x_s
...
@@ -118,28 +211,48 @@ def triton_per_token_group_quant_8bit(
...
@@ -118,28 +211,48 @@ def triton_per_token_group_quant_8bit(
def
sglang_per_token_group_quant_8bit
(
def
sglang_per_token_group_quant_8bit
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
group_size
:
int
,
group_size
:
int
,
dst_dtype
:
torch
.
dtype
,
eps
:
float
=
1e-10
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
fp8_type_
,
column_major_scales
:
bool
=
False
,
scale_tma_aligned
:
bool
=
False
,
):
):
assert
(
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
x
.
shape
[
-
1
]
%
group_size
==
0
),
"the last dimension of `x` cannot be divisible by `group_size`"
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dst_dtype
)
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
x_s
=
torch
.
empty
(
M
=
x
.
numel
()
//
group_size
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
N
=
group_size
device
=
x
.
device
,
if
column_major_scales
:
dtype
=
torch
.
float32
,
if
scale_tma_aligned
:
)
# aligned to 4 * sizeof(float)
aligned_size
=
(
x
.
shape
[
-
2
]
+
3
)
//
4
*
4
x_s
=
torch
.
empty
(
x
.
shape
[:
-
2
]
+
(
x
.
shape
[
-
1
]
//
group_size
,
aligned_size
),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
).
permute
(
-
1
,
-
2
)[:
x
.
shape
[
-
2
],
:]
else
:
x_s
=
torch
.
empty
(
(
x
.
shape
[
-
1
]
//
group_size
,)
+
x
.
shape
[:
-
1
],
device
=
x
.
device
,
dtype
=
torch
.
float32
,
).
permute
(
-
1
,
-
2
)
else
:
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
if
dst_
dtype
==
torch
.
int8
:
if
dtype
==
torch
.
int8
:
iinfo
=
torch
.
iinfo
(
dst_
dtype
)
iinfo
=
torch
.
iinfo
(
dtype
)
int8_max
=
iinfo
.
max
int8_max
=
iinfo
.
max
int8_min
=
iinfo
.
min
int8_min
=
iinfo
.
min
sgl_per_token_group_quant_int8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
int8_min
,
int8_max
)
sgl_per_token_group_quant_int8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
int8_min
,
int8_max
)
else
:
else
:
f8_info
=
torch
.
finfo
(
dst_
dtype
)
f8_info
=
torch
.
finfo
(
dtype
)
fp8_max
=
f8_info
.
max
fp8_max
=
f8_info
.
max
fp8_min
=
f8_info
.
min
fp8_min
=
f8_info
.
min
sgl_per_token_group_quant_fp8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
)
sgl_per_token_group_quant_fp8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
)
...
@@ -148,30 +261,55 @@ def sglang_per_token_group_quant_8bit(
...
@@ -148,30 +261,55 @@ def sglang_per_token_group_quant_8bit(
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"
batch_size, seq_len
, group_size, dst_dtype"
,
"
num_tokens, hidden_dim
, group_size, dst_dtype
, column_major_scales, scale_tma_aligned
"
,
list
(
list
(
itertools
.
product
(
itertools
.
product
(
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
],
# batch_size
[
1
27
,
128
,
512
,
1024
,
4096
,
8192
],
# num_tokens
[
64
,
128
,
256
,
512
,
1024
,
2048
],
# seq_len
[
256
,
512
,
1024
,
2048
,
4096
],
# hidden_dim
[
16
,
32
,
64
,
128
,
256
],
# group_size
[
8
,
16
,
32
,
64
,
128
],
# group_size
[
torch
.
int8
,
fp8_type_
],
# dtype
[
torch
.
int8
,
fp8_type_
],
# dtype
[
False
,
True
],
# column_major_scales
[
False
,
True
],
# scale_tma_aligned
)
)
),
),
)
)
def
test_per_token_group_quant_compare_implementations
(
def
test_per_token_group_quant_with_column_major
(
batch_size
,
seq_len
,
group_size
,
dst_dtype
num_tokens
,
hidden_dim
,
group_size
,
dst_dtype
,
column_major_scales
,
scale_tma_aligned
,
):
):
x
=
torch
.
randn
(
if
not
column_major_scales
and
scale_tma_aligned
:
(
batch_size
,
seq_len
,
group_size
*
2
),
device
=
"cuda"
,
dtype
=
torch
.
float16
return
x
=
torch
.
randn
(
num_tokens
,
hidden_dim
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
x_q_triton
,
x_s_triton
=
triton_per_token_group_quant_8bit
(
x
,
group_size
,
eps
=
1e-10
,
dtype
=
dst_dtype
,
column_major_scales
=
column_major_scales
,
scale_tma_aligned
=
scale_tma_aligned
,
)
)
x_q_triton
,
x_s_triton
=
triton_per_token_group_quant_8bit
(
x
,
group_size
,
dst_dtype
)
x_q_sglang
,
x_s_sglang
=
sglang_per_token_group_quant_8bit
(
x_q_sglang
,
x_s_sglang
=
sglang_per_token_group_quant_8bit
(
x
,
group_size
,
dst_dtype
)
x
,
group_size
,
eps
=
1e-10
,
dtype
=
dst_dtype
,
column_major_scales
=
column_major_scales
,
scale_tma_aligned
=
scale_tma_aligned
,
)
assert
torch
.
allclose
(
assert
torch
.
allclose
(
x_q_triton
.
to
(
torch
.
float32
),
x_q_sglang
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
x_q_triton
.
to
(
torch
.
float32
),
x_q_sglang
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
)
)
assert
torch
.
allclose
(
x_s_triton
,
x_s_sglang
,
rtol
=
1e-3
,
atol
=
1e-5
)
assert
torch
.
allclose
(
x_s_triton
.
contiguous
(),
x_s_sglang
.
contiguous
(),
rtol
=
1e-3
,
atol
=
1e-5
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment