Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bb418ced
Unverified
Commit
bb418ced
authored
Feb 11, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Feb 11, 2025
Browse files
optimize per token group quant fp8 (#3490)
parent
fdf04a14
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
509 additions
and
0 deletions
+509
-0
sgl-kernel/benchmark/bench_per_token_group_quant_fp8.py
sgl-kernel/benchmark/bench_per_token_group_quant_fp8.py
+209
-0
sgl-kernel/setup.py
sgl-kernel/setup.py
+1
-0
sgl-kernel/src/sgl-kernel/__init__.py
sgl-kernel/src/sgl-kernel/__init__.py
+2
-0
sgl-kernel/src/sgl-kernel/csrc/per_token_group_quant_fp8.cu
sgl-kernel/src/sgl-kernel/csrc/per_token_group_quant_fp8.cu
+100
-0
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
+4
-0
sgl-kernel/src/sgl-kernel/ops/__init__.py
sgl-kernel/src/sgl-kernel/ops/__init__.py
+14
-0
sgl-kernel/src/sgl-kernel/torch_extension.cc
sgl-kernel/src/sgl-kernel/torch_extension.cc
+6
-0
sgl-kernel/tests/test_per_token_group_quant_fp8.py
sgl-kernel/tests/test_per_token_group_quant_fp8.py
+173
-0
No files found.
sgl-kernel/benchmark/bench_per_token_group_quant_fp8.py
0 → 100644
View file @
bb418ced
import
itertools
import
math
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
triton
import
triton.language
as
tl
from
sgl_kernel
import
sgl_per_token_group_quant_fp8
from
sglang.srt.utils
import
get_device_core_count
,
get_device_name
,
is_hip
is_hip_
=
is_hip
()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
is_hip_
else
torch
.
float8_e4m3fn
@
triton
.
jit
def
_per_token_group_quant_fp8
(
# Pointers to inputs and output
y_ptr
,
y_q_ptr
,
y_s_ptr
,
# Stride of input
y_stride
,
# Collums of input
N
,
# Avoid to divide zero
eps
,
# Information for float8
fp8_min
,
fp8_max
,
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
):
"""A Triton-accelerated function to perform per-token-group quantization on a
tensor.
This function converts the tensor values into float8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id
=
tl
.
program_id
(
0
)
y_ptr
+=
g_id
*
y_stride
y_q_ptr
+=
g_id
*
y_stride
y_s_ptr
+=
g_id
cols
=
tl
.
arange
(
0
,
BLOCK
)
# N <= BLOCK
mask
=
cols
<
N
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
fp8_max
y_q
=
tl
.
clamp
(
y
/
y_s
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
def
triton_per_token_group_quant_fp8
(
x
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
fp8_type_
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
"""
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
finfo
=
torch
.
finfo
(
dtype
)
fp8_max
=
finfo
.
max
fp8_min
=
-
fp8_max
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_stages
=
1
_per_token_group_quant_fp8
[(
M
,)](
x
,
x_q
,
x_s
,
group_size
,
N
,
eps
,
fp8_min
=
fp8_min
,
fp8_max
=
fp8_max
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
return
x_q
,
x_s
def
sglang_per_token_group_quant_fp8
(
x
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
fp8_type_
,
):
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
finfo
=
torch
.
finfo
(
dtype
)
fp8_max
=
finfo
.
max
fp8_min
=
-
fp8_max
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
sgl_per_token_group_quant_fp8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
)
return
x_q
,
x_s
def
calculate_diff
(
batch_size
,
seq_len
,
group_size
):
dtype
=
torch
.
float16
device
=
torch
.
device
(
"cuda"
)
hidden_dim
=
group_size
*
2
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_dim
,
device
=
device
,
dtype
=
dtype
)
x_q_triton
,
x_s_triton
=
triton_per_token_group_quant_fp8
(
x
.
clone
(),
group_size
)
x_q_sglang
,
x_s_sglang
=
sglang_per_token_group_quant_fp8
(
x
.
clone
(),
group_size
)
if
torch
.
allclose
(
x_q_triton
.
to
(
torch
.
float32
),
x_q_sglang
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
)
and
torch
.
allclose
(
x_s_triton
,
x_s_sglang
,
rtol
=
1e-3
,
atol
=
1e-5
):
print
(
"✅ All implementations match"
)
else
:
print
(
"❌ Implementations differ"
)
batch_size_range
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
]
seq_len_range
=
[
64
,
128
,
256
,
512
,
1024
,
2048
]
group_size_range
=
[
128
]
# For DeepSeek V3/R1
configs
=
list
(
itertools
.
product
(
batch_size_range
,
seq_len_range
,
group_size_range
))
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
,
"seq_len"
,
"group_size"
],
x_vals
=
configs
,
line_arg
=
"provider"
,
line_vals
=
[
"triton"
,
"sglang"
],
line_names
=
[
"Triton"
,
"SGL Kernel"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
"per-token-group-quant-fp8-performance"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
seq_len
,
group_size
,
provider
):
dtype
=
torch
.
bfloat16
device
=
torch
.
device
(
"cuda"
)
hidden_dim
=
group_size
*
2
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_dim
,
device
=
device
,
dtype
=
dtype
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"triton"
:
fn
=
lambda
:
triton_per_token_group_quant_fp8
(
x
.
clone
(),
group_size
)
elif
provider
==
"sglang"
:
fn
=
lambda
:
sglang_per_token_group_quant_fp8
(
x
.
clone
(),
group_size
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
fn
,
quantiles
=
quantiles
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
if
__name__
==
"__main__"
:
calculate_diff
(
batch_size
=
4
,
seq_len
=
128
,
group_size
=
64
)
benchmark
.
run
(
print_data
=
True
)
sgl-kernel/setup.py
View file @
bb418ced
...
@@ -100,6 +100,7 @@ sources = [
...
@@ -100,6 +100,7 @@ sources = [
"src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu"
,
"src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu"
,
"src/sgl-kernel/csrc/eagle_utils.cu"
,
"src/sgl-kernel/csrc/eagle_utils.cu"
,
"src/sgl-kernel/csrc/speculative_sampling.cu"
,
"src/sgl-kernel/csrc/speculative_sampling.cu"
,
"src/sgl-kernel/csrc/per_token_group_quant_fp8.cu"
,
"3rdparty/flashinfer/csrc/activation.cu"
,
"3rdparty/flashinfer/csrc/activation.cu"
,
"3rdparty/flashinfer/csrc/bmm_fp8.cu"
,
"3rdparty/flashinfer/csrc/bmm_fp8.cu"
,
"3rdparty/flashinfer/csrc/norm.cu"
,
"3rdparty/flashinfer/csrc/norm.cu"
,
...
...
sgl-kernel/src/sgl-kernel/__init__.py
View file @
bb418ced
...
@@ -29,6 +29,7 @@ from sgl_kernel.ops import (
...
@@ -29,6 +29,7 @@ from sgl_kernel.ops import (
register_graph_buffers
,
register_graph_buffers
,
rmsnorm
,
rmsnorm
,
sampling_scaling_penalties
,
sampling_scaling_penalties
,
sgl_per_token_group_quant_fp8
,
silu_and_mul
,
silu_and_mul
,
top_k_renorm_prob
,
top_k_renorm_prob
,
top_k_top_p_sampling_from_probs
,
top_k_top_p_sampling_from_probs
,
...
@@ -65,4 +66,5 @@ __all__ = [
...
@@ -65,4 +66,5 @@ __all__ = [
"tree_speculative_sampling_target_only"
,
"tree_speculative_sampling_target_only"
,
"build_tree_kernel_efficient"
,
"build_tree_kernel_efficient"
,
"build_tree_kernel"
,
"build_tree_kernel"
,
"sgl_per_token_group_quant_fp8"
,
]
]
sgl-kernel/src/sgl-kernel/csrc/per_token_group_quant_fp8.cu
0 → 100644
View file @
bb418ced
#include <ATen/cuda/CUDAContext.h>
#include <c10/util/Float8_e4m3fn.h>
#include <cmath>
#include "utils.h"
using
FP8_TYPE
=
c10
::
Float8_e4m3fn
;
__device__
__forceinline__
float
WarpReduce
(
volatile
float
*
smem
,
const
int
tid
)
{
if
(
tid
<
8
)
{
smem
[
tid
]
=
fmaxf
(
smem
[
tid
],
smem
[
tid
+
8
]);
if
(
tid
<
4
)
smem
[
tid
]
=
fmaxf
(
smem
[
tid
],
smem
[
tid
+
4
]);
if
(
tid
<
2
)
smem
[
tid
]
=
fmaxf
(
smem
[
tid
],
smem
[
tid
+
2
]);
if
(
tid
<
1
)
smem
[
tid
]
=
fmaxf
(
smem
[
tid
],
smem
[
tid
+
1
]);
}
return
smem
[
0
];
}
template
<
typename
T
>
__global__
void
per_token_group_quant_fp8_kernel
(
const
T
*
__restrict__
input
,
void
*
__restrict__
output_q
,
float
*
__restrict__
output_s
,
const
int
group_size
,
const
int
num_groups
,
const
float
eps
,
const
float
fp8_min
,
const
float
fp8_max
)
{
const
int
groups_per_block
=
16
;
const
int
block_group_id
=
blockIdx
.
x
*
groups_per_block
;
const
int
tid
=
threadIdx
.
x
;
const
int
local_group_id
=
tid
/
16
;
// Each 16 threads handle one group
const
int
local_tid
=
tid
%
16
;
// Thread ID within the group
__shared__
float
s_absmax
[
16
][
17
];
// Use 17 instead of 16 to avoid bank conflicts
// Local maximum value for each thread
float
local_absmax
=
eps
;
// Ensure this block doesn't process out-of-bounds groups
if
(
block_group_id
+
local_group_id
<
num_groups
)
{
// Calculate input/output pointers for current group
const
T
*
group_input
=
input
+
(
block_group_id
+
local_group_id
)
*
group_size
;
FP8_TYPE
*
group_output
=
static_cast
<
FP8_TYPE
*>
(
output_q
)
+
(
block_group_id
+
local_group_id
)
*
group_size
;
float
*
scale_output
=
output_s
+
block_group_id
+
local_group_id
;
// Calculate local maximum absolute value
for
(
int
i
=
local_tid
;
i
<
group_size
;
i
+=
16
)
{
float
val
=
static_cast
<
float
>
(
group_input
[
i
]);
float
abs_val
=
fabsf
(
val
);
local_absmax
=
fmaxf
(
local_absmax
,
abs_val
);
}
// Store in shared memory
s_absmax
[
local_group_id
][
local_tid
]
=
local_absmax
;
__syncthreads
();
// Perform reduction within each group
if
(
local_tid
<
8
)
{
WarpReduce
(
&
s_absmax
[
local_group_id
][
0
],
local_tid
);
}
__syncthreads
();
// Get the maximum value for this group
const
float
group_absmax
=
s_absmax
[
local_group_id
][
0
];
const
float
y_s
=
group_absmax
/
fp8_max
;
// Only the first thread in each group writes the scale
if
(
local_tid
==
0
)
{
*
scale_output
=
y_s
;
}
// Quantize the data
for
(
int
i
=
local_tid
;
i
<
group_size
;
i
+=
16
)
{
float
val
=
static_cast
<
float
>
(
group_input
[
i
]);
float
q_val
=
fminf
(
fmaxf
(
val
/
y_s
,
fp8_min
),
fp8_max
);
group_output
[
i
]
=
FP8_TYPE
(
q_val
);
}
}
}
void
sgl_per_token_group_quant_fp8
(
torch
::
Tensor
input
,
torch
::
Tensor
output_q
,
torch
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
fp8_min
,
double
fp8_max
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
output_q
);
CHECK_INPUT
(
output_s
);
const
int
num_groups
=
input
.
numel
()
/
group_size
;
CHECK_EQ
(
input
.
numel
()
%
group_size
,
0
);
// Each block processes 16 groups, adjust grid size accordingly
dim3
grid
((
num_groups
+
15
)
/
16
);
dim3
block
(
256
);
// Keep 256 threads, each 16 threads handle one group
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
scalar_t
,
[
&
]
{
per_token_group_quant_fp8_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
scalar_t
*>
(
input
.
data_ptr
()),
output_q
.
data_ptr
(),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
group_size
,
num_groups
,
(
float
)
eps
,
(
float
)
fp8_min
,
(
float
)
fp8_max
);
return
true
;
});
}
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
View file @
bb418ced
...
@@ -143,3 +143,7 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind
...
@@ -143,3 +143,7 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind
void
build_tree_kernel
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
void
build_tree_kernel
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
int64_t
topk
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
int64_t
topk
,
int64_t
depth
,
int64_t
draft_token_num
);
int64_t
depth
,
int64_t
draft_token_num
);
// sgl_per_token_group_quant_fp8
void
sgl_per_token_group_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
fp8_min
,
double
fp8_max
);
sgl-kernel/src/sgl-kernel/ops/__init__.py
View file @
bb418ced
...
@@ -579,3 +579,17 @@ def build_tree_kernel(
...
@@ -579,3 +579,17 @@ def build_tree_kernel(
depth
,
depth
,
draft_token_num
,
draft_token_num
,
)
)
def
sgl_per_token_group_quant_fp8
(
input
:
torch
.
Tensor
,
output_q
:
torch
.
Tensor
,
output_s
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
,
fp8_min
:
float
,
fp8_max
:
float
,
)
->
None
:
torch
.
ops
.
sgl_kernels
.
sgl_per_token_group_quant_fp8
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
)
sgl-kernel/src/sgl-kernel/torch_extension.cc
View file @
bb418ced
...
@@ -153,6 +153,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
...
@@ -153,6 +153,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, "
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, "
"int topk, int depth, int draft_token_num) -> ()"
);
"int topk, int depth, int draft_token_num) -> ()"
);
m
.
impl
(
"build_tree_kernel"
,
torch
::
kCUDA
,
&
build_tree_kernel
);
m
.
impl
(
"build_tree_kernel"
,
torch
::
kCUDA
,
&
build_tree_kernel
);
// per_token_group_quant_fp8
m
.
def
(
"sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float fp8_min, float fp8_max) -> ()"
);
m
.
impl
(
"sgl_per_token_group_quant_fp8"
,
torch
::
kCUDA
,
&
sgl_per_token_group_quant_fp8
);
}
}
REGISTER_EXTENSION
(
_kernels
)
REGISTER_EXTENSION
(
_kernels
)
sgl-kernel/tests/test_per_token_group_quant_fp8.py
0 → 100644
View file @
bb418ced
import
itertools
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
pytest
import
torch
import
triton
import
triton.language
as
tl
from
sgl_kernel
import
sgl_per_token_group_quant_fp8
from
sglang.srt.utils
import
get_device_core_count
,
get_device_name
,
is_hip
is_hip_
=
is_hip
()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
is_hip_
else
torch
.
float8_e4m3fn
@
triton
.
jit
def
_per_token_group_quant_fp8
(
# Pointers to inputs and output
y_ptr
,
y_q_ptr
,
y_s_ptr
,
# Stride of input
y_stride
,
# Collums of input
N
,
# Avoid to divide zero
eps
,
# Information for float8
fp8_min
,
fp8_max
,
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
):
"""A Triton-accelerated function to perform per-token-group quantization on a
tensor.
This function converts the tensor values into float8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id
=
tl
.
program_id
(
0
)
y_ptr
+=
g_id
*
y_stride
y_q_ptr
+=
g_id
*
y_stride
y_s_ptr
+=
g_id
cols
=
tl
.
arange
(
0
,
BLOCK
)
# N <= BLOCK
mask
=
cols
<
N
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
fp8_max
y_q
=
tl
.
clamp
(
y
/
y_s
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
def
triton_per_token_group_quant_fp8
(
x
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
fp8_type_
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
"""
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
finfo
=
torch
.
finfo
(
dtype
)
fp8_max
=
finfo
.
max
fp8_min
=
-
fp8_max
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_stages
=
1
_per_token_group_quant_fp8
[(
M
,)](
x
,
x_q
,
x_s
,
group_size
,
N
,
eps
,
fp8_min
=
fp8_min
,
fp8_max
=
fp8_max
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
return
x_q
,
x_s
def
sglang_per_token_group_quant_fp8
(
x
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
fp8_type_
,
):
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
finfo
=
torch
.
finfo
(
dtype
)
fp8_max
=
finfo
.
max
fp8_min
=
-
fp8_max
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
sgl_per_token_group_quant_fp8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
)
return
x_q
,
x_s
@
pytest
.
mark
.
parametrize
(
"batch_size, seq_len, group_size"
,
list
(
itertools
.
product
(
[
1
,
2
,
4
,
8
,
16
],
# batch_size
[
64
,
128
,
256
,
512
,
1024
,
2048
],
# seq_len
[
64
,
128
,
256
],
# group_size
)
),
)
def
test_per_token_group_quant_compare_implementations
(
batch_size
,
seq_len
,
group_size
):
x
=
torch
.
randn
(
(
batch_size
,
seq_len
,
group_size
*
2
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
x_q_triton
,
x_s_triton
=
triton_per_token_group_quant_fp8
(
x
,
group_size
)
x_q_sglang
,
x_s_sglang
=
sglang_per_token_group_quant_fp8
(
x
,
group_size
)
assert
torch
.
allclose
(
x_q_triton
.
to
(
torch
.
float32
),
x_q_sglang
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
)
assert
torch
.
allclose
(
x_s_triton
,
x_s_sglang
,
rtol
=
1e-3
,
atol
=
1e-5
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment