Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ad55f171
Unverified
Commit
ad55f171
authored
Mar 07, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Mar 06, 2025
Browse files
[quant kernel] sgl-kernel support per_tensor_quant fp8 (#3786)
parent
361971b8
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
343 additions
and
0 deletions
+343
-0
sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py
sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py
+98
-0
sgl-kernel/setup.py
sgl-kernel/setup.py
+1
-0
sgl-kernel/src/sgl-kernel/__init__.py
sgl-kernel/src/sgl-kernel/__init__.py
+1
-0
sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
+163
-0
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
+1
-0
sgl-kernel/src/sgl-kernel/ops/gemm.py
sgl-kernel/src/sgl-kernel/ops/gemm.py
+9
-0
sgl-kernel/src/sgl-kernel/torch_extension.cc
sgl-kernel/src/sgl-kernel/torch_extension.cc
+3
-0
sgl-kernel/tests/test_per_tensor_quant_fp8.py
sgl-kernel/tests/test_per_tensor_quant_fp8.py
+67
-0
No files found.
sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py
0 → 100644
View file @
ad55f171
import
itertools
import
math
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
import
triton
import
triton.testing
from
sgl_kernel
import
sgl_per_tensor_quant_fp8
from
vllm
import
_custom_ops
as
ops
from
sglang.srt.utils
import
is_hip
is_hip_
=
is_hip
()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
is_hip_
else
torch
.
float8_e4m3fn
def
vllm_scaled_fp8_quant
(
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
ops
.
scaled_fp8_quant
(
input
,
scale
)
def
sglang_scaled_fp8_quant
(
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
fp8_type_
:
torch
.
dtype
=
torch
.
float8_e4m3fn
output
=
torch
.
empty_like
(
input
,
device
=
input
.
device
,
dtype
=
fp8_type_
)
is_static
=
True
if
scale
is
None
:
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
is_static
=
False
sgl_per_tensor_quant_fp8
(
input
,
output
,
scale
,
is_static
)
return
output
,
scale
def
calculate_diff
(
batch_size
:
int
,
seq_len
:
int
):
"""Calculate difference between VLLM and SGLang implementations."""
device
=
torch
.
device
(
"cuda"
)
x
=
torch
.
rand
((
batch_size
,
seq_len
),
dtype
=
torch
.
float16
,
device
=
device
)
vllm_out
,
vllm_scale
=
vllm_scaled_fp8_quant
(
x
)
sglang_out
,
sglang_scale
=
sglang_scaled_fp8_quant
(
x
)
scale_diff
=
torch
.
abs
(
vllm_scale
-
sglang_scale
).
item
()
output_diff
=
torch
.
abs
(
vllm_out
.
float
()
-
sglang_out
.
float
()).
mean
().
item
()
if
torch
.
allclose
(
vllm_out
.
to
(
torch
.
float32
),
sglang_out
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
)
and
torch
.
allclose
(
vllm_scale
,
sglang_scale
,
rtol
=
1e-3
,
atol
=
1e-5
):
print
(
"✅ All implementations match"
)
else
:
print
(
"❌ Implementations differ"
)
batch_size_range
=
[
16
,
32
,
64
,
128
]
seq_len_range
=
[
64
,
128
,
256
,
512
,
1024
,
2048
]
configs
=
list
(
itertools
.
product
(
batch_size_range
,
seq_len_range
))
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
,
"seq_len"
],
x_vals
=
configs
,
line_arg
=
"provider"
,
line_vals
=
[
"vllm"
,
"sglang"
],
line_names
=
[
"VLLM"
,
"SGL Kernel"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
"per-tensor-quant-fp8-performance"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
seq_len
,
provider
):
dtype
=
torch
.
float16
device
=
torch
.
device
(
"cuda"
)
x
=
torch
.
randn
(
batch_size
*
seq_len
,
4096
,
device
=
device
,
dtype
=
dtype
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"vllm"
:
fn
=
lambda
:
vllm_scaled_fp8_quant
(
x
.
clone
())
elif
provider
==
"sglang"
:
fn
=
lambda
:
sglang_scaled_fp8_quant
(
x
.
clone
())
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
fn
,
quantiles
=
quantiles
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
if
__name__
==
"__main__"
:
calculate_diff
(
batch_size
=
4
,
seq_len
=
4096
)
benchmark
.
run
(
print_data
=
True
)
sgl-kernel/setup.py
View file @
ad55f171
...
...
@@ -106,6 +106,7 @@ sources = [
"src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu"
,
"src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu"
,
"src/sgl-kernel/csrc/moe/moe_align_kernel.cu"
,
"src/sgl-kernel/csrc/speculative/eagle_utils.cu"
,
"src/sgl-kernel/csrc/speculative/speculative_sampling.cu"
,
...
...
sgl-kernel/src/sgl-kernel/__init__.py
View file @
ad55f171
...
...
@@ -27,6 +27,7 @@ from sgl_kernel.ops.gemm import (
fp8_blockwise_scaled_mm
,
fp8_scaled_mm
,
int8_scaled_mm
,
sgl_per_tensor_quant_fp8
,
sgl_per_token_group_quant_fp8
,
)
from
sgl_kernel.ops.moe
import
moe_align_block_size
...
...
sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
0 → 100644
View file @
ad55f171
#include <ATen/cuda/CUDAContext.h>
#include <c10/util/Float8_e4m3fn.h>
#include <cmath>
#include <cub/block/block_reduce.cuh>
#include <flashinfer/vec_dtypes.cuh>
#include "utils.h"
#define WARP_SIZE 32
#ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h>
using
FP8_TYPE
=
c10
::
Float8_e4m3fn
;
C10_HOST_DEVICE
constexpr
auto
FP8_E4M3_MAX
=
std
::
numeric_limits
<
FP8_TYPE
>::
max
();
#else
#include <c10/util/Float8_e4m3fnuz.h>
#include "amd/quant_utils.cuh"
using
FP8_TYPE
=
c10
::
Float8_e4m3fnuz
;
// Using the default max value from pytorch (240.0) will cause accuracy
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr
auto
FP8_E4M3_MAX
=
224.0
f
;
#endif
__device__
__forceinline__
float
atomicMaxFloat
(
float
*
addr
,
float
value
)
{
float
old
;
old
=
(
value
>=
0
)
?
__int_as_float
(
atomicMax
((
int
*
)
addr
,
__float_as_int
(
value
)))
:
__uint_as_float
(
atomicMin
((
unsigned
int
*
)
addr
,
__float_as_uint
(
value
)));
return
old
;
}
__device__
__forceinline__
float
warpReduceMax
(
float
max_value
)
{
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
16
));
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
8
));
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
4
));
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
2
));
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
1
));
return
max_value
;
}
template
<
typename
T
>
__global__
void
per_tensor_absmax_kernel
(
const
T
*
__restrict__
input
,
float
*
__restrict__
output_s
,
const
int64_t
num_elements
)
{
float
max_value
=
0.0
f
;
unsigned
int
tid
=
threadIdx
.
x
;
unsigned
int
gid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
grid_size
=
blockDim
.
x
*
gridDim
.
x
;
constexpr
uint32_t
vec_size
=
16
/
sizeof
(
T
);
using
vec_t
=
flashinfer
::
vec_t
<
T
,
vec_size
>
;
const
int32_t
num_vec_elems
=
num_elements
/
vec_size
;
for
(
int32_t
i
=
gid
;
i
<
num_vec_elems
;
i
+=
grid_size
)
{
vec_t
input_vec
;
input_vec
.
cast_load
(
input
+
i
*
vec_size
);
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
vec_size
;
++
j
)
{
float
val
=
static_cast
<
float
>
(
input_vec
[
j
]);
max_value
=
fmaxf
(
max_value
,
fabsf
(
val
));
}
}
const
int32_t
remaining_start
=
num_vec_elems
*
vec_size
;
for
(
int32_t
idx
=
remaining_start
+
gid
;
idx
<
num_elements
;
idx
+=
grid_size
)
{
float
val
=
static_cast
<
float
>
(
input
[
idx
]);
max_value
=
fmaxf
(
max_value
,
fabsf
(
val
));
}
static
__shared__
float
warpLevelMaxs
[
WARP_SIZE
];
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
max_value
=
warpReduceMax
(
max_value
);
if
(
laneId
==
0
)
warpLevelMaxs
[
warpId
]
=
max_value
;
__syncthreads
();
max_value
=
(
threadIdx
.
x
<
blockDim
.
x
/
WARP_SIZE
)
?
warpLevelMaxs
[
laneId
]
:
0
;
if
(
warpId
==
0
)
max_value
=
warpReduceMax
(
max_value
);
if
(
tid
==
0
)
{
atomicMaxFloat
(
output_s
,
max_value
/
FP8_E4M3_MAX
);
}
}
template
<
typename
T
>
__global__
void
per_tensor_quant_fp8_kernel
(
const
T
*
__restrict__
input
,
FP8_TYPE
*
__restrict__
output
,
const
float
*
__restrict__
scale
,
const
int64_t
num_elements
)
{
const
int
gid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
grid_size
=
blockDim
.
x
*
gridDim
.
x
;
const
float
scale_val
=
1.0
f
/
(
*
scale
);
constexpr
uint32_t
vec_size
=
16
/
sizeof
(
T
);
using
vec_t
=
flashinfer
::
vec_t
<
T
,
vec_size
>
;
const
int32_t
num_vec_elems
=
num_elements
/
vec_size
;
for
(
int32_t
i
=
gid
;
i
<
num_vec_elems
;
i
+=
grid_size
)
{
vec_t
input_vec
;
input_vec
.
cast_load
(
input
+
i
*
vec_size
);
FP8_TYPE
output_arr
[
vec_size
];
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
vec_size
;
++
j
)
{
float
val
=
fmax
(
fmin
(
static_cast
<
float
>
(
input_vec
[
j
])
*
scale_val
,
FP8_E4M3_MAX
),
-
FP8_E4M3_MAX
);
#ifndef USE_ROCM
output_arr
[
j
]
=
static_cast
<
FP8_TYPE
>
(
val
);
#else
output_arr
[
j
]
=
c10
::
Float8_e4m3fnuz
(
__hip_cvt_float_to_fp8
(
value
,
fp8
::
fp8_type
::
__default_saturation
,
fp8
::
fp8_type
::
__default_interpret
),
c10
::
Float8_e4m3fnuz
::
from_bits
());
#endif
}
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
vec_size
;
++
j
)
{
output
[
i
*
vec_size
+
j
]
=
output_arr
[
j
];
}
}
const
int32_t
remaining_start
=
num_vec_elems
*
vec_size
;
for
(
int32_t
idx
=
remaining_start
+
gid
;
idx
<
num_elements
;
idx
+=
grid_size
)
{
float
val
=
fmax
(
-
FP8_E4M3_MAX
,
fmin
(
static_cast
<
float
>
(
input
[
idx
])
*
scale_val
,
FP8_E4M3_MAX
));
#ifndef USE_ROCM
output
[
idx
]
=
static_cast
<
FP8_TYPE
>
(
val
);
#else
output
[
idx
]
=
c10
::
Float8_e4m3fnuz
(
__hip_cvt_float_to_fp8
(
value
,
fp8
::
fp8_type
::
__default_saturation
,
fp8
::
fp8_type
::
__default_interpret
),
c10
::
Float8_e4m3fnuz
::
from_bits
());
#endif
}
}
void
sgl_per_tensor_quant_fp8
(
torch
::
Tensor
input
,
torch
::
Tensor
output_q
,
torch
::
Tensor
output_s
,
bool
is_static
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
output_q
);
CHECK_INPUT
(
output_s
);
const
int
block_size
=
256
;
const
int
num_elements
=
input
.
numel
();
const
int
num_blocks
=
min
((
num_elements
+
block_size
-
1
)
/
block_size
,
1024
);
dim3
grid
(
num_blocks
);
dim3
block
(
block_size
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
scalar_t
,
[
&
]
{
if
(
is_static
==
false
)
{
per_tensor_absmax_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
scalar_t
*>
(
input
.
data_ptr
()),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
num_elements
);
}
per_tensor_quant_fp8_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
scalar_t
*>
(
input
.
data_ptr
()),
static_cast
<
FP8_TYPE
*>
(
output_q
.
data_ptr
()),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
num_elements
);
return
true
;
});
}
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
View file @
ad55f171
...
...
@@ -92,6 +92,7 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T
const
torch
::
Dtype
&
out_dtype
);
void
sgl_per_token_group_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
int64_t
group_size
,
double
eps
,
double
fp8_min
,
double
fp8_max
);
void
sgl_per_tensor_quant_fp8
(
at
::
Tensor
input
,
at
::
Tensor
output_q
,
at
::
Tensor
output_s
,
bool
is_static
);
void
cublas_grouped_gemm
(
const
std
::
vector
<
torch
::
Tensor
>&
inputs
,
const
std
::
vector
<
torch
::
Tensor
>&
weights
,
const
std
::
vector
<
torch
::
Tensor
>&
outputs
,
const
torch
::
Dtype
&
out_dtype
,
int64_t
cublas_handle
,
int64_t
cuda_stream
);
...
...
sgl-kernel/src/sgl-kernel/ops/gemm.py
View file @
ad55f171
...
...
@@ -91,6 +91,15 @@ def sgl_per_token_group_quant_fp8(
)
def
sgl_per_tensor_quant_fp8
(
input
:
torch
.
Tensor
,
output_q
:
torch
.
Tensor
,
output_s
:
torch
.
Tensor
,
is_static
:
bool
,
)
->
None
:
torch
.
ops
.
sgl_kernels
.
sgl_per_tensor_quant_fp8
(
input
,
output_q
,
output_s
,
is_static
)
def
cublas_grouped_gemm
(
inputs
:
List
[
torch
.
Tensor
],
weights
:
List
[
torch
.
Tensor
],
...
...
sgl-kernel/src/sgl-kernel/torch_extension.cc
View file @
ad55f171
...
...
@@ -90,6 +90,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
" float eps, float fp8_min, float fp8_max) -> ()"
);
m
.
impl
(
"sgl_per_token_group_quant_fp8"
,
torch
::
kCUDA
,
&
sgl_per_token_group_quant_fp8
);
m
.
def
(
"sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()"
);
m
.
impl
(
"sgl_per_tensor_quant_fp8"
,
torch
::
kCUDA
,
&
sgl_per_tensor_quant_fp8
);
m
.
def
(
"cublas_grouped_gemm(Tensor[] inputs, Tensor[] weights, Tensor[] outputs,"
" ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()"
);
...
...
sgl-kernel/tests/test_per_tensor_quant_fp8.py
0 → 100644
View file @
ad55f171
import
itertools
from
typing
import
Optional
,
Tuple
import
pytest
import
torch
from
sgl_kernel
import
sgl_per_tensor_quant_fp8
from
vllm
import
_custom_ops
as
ops
from
sglang.srt.utils
import
is_hip
is_hip_
=
is_hip
()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
is_hip_
else
torch
.
float8_e4m3fn
def
vllm_scaled_fp8_quant
(
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
ops
.
scaled_fp8_quant
(
input
,
scale
)
def
sglang_scaled_fp8_quant
(
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
fp8_type_
:
torch
.
dtype
=
torch
.
float8_e4m3fn
output
=
torch
.
empty_like
(
input
,
device
=
input
.
device
,
dtype
=
fp8_type_
)
is_static
=
True
if
scale
is
None
:
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
is_static
=
False
sgl_per_tensor_quant_fp8
(
input
,
output
,
scale
,
is_static
)
return
output
,
scale
@
pytest
.
mark
.
parametrize
(
"num_tokens,hidden_dim"
,
list
(
itertools
.
product
([
128
,
256
,
512
],
[
512
,
2048
,
4096
])),
)
def
test_per_tensor_quant_compare_implementations
(
num_tokens
:
int
,
hidden_dim
:
int
,
):
device
=
torch
.
device
(
"cuda"
)
x
=
torch
.
rand
((
num_tokens
,
hidden_dim
),
dtype
=
torch
.
float16
,
device
=
device
)
vllm_out
,
vllm_scale
=
vllm_scaled_fp8_quant
(
x
)
sglang_out
,
sglang_scale
=
sglang_scaled_fp8_quant
(
x
)
torch
.
testing
.
assert_close
(
vllm_scale
,
sglang_scale
,
rtol
=
1e-3
,
atol
=
1e-3
)
torch
.
testing
.
assert_close
(
vllm_out
.
float
(),
sglang_out
.
float
(),
rtol
=
1e-3
,
atol
=
1e-3
)
scale
=
torch
.
rand
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
vllm_out
,
vllm_scale
=
vllm_scaled_fp8_quant
(
x
,
scale
)
sglang_out
,
sglang_scale
=
sglang_scaled_fp8_quant
(
x
,
scale
)
torch
.
testing
.
assert_close
(
vllm_scale
,
sglang_scale
,
rtol
=
1e-3
,
atol
=
1e-3
)
torch
.
testing
.
assert_close
(
vllm_out
.
float
(),
sglang_out
.
float
(),
rtol
=
1e-3
,
atol
=
1e-3
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment