Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e0917e6b
Unverified
Commit
e0917e6b
authored
Mar 12, 2025
by
Stefan He
Committed by
GitHub
Mar 12, 2025
Browse files
Remove vllm ops scaled fp8 quant and accelerate per token quant by 20-28% (#4215)
Co-authored-by:
Stefan He
<
bhe@linkedin.com
>
parent
7130a7ce
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
202 additions
and
37 deletions
+202
-37
python/sglang/srt/custom_op.py
python/sglang/srt/custom_op.py
+59
-0
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+22
-8
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+14
-6
python/sglang/test/test_custom_ops.py
python/sglang/test/test_custom_ops.py
+88
-0
sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
+19
-23
No files found.
python/sglang/srt/custom_op.py
View file @
e0917e6b
from
typing
import
Optional
import
torch
from
torch
import
nn
...
...
@@ -40,3 +42,60 @@ class CustomOp(nn.Module):
return
self
.
forward_hip
else
:
return
self
.
forward_native
if
_is_cuda
:
from
sgl_kernel
import
sgl_per_tensor_quant_fp8
,
sgl_per_token_quant_fp8
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
use_per_token_if_dynamic
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Quantize input tensor to FP8 (8-bit floating point) format.
Args:
input (torch.Tensor): Input tensor to be quantized
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
If None, scales will be computed dynamically.
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
determines the quantization granularity:
- True: compute scale per token
- False: compute single scale per tensor
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- quantized_tensor: The FP8 quantized version of input
- scale_tensor: The scaling factors used for quantization
Raises:
AssertionError: If input is not 2D or if static scale's numel != 1
"""
assert
input
.
ndim
==
2
,
f
"Expected 2D input tensor, got
{
input
.
ndim
}
D"
shape
=
input
.
shape
out_dtype
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
output
=
torch
.
empty
(
shape
,
device
=
input
.
device
,
dtype
=
out_dtype
)
if
scale
is
None
:
# Dynamic scaling
if
use_per_token_if_dynamic
:
scale
=
torch
.
empty
(
(
shape
[
0
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
sgl_per_token_quant_fp8
(
input
,
output
,
scale
)
else
:
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
sgl_per_tensor_quant_fp8
(
input
,
output
,
scale
,
is_static
=
False
)
# False for dynamic
else
:
# Static scaling
assert
(
scale
.
numel
()
==
1
),
f
"Expected scalar scale, got numel=
{
scale
.
numel
()
}
"
sgl_per_tensor_quant_fp8
(
input
,
output
,
scale
,
is_static
=
True
)
# True for static
return
output
,
scale
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
e0917e6b
...
...
@@ -3,7 +3,7 @@ from typing import Callable, List, Optional, Tuple
import
torch
from
torch.nn
import
Module
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
vllm_
ops
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.distributed
import
(
...
...
@@ -26,7 +26,13 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
,
Fp8MoEMethod
from
sglang.srt.utils
import
is_hip
,
set_weight_attrs
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
set_weight_attrs
_is_cuda
=
is_cuda
()
if
_is_cuda
:
from
sglang.srt.custom_op
import
scaled_fp8_quant
as
sgl_scaled_fp8_quant
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -719,12 +725,20 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
)
for
expert
in
range
(
layer
.
num_experts_per_partition
):
w13_weight
[
expert
,
:,
:],
layer
.
w13_weight_scale
[
expert
]
=
(
ops
.
scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
)
w2_weight
[
expert
,
:,
:],
layer
.
w2_weight_scale
[
expert
]
=
(
ops
.
scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
)
if
_is_cuda
:
w13_weight
[
expert
,
:,
:],
layer
.
w13_weight_scale
[
expert
]
=
(
sgl_scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
)
w2_weight
[
expert
,
:,
:],
layer
.
w2_weight_scale
[
expert
]
=
(
sgl_scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
)
else
:
w13_weight
[
expert
,
:,
:],
layer
.
w13_weight_scale
[
expert
]
=
(
vllm_ops
.
scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
)
w2_weight
[
expert
,
:,
:],
layer
.
w2_weight_scale
[
expert
]
=
(
vllm_ops
.
scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
return
...
...
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
e0917e6b
...
...
@@ -11,7 +11,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
import
torch
import
triton
import
triton.language
as
tl
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
vllm_
ops
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.quantization.fp8_kernel
import
per_token_group_quant_fp8
...
...
@@ -42,6 +42,7 @@ _is_cuda = is_cuda()
if
_is_cuda
:
from
sgl_kernel
import
gelu_and_mul
,
silu_and_mul
from
sglang.srt.custom_op
import
scaled_fp8_quant
as
sgl_scaled_fp8_quant
from
sglang.srt.layers.quantization.fp8_kernel
import
(
sglang_per_token_group_quant_fp8
,
)
...
...
@@ -486,7 +487,7 @@ def moe_align_block_size(
cumsum_buffer
,
)
else
:
ops
.
moe_align_block_size
(
vllm_
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
...
...
@@ -527,7 +528,10 @@ def invoke_fused_moe_kernel(
if
block_shape
is
None
:
# activation tensor-wise fp8 quantization, dynamic or static
padded_size
=
padding_size
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
if
_is_cuda
:
A
,
A_scale
=
sgl_scaled_fp8_quant
(
A
,
A_scale
)
else
:
A
,
A_scale
=
vllm_ops
.
scaled_fp8_quant
(
A
,
A_scale
)
else
:
# activation block-wise fp8 quantization
assert
len
(
block_shape
)
==
2
...
...
@@ -1095,12 +1099,16 @@ def fused_experts_impl(
if
_is_cuda
:
silu_and_mul
(
intermediate_cache1
.
view
(
-
1
,
N
),
intermediate_cache2
)
else
:
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
vllm_ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
)
)
elif
activation
==
"gelu"
:
if
_is_cuda
:
gelu_and_mul
(
intermediate_cache1
.
view
(
-
1
,
N
),
intermediate_cache2
)
else
:
ops
.
gelu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
vllm_ops
.
gelu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
)
)
else
:
raise
ValueError
(
f
"Unsupported activation:
{
activation
=
}
"
)
...
...
@@ -1132,7 +1140,7 @@ def fused_experts_impl(
if
no_combine
:
pass
elif
_is_hip
:
ops
.
moe_sum
(
vllm_
ops
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
)
...
...
python/sglang/test/test_custom_ops.py
0 → 100644
View file @
e0917e6b
# Adapted from https://github.com/vllm-project/vllm/blob/8ca7a71df787ad711ad3ac70a5bd2eb2bb398938/tests/quantization/test_fp8.py
import
pytest
import
torch
from
sglang.srt.custom_op
import
scaled_fp8_quant
from
sglang.srt.utils
import
is_cuda
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
def
test_scaled_fp8_quant_per_tensor
(
dtype
)
->
None
:
def
quantize_ref_per_tensor
(
tensor
,
inv_scale
):
# The reference implementation that fully aligns to
# the kernel being tested.
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
scale
=
inv_scale
.
reciprocal
()
qweight
=
(
tensor
.
to
(
torch
.
float32
)
*
scale
).
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)
qweight
=
qweight
.
to
(
torch
.
float8_e4m3fn
)
return
qweight
def
dequantize_per_tensor
(
tensor
,
inv_scale
,
dtype
):
fake_qweight
=
tensor
.
to
(
dtype
)
dq_weight
=
fake_qweight
*
inv_scale
return
dq_weight
# Note that we use a shape % 8 != 0 to cover edge cases,
# because scaled_fp8_quant is vectorized by 8.
x
=
(
torch
.
randn
(
size
=
(
11
,
11
),
device
=
"cuda"
)
*
13
).
to
(
dtype
)
# Test Per Tensor Dynamic quantization
# scale = max(abs(x)) / FP8_E4M3_MAX
y
,
scale
=
scaled_fp8_quant
(
x
,
None
)
ref_y
=
quantize_ref_per_tensor
(
x
,
scale
)
torch
.
testing
.
assert_close
(
y
,
ref_y
)
torch
.
testing
.
assert_close
(
dequantize_per_tensor
(
y
,
scale
,
dtype
),
dequantize_per_tensor
(
ref_y
,
scale
,
dtype
),
)
# Test Per Tensor Static quantization
y
,
_
=
scaled_fp8_quant
(
x
,
scale
)
ref_y
=
quantize_ref_per_tensor
(
x
,
scale
)
torch
.
testing
.
assert_close
(
y
,
ref_y
)
torch
.
testing
.
assert_close
(
dequantize_per_tensor
(
y
,
scale
,
dtype
),
dequantize_per_tensor
(
ref_y
,
scale
,
dtype
),
)
if
is_cuda
:
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
def
test_scaled_fp8_quant_per_token_dynamic
(
dtype
)
->
None
:
def
quantize_ref_per_token
(
tensor
,
inv_scale
):
# The reference implementation that fully aligns to
# the kernel being tested.
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
scale
=
inv_scale
.
reciprocal
()
qweight
=
(
tensor
.
to
(
torch
.
float32
)
*
scale
).
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)
qweight
=
qweight
.
to
(
torch
.
float8_e4m3fn
)
return
qweight
def
dequantize_per_token
(
tensor
,
inv_scale
,
dtype
):
fake_qweight
=
tensor
.
to
(
dtype
)
dq_weight
=
fake_qweight
*
inv_scale
return
dq_weight
# Note that we use a shape % 8 = 0,
# because per_token_quant_fp8 is vectorized by 8 elements.
x
=
(
torch
.
randn
(
size
=
(
11
,
16
),
device
=
"cuda"
)
*
13
).
to
(
dtype
)
# Test Per Tensor Dynamic quantization
# scale = max(abs(x)) / FP8_E4M3_MAX
y
,
scale
=
scaled_fp8_quant
(
x
,
None
,
use_per_token_if_dynamic
=
True
)
ref_y
=
quantize_ref_per_token
(
x
,
scale
)
torch
.
testing
.
assert_close
(
y
,
ref_y
)
torch
.
testing
.
assert_close
(
dequantize_per_token
(
y
,
scale
,
dtype
),
dequantize_per_token
(
ref_y
,
scale
,
dtype
),
)
if
__name__
==
"__main__"
:
# Run the specific test function directly
pytest
.
main
([
__file__
])
sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
View file @
e0917e6b
...
...
@@ -14,7 +14,6 @@ __global__ void per_token_quant_fp8_kernel(
const
int64_t
hidden_dim
,
const
int64_t
num_tokens
)
{
const
int
token_idx
=
blockIdx
.
x
;
if
(
token_idx
>=
num_tokens
)
return
;
const
int
tid
=
threadIdx
.
x
;
...
...
@@ -25,9 +24,20 @@ __global__ void per_token_quant_fp8_kernel(
float
max_value
=
0.0
f
;
for
(
int
i
=
tid
;
i
<
hidden_dim
;
i
+=
block_dim
)
{
float
val
=
static_cast
<
float
>
(
token_input
[
i
]);
max_value
=
fmaxf
(
max_value
,
fabsf
(
val
));
constexpr
uint32_t
vec_size
=
16
/
sizeof
(
T
);
using
vec_t
=
flashinfer
::
vec_t
<
T
,
vec_size
>
;
const
int32_t
num_vec_elems
=
hidden_dim
/
vec_size
;
// Find max using vectorized loads
for
(
int32_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
block_dim
)
{
vec_t
input_vec
;
input_vec
.
cast_load
(
token_input
+
i
*
vec_size
);
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
vec_size
;
++
j
)
{
float
val
=
static_cast
<
float
>
(
input_vec
[
j
]);
max_value
=
fmaxf
(
max_value
,
fabsf
(
val
));
}
}
max_value
=
blockReduceMax
(
max_value
);
...
...
@@ -41,11 +51,7 @@ __global__ void per_token_quant_fp8_kernel(
const
float
scale_val
=
1.0
f
/
block_max
;
constexpr
uint32_t
vec_size
=
16
/
sizeof
(
T
);
using
vec_t
=
flashinfer
::
vec_t
<
T
,
vec_size
>
;
const
int32_t
num_vec_elems
=
hidden_dim
/
vec_size
;
// Quantize using vectorized loads
for
(
int32_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
block_dim
)
{
vec_t
input_vec
;
input_vec
.
cast_load
(
token_input
+
i
*
vec_size
);
...
...
@@ -53,7 +59,7 @@ __global__ void per_token_quant_fp8_kernel(
FP8_TYPE
output_arr
[
vec_size
];
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
vec_size
;
++
j
)
{
float
val
=
fmax
(
fmin
(
static_cast
<
float
>
(
input_vec
[
j
])
*
scale_val
,
FP8_E4M3_MAX
),
-
FP8_E4M3_MAX
);
float
val
=
fmax
f
(
fmin
f
(
static_cast
<
float
>
(
input_vec
[
j
])
*
scale_val
,
FP8_E4M3_MAX
),
-
FP8_E4M3_MAX
);
#ifndef USE_ROCM
output_arr
[
j
]
=
static_cast
<
FP8_TYPE
>
(
val
);
#else
...
...
@@ -68,18 +74,6 @@ __global__ void per_token_quant_fp8_kernel(
token_output
[
i
*
vec_size
+
j
]
=
output_arr
[
j
];
}
}
const
int32_t
remaining_start
=
num_vec_elems
*
vec_size
;
for
(
int32_t
idx
=
remaining_start
+
tid
;
idx
<
hidden_dim
;
idx
+=
block_dim
)
{
float
val
=
fmax
(
-
FP8_E4M3_MAX
,
fmin
(
static_cast
<
float
>
(
token_input
[
idx
])
*
scale_val
,
FP8_E4M3_MAX
));
#ifndef USE_ROCM
token_output
[
idx
]
=
static_cast
<
FP8_TYPE
>
(
val
);
#else
token_output
[
idx
]
=
c10
::
Float8_e4m3fnuz
(
__hip_cvt_float_to_fp8
(
val
,
fp8
::
fp8_type
::
__default_saturation
,
fp8
::
fp8_type
::
__default_interpret
),
c10
::
Float8_e4m3fnuz
::
from_bits
());
#endif
}
}
void
sgl_per_token_quant_fp8
(
torch
::
Tensor
input
,
torch
::
Tensor
output_q
,
torch
::
Tensor
output_s
)
{
...
...
@@ -91,7 +85,9 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
const
int64_t
num_tokens
=
input_sizes
[
0
];
const
int64_t
hidden_dim
=
input_sizes
[
1
];
const
int
block_size
=
128
;
TORCH_CHECK
(
hidden_dim
%
8
==
0
,
"Hidden dimension must be divisible by 8, but got "
,
hidden_dim
);
const
int
block_size
=
256
;
const
int
num_blocks
=
num_tokens
;
dim3
grid
(
num_blocks
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment