Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
53aed988
Unverified
Commit
53aed988
authored
Dec 26, 2024
by
HandH1998
Committed by
GitHub
Dec 26, 2024
Browse files
Refactor MoE (#2575)
Co-authored-by:
zhyncs
<
me@zhyncs.com
>
parent
8a56b431
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1012 additions
and
49 deletions
+1012
-49
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+4
-1
python/sglang/srt/layers/linear.py
python/sglang/srt/layers/linear.py
+20
-2
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+78
-8
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+6
-1
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+159
-25
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+278
-0
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+90
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+36
-11
python/sglang/test/test_block_fp8.py
python/sglang/test/test_block_fp8.py
+341
-0
No files found.
python/sglang/srt/configs/model_config.py
View file @
53aed988
...
@@ -94,7 +94,10 @@ class ModelConfig:
...
@@ -94,7 +94,10 @@ class ModelConfig:
)
)
# FIXME: temporary special judge for MLA architecture
# FIXME: temporary special judge for MLA architecture
if
"DeepseekV2ForCausalLM"
in
self
.
hf_config
.
architectures
:
if
(
"DeepseekV2ForCausalLM"
in
self
.
hf_config
.
architectures
or
"DeepseekV3ForCausalLM"
in
self
.
hf_config
.
architectures
):
self
.
head_dim
=
256
self
.
head_dim
=
256
self
.
attention_arch
=
AttentionArch
.
MLA
self
.
attention_arch
=
AttentionArch
.
MLA
self
.
kv_lora_rank
=
self
.
hf_config
.
kv_lora_rank
self
.
kv_lora_rank
=
self
.
hf_config
.
kv_lora_rank
...
...
python/sglang/srt/layers/linear.py
View file @
53aed988
...
@@ -30,6 +30,7 @@ from sglang.srt.layers.quantization.base_config import (
...
@@ -30,6 +30,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig
,
QuantizationConfig
,
QuantizeMethodBase
,
QuantizeMethodBase
,
)
)
from
sglang.srt.layers.quantization.fp8_utils
import
BlockQuantScaleParameter
from
sglang.srt.utils
import
set_weight_attrs
from
sglang.srt.utils
import
set_weight_attrs
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -628,6 +629,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -628,6 +629,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
if
isinstance
(
param
,
BlockQuantScaleParameter
):
weight_block_size
=
self
.
quant_method
.
quant_config
.
weight_block_size
block_n
,
_
=
weight_block_size
[
0
],
weight_block_size
[
1
]
shard_offset
=
(
(
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
+
block_n
-
1
)
//
block_n
)
//
tp_size
shard_size
=
(
(
self
.
output_sizes
[
loaded_shard_id
]
+
block_n
-
1
)
//
block_n
//
tp_size
)
else
:
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
//
tp_size
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
//
tp_size
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
//
tp_size
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
//
tp_size
...
@@ -795,6 +807,12 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -795,6 +807,12 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_offset
=
self
.
_get_shard_offset_mapping
(
loaded_shard_id
)
shard_offset
=
self
.
_get_shard_offset_mapping
(
loaded_shard_id
)
shard_size
=
self
.
_get_shard_size_mapping
(
loaded_shard_id
)
shard_size
=
self
.
_get_shard_size_mapping
(
loaded_shard_id
)
if
isinstance
(
param
,
BlockQuantScaleParameter
):
weight_block_size
=
self
.
quant_method
.
quant_config
.
weight_block_size
block_n
,
_
=
weight_block_size
[
0
],
weight_block_size
[
1
]
shard_offset
=
(
shard_offset
+
block_n
-
1
)
//
block_n
shard_size
=
(
shard_size
+
block_n
-
1
)
//
block_n
param
.
load_qkv_weight
(
param
.
load_qkv_weight
(
loaded_weight
=
loaded_weight
,
loaded_weight
=
loaded_weight
,
num_heads
=
self
.
num_kv_head_replicas
,
num_heads
=
self
.
num_kv_head_replicas
,
...
...
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
53aed988
...
@@ -6,7 +6,7 @@ import functools
...
@@ -6,7 +6,7 @@ import functools
import
json
import
json
import
logging
import
logging
import
os
import
os
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Tuple
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
import
triton
import
triton
...
@@ -14,6 +14,7 @@ import triton.language as tl
...
@@ -14,6 +14,7 @@ import triton.language as tl
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.quantization.fp8_kernel
import
per_token_group_quant_fp8
from
sglang.srt.utils
import
direct_register_custom_op
,
get_device_name
from
sglang.srt.utils
import
direct_register_custom_op
,
get_device_name
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -48,8 +49,14 @@ def fused_moe_kernel(
...
@@ -48,8 +49,14 @@ def fused_moe_kernel(
stride_bn
,
stride_bn
,
stride_cm
,
stride_cm
,
stride_cn
,
stride_cn
,
stride_asm
,
stride_ask
,
stride_bse
,
stride_bse
,
stride_bsk
,
stride_bsn
,
stride_bsn
,
# Block size for block-wise quantization
group_n
:
tl
.
constexpr
,
group_k
:
tl
.
constexpr
,
# Meta-parameters
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
...
@@ -133,6 +140,13 @@ def fused_moe_kernel(
...
@@ -133,6 +140,13 @@ def fused_moe_kernel(
b_scale
=
tl
.
load
(
b_scale_ptrs
)
b_scale
=
tl
.
load
(
b_scale_ptrs
)
if
use_fp8_w8a8
:
if
use_fp8_w8a8
:
if
group_k
>
0
and
group_n
>
0
:
a_scale_ptrs
=
a_scale_ptr
+
(
offs_token
//
top_k
)
*
stride_asm
offs_bsn
=
offs_bn
//
group_n
b_scale_ptrs
=
(
b_scale_ptr
+
off_experts
*
stride_bse
+
offs_bsn
*
stride_bsn
)
else
:
a_scale
=
tl
.
load
(
a_scale_ptr
)
a_scale
=
tl
.
load
(
a_scale_ptr
)
b_scale
=
tl
.
load
(
b_scale_ptr
+
off_experts
)
b_scale
=
tl
.
load
(
b_scale_ptr
+
off_experts
)
...
@@ -165,6 +179,16 @@ def fused_moe_kernel(
...
@@ -165,6 +179,16 @@ def fused_moe_kernel(
if
use_int8_w8a16
:
if
use_int8_w8a16
:
accumulator
=
tl
.
dot
(
a
,
b
.
to
(
compute_type
),
acc
=
accumulator
)
accumulator
=
tl
.
dot
(
a
,
b
.
to
(
compute_type
),
acc
=
accumulator
)
elif
use_fp8_w8a8
:
elif
use_fp8_w8a8
:
if
group_k
>
0
and
group_n
>
0
:
k_start
=
k
*
BLOCK_SIZE_K
offs_ks
=
k_start
//
group_k
a_scale
=
tl
.
load
(
a_scale_ptrs
+
offs_ks
*
stride_ask
,
mask
=
token_mask
,
other
=
0.0
)
b_scale
=
tl
.
load
(
b_scale_ptrs
+
offs_ks
*
stride_bsk
)
accumulator
+=
tl
.
dot
(
a
,
b
)
*
a_scale
[:,
None
]
*
b_scale
[
None
,
:]
else
:
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
else
:
else
:
accumulator
+=
tl
.
dot
(
a
,
b
)
accumulator
+=
tl
.
dot
(
a
,
b
)
...
@@ -178,6 +202,9 @@ def fused_moe_kernel(
...
@@ -178,6 +202,9 @@ def fused_moe_kernel(
if
use_int8_w8a16
:
if
use_int8_w8a16
:
accumulator
=
(
accumulator
*
b_scale
).
to
(
compute_type
)
accumulator
=
(
accumulator
*
b_scale
).
to
(
compute_type
)
elif
use_fp8_w8a8
:
elif
use_fp8_w8a8
:
if
group_k
>
0
and
group_n
>
0
:
accumulator
=
accumulator
.
to
(
compute_type
)
else
:
accumulator
=
(
accumulator
*
a_scale
*
b_scale
).
to
(
compute_type
)
accumulator
=
(
accumulator
*
a_scale
*
b_scale
).
to
(
compute_type
)
else
:
else
:
accumulator
=
accumulator
.
to
(
compute_type
)
accumulator
=
accumulator
.
to
(
compute_type
)
...
@@ -262,6 +289,7 @@ def invoke_fused_moe_kernel(
...
@@ -262,6 +289,7 @@ def invoke_fused_moe_kernel(
compute_type
:
tl
.
dtype
,
compute_type
:
tl
.
dtype
,
use_fp8_w8a8
:
bool
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int8_w8a16
:
bool
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
)
->
None
:
)
->
None
:
assert
topk_weights
.
stride
(
1
)
==
1
assert
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
...
@@ -269,8 +297,16 @@ def invoke_fused_moe_kernel(
...
@@ -269,8 +297,16 @@ def invoke_fused_moe_kernel(
padded_size
=
0
padded_size
=
0
if
use_fp8_w8a8
:
if
use_fp8_w8a8
:
padded_size
=
padding_size
padded_size
=
padding_size
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
assert
B_scale
is
not
None
assert
B_scale
is
not
None
if
block_shape
is
None
:
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
else
:
assert
len
(
block_shape
)
==
2
block_n
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
A
,
A_scale
=
per_token_group_quant_fp8
(
A
,
block_k
)
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
assert
triton
.
cdiv
(
B
.
shape
[
-
2
],
block_n
)
==
B_scale
.
shape
[
-
2
]
assert
triton
.
cdiv
(
B
.
shape
[
-
1
],
block_k
)
==
B_scale
.
shape
[
-
1
]
elif
use_int8_w8a16
:
elif
use_int8_w8a16
:
assert
B_scale
is
not
None
assert
B_scale
is
not
None
else
:
else
:
...
@@ -309,8 +345,13 @@ def invoke_fused_moe_kernel(
...
@@ -309,8 +345,13 @@ def invoke_fused_moe_kernel(
B
.
stride
(
1
),
B
.
stride
(
1
),
C
.
stride
(
1
),
C
.
stride
(
1
),
C
.
stride
(
2
),
C
.
stride
(
2
),
B_scale
.
stride
(
0
)
if
B_scale
is
not
None
and
use_int8_w8a16
else
0
,
A_scale
.
stride
(
0
)
if
A_scale
is
not
None
and
A_scale
.
ndim
==
2
else
0
,
B_scale
.
stride
(
1
)
if
B_scale
is
not
None
and
use_int8_w8a16
else
0
,
A_scale
.
stride
(
1
)
if
A_scale
is
not
None
and
A_scale
.
ndim
==
2
else
0
,
B_scale
.
stride
(
0
)
if
B_scale
is
not
None
and
B_scale
.
ndim
>=
2
else
0
,
B_scale
.
stride
(
2
)
if
B_scale
is
not
None
and
B_scale
.
ndim
==
3
else
0
,
B_scale
.
stride
(
1
)
if
B_scale
is
not
None
and
B_scale
.
ndim
>=
2
else
0
,
0
if
block_shape
is
None
else
block_shape
[
0
],
0
if
block_shape
is
None
else
block_shape
[
1
],
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
top_k
=
top_k
,
top_k
=
top_k
,
compute_type
=
compute_type
,
compute_type
=
compute_type
,
...
@@ -415,6 +456,7 @@ def try_get_optimal_moe_config(
...
@@ -415,6 +456,7 @@ def try_get_optimal_moe_config(
dtype
:
Optional
[
str
],
dtype
:
Optional
[
str
],
M
:
int
,
M
:
int
,
is_marlin
:
bool
=
False
,
is_marlin
:
bool
=
False
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
):
):
from
sglang.srt.layers.moe.fused_moe_triton
import
get_config
from
sglang.srt.layers.moe.fused_moe_triton
import
get_config
...
@@ -433,6 +475,13 @@ def try_get_optimal_moe_config(
...
@@ -433,6 +475,13 @@ def try_get_optimal_moe_config(
else
:
else
:
# Else use the default config
# Else use the default config
config
=
get_default_config
(
M
,
E
,
N
,
w1_shape
[
2
],
top_k
,
dtype
,
is_marlin
)
config
=
get_default_config
(
M
,
E
,
N
,
w1_shape
[
2
],
top_k
,
dtype
,
is_marlin
)
# TODO(HandH1998): Optimize the configs of block-wise quant.
# NOTE(HandH1998): For block-wise quant,
# BLOCK_K must be divisable by block_shape[1]
# BLOCK_N and BLOCK_M has no requirements
if
block_shape
is
not
None
:
config
[
"BLOCK_SIZE_N"
]
=
block_shape
[
0
]
config
[
"BLOCK_SIZE_K"
]
=
block_shape
[
1
]
return
config
return
config
...
@@ -464,6 +513,7 @@ def inplace_fused_experts(
...
@@ -464,6 +513,7 @@ def inplace_fused_experts(
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
)
->
None
:
)
->
None
:
fused_experts_impl
(
fused_experts_impl
(
hidden_states
,
hidden_states
,
...
@@ -478,6 +528,7 @@ def inplace_fused_experts(
...
@@ -478,6 +528,7 @@ def inplace_fused_experts(
w2_scale
,
w2_scale
,
a1_scale
,
a1_scale
,
a2_scale
,
a2_scale
,
block_shape
,
)
)
...
@@ -493,6 +544,7 @@ def inplace_fused_experts_fake(
...
@@ -493,6 +544,7 @@ def inplace_fused_experts_fake(
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
)
->
None
:
)
->
None
:
pass
pass
...
@@ -517,6 +569,7 @@ def outplace_fused_experts(
...
@@ -517,6 +569,7 @@ def outplace_fused_experts(
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
fused_experts_impl
(
return
fused_experts_impl
(
hidden_states
,
hidden_states
,
...
@@ -531,6 +584,7 @@ def outplace_fused_experts(
...
@@ -531,6 +584,7 @@ def outplace_fused_experts(
w2_scale
,
w2_scale
,
a1_scale
,
a1_scale
,
a2_scale
,
a2_scale
,
block_shape
,
)
)
...
@@ -546,6 +600,7 @@ def outplace_fused_experts_fake(
...
@@ -546,6 +600,7 @@ def outplace_fused_experts_fake(
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
return
torch
.
empty_like
(
hidden_states
)
...
@@ -571,6 +626,7 @@ def fused_experts(
...
@@ -571,6 +626,7 @@ def fused_experts(
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
):
):
if
inplace
:
if
inplace
:
torch
.
ops
.
sglang
.
inplace_fused_experts
(
torch
.
ops
.
sglang
.
inplace_fused_experts
(
...
@@ -585,6 +641,7 @@ def fused_experts(
...
@@ -585,6 +641,7 @@ def fused_experts(
w2_scale
,
w2_scale
,
a1_scale
,
a1_scale
,
a2_scale
,
a2_scale
,
block_shape
,
)
)
return
hidden_states
return
hidden_states
else
:
else
:
...
@@ -600,6 +657,7 @@ def fused_experts(
...
@@ -600,6 +657,7 @@ def fused_experts(
w2_scale
,
w2_scale
,
a1_scale
,
a1_scale
,
a2_scale
,
a2_scale
,
block_shape
,
)
)
...
@@ -616,6 +674,7 @@ def fused_experts_impl(
...
@@ -616,6 +674,7 @@ def fused_experts_impl(
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
):
):
padded_size
=
padding_size
padded_size
=
padding_size
if
not
use_fp8_w8a8
:
if
not
use_fp8_w8a8
:
...
@@ -647,6 +706,7 @@ def fused_experts_impl(
...
@@ -647,6 +706,7 @@ def fused_experts_impl(
(
w2
.
shape
[
0
],
w2
.
shape
[
1
],
w2
.
shape
[
2
]
-
padded_size
),
(
w2
.
shape
[
0
],
w2
.
shape
[
1
],
w2
.
shape
[
2
]
-
padded_size
),
topk_ids
.
shape
[
1
],
topk_ids
.
shape
[
1
],
config_dtype
,
config_dtype
,
block_shape
=
block_shape
,
)
)
config
=
get_config_func
(
M
)
config
=
get_config_func
(
M
)
...
@@ -719,6 +779,7 @@ def fused_experts_impl(
...
@@ -719,6 +779,7 @@ def fused_experts_impl(
compute_type
=
compute_type
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
block_shape
=
block_shape
,
)
)
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
...
@@ -740,6 +801,7 @@ def fused_experts_impl(
...
@@ -740,6 +801,7 @@ def fused_experts_impl(
compute_type
=
compute_type
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
block_shape
=
block_shape
,
)
)
torch
.
sum
(
torch
.
sum
(
...
@@ -768,6 +830,7 @@ def fused_moe(
...
@@ -768,6 +830,7 @@ def fused_moe(
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
This function computes a Mixture of Experts (MoE) layer using two sets of
...
@@ -795,6 +858,12 @@ def fused_moe(
...
@@ -795,6 +858,12 @@ def fused_moe(
w1.
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
w2.
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for
a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
a2.
- block_shape: (Optional[List[int]]): Optional block size for block-wise
quantization.
Returns:
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
- torch.Tensor: The output tensor after applying the MoE layer.
...
@@ -826,4 +895,5 @@ def fused_moe(
...
@@ -826,4 +895,5 @@ def fused_moe(
w2_scale
=
w2_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
)
)
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
53aed988
...
@@ -34,6 +34,7 @@ class FusedMoeWeightScaleSupported(Enum):
...
@@ -34,6 +34,7 @@ class FusedMoeWeightScaleSupported(Enum):
TENSOR
=
"tensor"
TENSOR
=
"tensor"
CHANNEL
=
"channel"
CHANNEL
=
"channel"
GROUP
=
"group"
GROUP
=
"group"
BLOCK
=
"block"
class
FusedMoEMethodBase
(
QuantizeMethodBase
):
class
FusedMoEMethodBase
(
QuantizeMethodBase
):
...
@@ -214,6 +215,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -214,6 +215,7 @@ class FusedMoE(torch.nn.Module):
)
)
self
.
top_k
=
top_k
self
.
top_k
=
top_k
self
.
num_experts
=
num_experts
self
.
num_experts
=
num_experts
assert
intermediate_size
%
self
.
tp_size
==
0
self
.
intermediate_size_per_partition
=
intermediate_size
//
self
.
tp_size
self
.
intermediate_size_per_partition
=
intermediate_size
//
self
.
tp_size
self
.
reduce_results
=
reduce_results
self
.
reduce_results
=
reduce_results
self
.
renormalize
=
renormalize
self
.
renormalize
=
renormalize
...
@@ -470,7 +472,10 @@ class FusedMoE(torch.nn.Module):
...
@@ -470,7 +472,10 @@ class FusedMoE(torch.nn.Module):
expert_data
=
expert_data
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
,
tp_rank
=
tp_rank
,
)
)
elif
quant_method
==
FusedMoeWeightScaleSupported
.
GROUP
.
value
:
elif
quant_method
in
[
FusedMoeWeightScaleSupported
.
GROUP
.
value
,
FusedMoeWeightScaleSupported
.
BLOCK
.
value
,
]:
self
.
_load_model_weight_or_group_weight_scale
(
self
.
_load_model_weight_or_group_weight_scale
(
shard_id
=
shard_id
,
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
shard_dim
=
shard_dim
,
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
53aed988
...
@@ -9,6 +9,7 @@ import torch.nn.functional as F
...
@@ -9,6 +9,7 @@ import torch.nn.functional as F
from
torch.nn
import
Module
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
...
@@ -32,7 +33,11 @@ from sglang.srt.layers.quantization.base_config import (
...
@@ -32,7 +33,11 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig
,
QuantizationConfig
,
QuantizeMethodBase
,
QuantizeMethodBase
,
)
)
from
sglang.srt.layers.quantization.fp8_utils
import
normalize_e4m3fn_to_e4m3fnuz
from
sglang.srt.layers.quantization.fp8_utils
import
(
BlockQuantScaleParameter
,
apply_w8a8_block_fp8_linear
,
normalize_e4m3fn_to_e4m3fnuz
,
)
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_bool_env_var
,
get_bool_env_var
,
is_hip
,
is_hip
,
...
@@ -53,6 +58,7 @@ class Fp8Config(QuantizationConfig):
...
@@ -53,6 +58,7 @@ class Fp8Config(QuantizationConfig):
is_checkpoint_fp8_serialized
:
bool
=
False
,
is_checkpoint_fp8_serialized
:
bool
=
False
,
activation_scheme
:
str
=
"dynamic"
,
activation_scheme
:
str
=
"dynamic"
,
ignored_layers
:
Optional
[
List
[
str
]]
=
None
,
ignored_layers
:
Optional
[
List
[
str
]]
=
None
,
weight_block_size
:
List
[
int
]
=
None
,
)
->
None
:
)
->
None
:
self
.
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
self
.
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
if
is_checkpoint_fp8_serialized
:
if
is_checkpoint_fp8_serialized
:
...
@@ -64,6 +70,20 @@ class Fp8Config(QuantizationConfig):
...
@@ -64,6 +70,20 @@ class Fp8Config(QuantizationConfig):
raise
ValueError
(
f
"Unsupported activation scheme
{
activation_scheme
}
"
)
raise
ValueError
(
f
"Unsupported activation scheme
{
activation_scheme
}
"
)
self
.
activation_scheme
=
activation_scheme
self
.
activation_scheme
=
activation_scheme
self
.
ignored_layers
=
ignored_layers
or
[]
self
.
ignored_layers
=
ignored_layers
or
[]
if
weight_block_size
is
not
None
:
if
not
is_checkpoint_fp8_serialized
:
raise
ValueError
(
f
"The block-wise quantization only supports fp8-serialized checkpoint for now."
)
if
len
(
weight_block_size
)
!=
2
:
raise
ValueError
(
f
"The quantization block size of weight must have 2 dimensions, but got
{
len
(
weight_block_size
)
}
dimensions."
)
if
activation_scheme
!=
"dynamic"
:
raise
ValueError
(
f
"The block-wise quantization only supports dynamic activation scheme for now, but got
{
activation_scheme
}
activation scheme."
)
self
.
weight_block_size
=
weight_block_size
@
classmethod
@
classmethod
def
get_name
(
cls
)
->
str
:
def
get_name
(
cls
)
->
str
:
...
@@ -87,10 +107,12 @@ class Fp8Config(QuantizationConfig):
...
@@ -87,10 +107,12 @@ class Fp8Config(QuantizationConfig):
is_checkpoint_fp8_serialized
=
"fp8"
in
quant_method
is_checkpoint_fp8_serialized
=
"fp8"
in
quant_method
activation_scheme
=
cls
.
get_from_keys
(
config
,
[
"activation_scheme"
])
activation_scheme
=
cls
.
get_from_keys
(
config
,
[
"activation_scheme"
])
ignored_layers
=
cls
.
get_from_keys_or
(
config
,
[
"ignored_layers"
],
None
)
ignored_layers
=
cls
.
get_from_keys_or
(
config
,
[
"ignored_layers"
],
None
)
weight_block_size
=
cls
.
get_from_keys_or
(
config
,
[
"weight_block_size"
],
None
)
return
cls
(
return
cls
(
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
,
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
,
activation_scheme
=
activation_scheme
,
activation_scheme
=
activation_scheme
,
ignored_layers
=
ignored_layers
,
ignored_layers
=
ignored_layers
,
weight_block_size
=
weight_block_size
,
)
)
def
get_quant_method
(
def
get_quant_method
(
...
@@ -143,6 +165,11 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -143,6 +165,11 @@ class Fp8LinearMethod(LinearMethodBase):
if
is_hip
():
if
is_hip
():
self
.
use_marlin
=
False
self
.
use_marlin
=
False
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
if
self
.
block_quant
:
# Marlin doesn't support block-wise fp8
self
.
use_marlin
=
False
def
create_weights
(
def
create_weights
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
@@ -153,10 +180,35 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -153,10 +180,35 @@ class Fp8LinearMethod(LinearMethodBase):
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
**
extra_weight_attrs
,
):
):
del
input_size
,
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
tp_size
=
get_tensor_model_parallel_world_size
()
if
self
.
block_quant
:
block_n
,
block_k
=
(
self
.
quant_config
.
weight_block_size
[
0
],
self
.
quant_config
.
weight_block_size
[
1
],
)
# Required by row parallel
if
tp_size
>
1
and
input_size
//
input_size_per_partition
==
tp_size
:
if
input_size_per_partition
%
block_k
!=
0
:
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible by "
f
"weight quantization block_k =
{
block_k
}
."
)
# Required by collum parallel or enabling merged weights
if
(
tp_size
>
1
and
output_size
//
output_size_per_partition
==
tp_size
)
or
len
(
output_partition_sizes
)
>
1
:
for
output_partition_size
in
output_partition_sizes
:
if
output_partition_size
%
block_n
!=
0
:
raise
ValueError
(
f
"Weight output_partition_size = "
f
"
{
output_partition_size
}
is not divisible by "
f
"weight quantization block_n =
{
block_n
}
."
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
input_size_per_partition
=
input_size_per_partition
...
@@ -184,11 +236,25 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -184,11 +236,25 @@ class Fp8LinearMethod(LinearMethodBase):
# Otherwise, wait until process_weights_after_loading.
# Otherwise, wait until process_weights_after_loading.
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# WEIGHT SCALE
# WEIGHT SCALE
if
self
.
block_quant
:
assert
self
.
quant_config
.
activation_scheme
==
"dynamic"
scale
=
BlockQuantScaleParameter
(
data
=
torch
.
empty
(
(
output_size_per_partition
+
block_n
-
1
)
//
block_n
,
(
input_size_per_partition
+
block_k
-
1
)
//
block_k
,
dtype
=
torch
.
float32
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"weight_scale_inv"
,
scale
)
else
:
scale
=
PerTensorScaleParameter
(
scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
weight_loader
=
weight_loader
,
)
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"weight_scale"
,
scale
)
layer
.
register_parameter
(
"weight_scale"
,
scale
)
...
@@ -205,6 +271,9 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -205,6 +271,9 @@ class Fp8LinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"input_scale"
,
None
)
layer
.
register_parameter
(
"input_scale"
,
None
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# Block quant doesn't need to process weights after loading
if
self
.
block_quant
:
return
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
# If checkpoint not serialized fp8, quantize the weights.
# If checkpoint not serialized fp8, quantize the weights.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
...
@@ -295,6 +364,16 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -295,6 +364,16 @@ class Fp8LinearMethod(LinearMethodBase):
bias
=
bias
,
bias
=
bias
,
)
)
if
self
.
block_quant
:
return
apply_w8a8_block_fp8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
block_size
=
self
.
quant_config
.
weight_block_size
,
weight_scale
=
layer
.
weight_scale_inv
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
)
return
apply_fp8_linear
(
return
apply_fp8_linear
(
input
=
x
,
input
=
x
,
weight
=
layer
.
weight
,
weight
=
layer
.
weight
,
...
@@ -339,6 +418,7 @@ class Fp8MoEMethod:
...
@@ -339,6 +418,7 @@ class Fp8MoEMethod:
def
__init__
(
self
,
quant_config
):
def
__init__
(
self
,
quant_config
):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
def
create_weights
(
def
create_weights
(
self
,
self
,
...
@@ -353,6 +433,28 @@ class Fp8MoEMethod:
...
@@ -353,6 +433,28 @@ class Fp8MoEMethod:
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
params_dtype
=
torch
.
float8_e4m3fn
params_dtype
=
torch
.
float8_e4m3fn
tp_size
=
get_tensor_model_parallel_world_size
()
if
self
.
block_quant
:
block_n
,
block_k
=
(
self
.
quant_config
.
weight_block_size
[
0
],
self
.
quant_config
.
weight_block_size
[
1
],
)
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
# Required by collum parallel or enabling merged weights
if
intermediate_size
%
block_n
!=
0
:
raise
ValueError
(
f
"The output_size of gate's and up's weight = "
f
"
{
intermediate_size
}
is not divisible by "
f
"weight quantization block_n =
{
block_n
}
."
)
if
tp_size
>
1
:
# Required by row parallel
if
intermediate_size
%
block_k
!=
0
:
raise
ValueError
(
f
"The input_size of down's weight = "
f
"
{
intermediate_size
}
is not divisible by "
f
"weight quantization block_k =
{
block_k
}
."
)
# WEIGHTS
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
=
torch
.
nn
.
Parameter
(
...
@@ -374,21 +476,45 @@ class Fp8MoEMethod:
...
@@ -374,21 +476,45 @@ class Fp8MoEMethod:
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# WEIGHT_SCALES
# WEIGHT_SCALES
if
self
.
block_quant
:
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
((
intermediate_size
+
block_n
-
1
)
//
block_n
),
(
hidden_size
+
block_k
-
1
)
//
block_k
,
dtype
=
torch
.
float32
,
),
requires_grad
=
False
,
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
(
hidden_size
+
block_n
-
1
)
//
block_n
,
(
intermediate_size
+
block_k
-
1
)
//
block_k
,
dtype
=
torch
.
float32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_scale_inv"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale_inv"
,
w2_weight_scale
)
assert
self
.
quant_config
.
activation_scheme
==
"dynamic"
else
:
# Allocate 2 scales for w1 and w3 respectively.
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
# They will be combined to a single scale after weight loading.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
requires_grad
=
False
torch
.
ones
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Add the quantization method used (per tensor/grouped/channel)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
# to ensure the weight scales are loaded in properly
extra_weight_attrs
.
update
(
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
}
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
BLOCK
.
value
}
if
self
.
block_quant
else
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
}
)
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# If loading an fp16 checkpoint, do not (we will quantize in
...
@@ -422,7 +548,9 @@ class Fp8MoEMethod:
...
@@ -422,7 +548,9 @@ class Fp8MoEMethod:
layer
.
w2_input_scale
=
None
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# Block quant doesn't need to process weights after loading
if
self
.
block_quant
:
return
# If checkpoint is fp16 or bfloat16, quantize in place.
# If checkpoint is fp16 or bfloat16, quantize in place.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
...
@@ -519,7 +647,6 @@ class Fp8MoEMethod:
...
@@ -519,7 +647,6 @@ class Fp8MoEMethod:
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
w2_input_scale
,
requires_grad
=
False
w2_input_scale
,
requires_grad
=
False
)
)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
# We take the max then dequant and requant each expert.
assert
layer
.
w13_weight_scale
is
not
None
assert
layer
.
w13_weight_scale
is
not
None
...
@@ -594,10 +721,17 @@ class Fp8MoEMethod:
...
@@ -594,10 +721,17 @@ class Fp8MoEMethod:
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
inplace
=
True
,
inplace
=
True
,
use_fp8_w8a8
=
True
,
use_fp8_w8a8
=
True
,
w1_scale
=
layer
.
w13_weight_scale
,
w1_scale
=
(
w2_scale
=
layer
.
w2_weight_scale
,
layer
.
w13_weight_scale_inv
if
self
.
block_quant
else
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale_inv
if
self
.
block_quant
else
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
block_shape
=
self
.
quant_config
.
weight_block_size
,
)
)
...
...
python/sglang/srt/layers/quantization/fp8_kernel.py
0 → 100644
View file @
53aed988
from
typing
import
List
,
Tuple
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
_per_token_group_quant_fp8
(
# Pointers to inputs and output
y_ptr
,
y_q_ptr
,
y_s_ptr
,
# Stride of input
y_stride
,
# Collums of input
N
,
# Avoid to divide zero
eps
,
# Information for float8
fp8_min
,
fp8_max
,
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
):
"""A Triton-accelerated function to perform per-token-group quantization on a
tensor.
This function converts the tensor values into float8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id
=
tl
.
program_id
(
0
)
y_ptr
+=
g_id
*
y_stride
y_q_ptr
+=
g_id
*
y_stride
y_s_ptr
+=
g_id
cols
=
tl
.
arange
(
0
,
BLOCK
)
# N <= BLOCK
mask
=
cols
<
N
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
fp8_max
y_q
=
tl
.
clamp
(
y
/
y_s
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
def
per_token_group_quant_fp8
(
x
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
torch
.
float8_e4m3fn
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
"""
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
finfo
=
torch
.
finfo
(
dtype
)
fp8_min
=
finfo
.
min
fp8_max
=
finfo
.
max
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_stages
=
1
_per_token_group_quant_fp8
[(
M
,)](
x
,
x_q
,
x_s
,
group_size
,
N
,
eps
,
fp8_min
=
fp8_min
,
fp8_max
=
fp8_max
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
return
x_q
,
x_s
@
triton
.
jit
def
_w8a8_block_fp8_matmul
(
# Pointers to inputs and output
A
,
B
,
C
,
As
,
Bs
,
# Shape for matmul
M
,
N
,
K
,
# Block size for block-wise quantization
group_n
,
group_k
,
# Stride for inputs and output
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
stride_As_m
,
stride_As_k
,
stride_Bs_k
,
stride_Bs_n
,
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization, and store the result in output
tensor `C`.
"""
pid
=
tl
.
program_id
(
axis
=
0
)
num_pid_m
=
tl
.
cdiv
(
M
,
BLOCK_SIZE_M
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
group_id
=
pid
//
num_pid_in_group
first_pid_m
=
group_id
*
GROUP_SIZE_M
group_size_m
=
min
(
num_pid_m
-
first_pid_m
,
GROUP_SIZE_M
)
pid_m
=
first_pid_m
+
(
pid
%
group_size_m
)
pid_n
=
(
pid
%
num_pid_in_group
)
//
group_size_m
offs_am
=
(
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
))
%
M
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
A
+
(
offs_am
[:,
None
]
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
b_ptrs
=
B
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
As_ptrs
=
As
+
offs_am
*
stride_As_m
offs_bsn
=
offs_bn
//
group_n
Bs_ptrs
=
Bs
+
offs_bsn
*
stride_Bs_n
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
a
=
tl
.
load
(
a_ptrs
,
mask
=
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
k_start
=
k
*
BLOCK_SIZE_K
offs_ks
=
k_start
//
group_k
a_s
=
tl
.
load
(
As_ptrs
+
offs_ks
*
stride_As_k
)
b_s
=
tl
.
load
(
Bs_ptrs
+
offs_ks
*
stride_Bs_k
)
accumulator
+=
tl
.
dot
(
a
,
b
)
*
a_s
[:,
None
]
*
b_s
[
None
,
:]
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
C
.
dtype
.
element_ty
==
tl
.
bfloat16
:
c
=
accumulator
.
to
(
tl
.
bfloat16
)
elif
C
.
dtype
.
element_ty
==
tl
.
float16
:
c
=
accumulator
.
to
(
tl
.
float16
)
else
:
c
=
accumulator
.
to
(
tl
.
float32
)
offs_cm
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
C
+
stride_cm
*
offs_cm
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
(
offs_cm
[:,
None
]
<
M
)
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
def
w8a8_block_fp8_matmul
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
block_size
:
List
[
int
],
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
torch
.
Tensor
:
"""This function performs matrix multiplication with block-wise quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
assert
len
(
block_size
)
==
2
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
]
assert
A
.
shape
[:
-
1
]
==
As
.
shape
[:
-
1
]
and
A
.
is_contiguous
()
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
As
.
shape
[
-
1
]
M
=
A
.
numel
()
//
A
.
shape
[
-
1
]
assert
B
.
ndim
==
2
and
B
.
is_contiguous
()
and
Bs
.
ndim
==
2
N
,
K
=
B
.
shape
assert
triton
.
cdiv
(
N
,
block_n
)
==
Bs
.
shape
[
0
]
assert
triton
.
cdiv
(
K
,
block_k
)
==
Bs
.
shape
[
1
]
C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
C
=
A
.
new_empty
(
C_shape
,
dtype
=
output_dtype
)
# TODO(HandH1998):
# BLOCK_SIZE_M, BLOCK_SIZE_K, BLOCK_SIZE_N can be optimized.
# BLOCK_SIZE_K must be divisable by block_k
# BLOCK_SIZE_N and BLOCK_SIZE_M has no requirements
BLOCK_SIZE_M
=
128
if
M
<
BLOCK_SIZE_M
:
BLOCK_SIZE_M
=
triton
.
next_power_of_2
(
M
)
BLOCK_SIZE_M
=
max
(
BLOCK_SIZE_M
,
16
)
BLOCK_SIZE_K
=
block_k
assert
block_k
%
BLOCK_SIZE_K
==
0
BLOCK_SIZE_N
=
block_n
def
grid
(
META
):
return
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_SIZE_N"
]),
)
_w8a8_block_fp8_matmul
[
grid
](
A
,
B
,
C
,
As
,
Bs
,
M
,
N
,
K
,
block_n
,
block_k
,
A
.
stride
(
-
2
),
A
.
stride
(
-
1
),
B
.
stride
(
1
),
B
.
stride
(
0
),
C
.
stride
(
-
2
),
C
.
stride
(
-
1
),
As
.
stride
(
-
2
),
As
.
stride
(
-
1
),
Bs
.
stride
(
1
),
Bs
.
stride
(
0
),
BLOCK_SIZE_M
=
BLOCK_SIZE_M
,
BLOCK_SIZE_N
=
BLOCK_SIZE_N
,
BLOCK_SIZE_K
=
BLOCK_SIZE_K
,
GROUP_SIZE_M
=
8
,
)
return
C
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
53aed988
from
typing
import
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
from
vllm.model_executor.parameter
import
RowvLLMParameter
,
_ColumnvLLMParameter
from
sglang.srt.layers.quantization.fp8_kernel
import
(
per_token_group_quant_fp8
,
w8a8_block_fp8_matmul
,
)
def
normalize_e4m3fn_to_e4m3fnuz
(
def
normalize_e4m3fn_to_e4m3fnuz
(
...
@@ -25,3 +31,86 @@ def normalize_e4m3fn_to_e4m3fnuz(
...
@@ -25,3 +31,86 @@ def normalize_e4m3fn_to_e4m3fnuz(
if
input_scale
is
not
None
:
if
input_scale
is
not
None
:
input_scale
=
input_scale
*
2.0
input_scale
=
input_scale
*
2.0
return
weight
,
weight_scale
,
input_scale
return
weight
,
weight_scale
,
input_scale
def
apply_w8a8_block_fp8_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
block_size
:
List
[
int
],
weight_scale
:
torch
.
Tensor
,
input_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
assert
input_scale
is
None
# View input as 2D matrix for fp8 methods
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
q_input
,
x_scale
=
per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
])
output
=
w8a8_block_fp8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output_dtype
=
input
.
dtype
)
if
bias
is
not
None
:
output
=
output
+
bias
return
output
.
to
(
dtype
=
input
.
dtype
).
view
(
*
output_shape
)
def
input_to_float8
(
x
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
float8_e4m3fn
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""This function quantizes input values to float8 values with tensor-wise quantization."""
finfo
=
torch
.
finfo
(
dtype
)
min_val
,
max_val
=
x
.
aminmax
()
amax
=
torch
.
maximum
(
min_val
.
abs
(),
max_val
.
abs
()).
clamp
(
min
=
1e-12
)
scale
=
finfo
.
max
/
amax
x_scl_sat
=
(
x
*
scale
).
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)
return
x_scl_sat
.
to
(
dtype
).
contiguous
(),
scale
.
float
().
reciprocal
()
def
block_quant_to_tensor_quant
(
x_q_block
:
torch
.
Tensor
,
x_s
:
torch
.
Tensor
,
block_size
:
List
[
int
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""This function converts block-wise quantization to tensor-wise quantization.
The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
and the block size.
The outputs are tensor-wise quantization tensor and tensor-wise quantization scale.
Note only float8 is supported for now.
"""
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
n
,
k
=
x_q_block
.
shape
n_tiles
=
(
n
+
block_n
-
1
)
//
block_n
k_tiles
=
(
k
+
block_k
-
1
)
//
block_k
assert
n_tiles
==
x_s
.
shape
[
0
]
assert
k_tiles
==
x_s
.
shape
[
1
]
x_dq_block
=
x_q_block
.
to
(
torch
.
float32
)
x_dq_block_tiles
=
[
[
x_dq_block
[
j
*
block_n
:
min
((
j
+
1
)
*
block_n
,
n
),
i
*
block_k
:
min
((
i
+
1
)
*
block_k
,
k
),
]
for
i
in
range
(
k_tiles
)
]
for
j
in
range
(
n_tiles
)
]
for
i
in
range
(
k_tiles
):
for
j
in
range
(
n_tiles
):
x_dq_block_tiles
[
j
][
i
][:,
:]
=
x_dq_block_tiles
[
j
][
i
]
*
x_s
[
j
][
i
]
x_q_tensor
,
scale
=
input_to_float8
(
x_dq_block
,
dtype
=
x_q_block
.
dtype
)
return
x_q_tensor
,
scale
class
BlockQuantScaleParameter
(
_ColumnvLLMParameter
,
RowvLLMParameter
):
"""
Parameter class for weight scales loaded for weights with
block-wise quantization. Uses both column and row parallelism.
"""
pass
python/sglang/srt/models/deepseek_v2.py
View file @
53aed988
...
@@ -43,6 +43,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
...
@@ -43,6 +43,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8_utils
import
(
block_quant_to_tensor_quant
,
input_to_float8
,
)
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
...
@@ -186,15 +190,6 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
...
@@ -186,15 +190,6 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
return
0.1
*
mscale
*
math
.
log
(
scale
)
+
1.0
return
0.1
*
mscale
*
math
.
log
(
scale
)
+
1.0
def
input_to_float8
(
x
,
dtype
=
torch
.
float8_e4m3fn
):
finfo
=
torch
.
finfo
(
dtype
)
min_val
,
max_val
=
x
.
aminmax
()
amax
=
torch
.
maximum
(
min_val
.
abs
(),
max_val
.
abs
()).
clamp
(
min
=
1e-12
)
scale
=
finfo
.
max
/
amax
x_scl_sat
=
(
x
*
scale
).
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)
return
x_scl_sat
.
to
(
dtype
).
contiguous
(),
scale
.
float
().
reciprocal
()
class
DeepseekV2Attention
(
nn
.
Module
):
class
DeepseekV2Attention
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -869,6 +864,16 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -869,6 +864,16 @@ class DeepseekV2ForCausalLM(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
# TODO(HandH1998): Modify it when nextn is supported.
if
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
):
num_nextn_layers
=
self
.
config
.
num_nextn_predict_layers
if
num_nextn_layers
>
0
and
name
.
startswith
(
"model.layers"
):
name_list
=
name
.
split
(
"."
)
if
(
len
(
name_list
)
>=
3
and
int
(
name_list
[
2
])
>=
self
.
config
.
num_hidden_layers
):
continue
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
...
@@ -933,13 +938,33 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -933,13 +938,33 @@ class DeepseekV2ForCausalLM(nn.Module):
).
T
).
T
else
:
else
:
w
=
self_attn
.
kv_b_proj
.
weight
w
=
self_attn
.
kv_b_proj
.
weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
# This may affect the accuracy of fp8 model.
if
(
hasattr
(
self
.
quant_config
,
"weight_block_size"
)
and
w
.
dtype
==
torch
.
float8_e4m3fn
):
weight_block_size
=
self
.
quant_config
.
weight_block_size
if
weight_block_size
is
not
None
:
assert
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale_inv"
)
w
,
scale
=
block_quant_to_tensor_quant
(
w
,
self_attn
.
kv_b_proj
.
weight_scale_inv
,
weight_block_size
)
self_attn
.
w_scale
=
scale
w_kc
,
w_vc
=
w
.
unflatten
(
w_kc
,
w_vc
=
w
.
unflatten
(
0
,
(
-
1
,
self_attn
.
qk_nope_head_dim
+
self_attn
.
v_head_dim
)
0
,
(
-
1
,
self_attn
.
qk_nope_head_dim
+
self_attn
.
v_head_dim
)
).
split
([
self_attn
.
qk_nope_head_dim
,
self_attn
.
v_head_dim
],
dim
=
1
)
).
split
([
self_attn
.
qk_nope_head_dim
,
self_attn
.
v_head_dim
],
dim
=
1
)
self_attn
.
w_kc
=
w_kc
.
transpose
(
1
,
2
).
contiguous
().
transpose
(
1
,
2
)
self_attn
.
w_kc
=
w_kc
.
transpose
(
1
,
2
).
contiguous
().
transpose
(
1
,
2
)
self_attn
.
w_vc
=
w_vc
.
contiguous
().
transpose
(
1
,
2
)
self_attn
.
w_vc
=
w_vc
.
contiguous
().
transpose
(
1
,
2
)
if
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale"
):
if
(
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale"
)
and
self_attn
.
w_scale
is
None
):
self_attn
.
w_scale
=
self_attn
.
kv_b_proj
.
weight_scale
self_attn
.
w_scale
=
self_attn
.
kv_b_proj
.
weight_scale
EntryClass
=
DeepseekV2ForCausalLM
class
DeepseekV3ForCausalLM
(
DeepseekV2ForCausalLM
):
pass
EntryClass
=
[
DeepseekV2ForCausalLM
,
DeepseekV3ForCausalLM
]
python/sglang/test/test_block_fp8.py
0 → 100644
View file @
53aed988
import
itertools
import
unittest
import
torch
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_moe
from
sglang.srt.layers.quantization.fp8_kernel
import
(
per_token_group_quant_fp8
,
w8a8_block_fp8_matmul
,
)
# For test
def
native_per_token_group_quant_fp8
(
x
,
group_size
,
eps
=
1e-10
,
dtype
=
torch
.
float8_e4m3fn
):
"""Function to perform per-token-group quantization on an input tensor `x` using native torch.
It converts the tensor values into float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Note that only `torch.float8_e4m3fn` is supported for now.
"""
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
finfo
=
torch
.
finfo
(
dtype
)
fp8_min
=
finfo
.
min
fp8_max
=
finfo
.
max
x_
=
x
.
reshape
(
x
.
numel
()
//
group_size
,
group_size
)
amax
=
x_
.
abs
().
max
(
dim
=-
1
,
keepdim
=
True
)[
0
].
clamp
(
min
=
eps
).
to
(
torch
.
float32
)
x_s
=
amax
/
fp8_max
x_q
=
(
x_
/
x_s
).
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
dtype
)
x_q
=
x_q
.
reshape
(
x
.
shape
)
x_s
=
x_s
.
reshape
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,))
return
x_q
,
x_s
class
TestPerTokenGroupQuantFP8
(
unittest
.
TestCase
):
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float32
]
NUM_TOKENS
=
[
7
,
83
,
2048
]
D
=
[
512
,
4096
,
5120
,
13824
]
GROUP_SIZE
=
[
64
,
128
,
256
,
512
]
SEEDS
=
[
0
]
@
classmethod
def
setUpClass
(
cls
):
if
not
torch
.
cuda
.
is_available
():
raise
unittest
.
SkipTest
(
"CUDA is not available"
)
torch
.
set_default_device
(
"cuda"
)
def
_per_token_group_quant_fp8
(
self
,
num_tokens
,
d
,
dtype
,
group_size
,
seed
):
torch
.
manual_seed
(
seed
)
x
=
torch
.
rand
(
num_tokens
,
d
,
dtype
=
dtype
)
with
torch
.
inference_mode
():
ref_out
,
ref_scale
=
native_per_token_group_quant_fp8
(
x
,
group_size
)
out
,
scale
=
per_token_group_quant_fp8
(
x
,
group_size
)
self
.
assertTrue
(
torch
.
allclose
(
out
.
to
(
torch
.
float32
),
ref_out
.
to
(
torch
.
float32
),
rtol
=
0.15
)
)
self
.
assertTrue
(
torch
.
allclose
(
scale
,
ref_scale
))
def
test_per_token_group_quant_fp8
(
self
):
for
params
in
itertools
.
product
(
self
.
NUM_TOKENS
,
self
.
D
,
self
.
DTYPES
,
self
.
GROUP_SIZE
,
self
.
SEEDS
,
):
with
self
.
subTest
(
num_tokens
=
params
[
0
],
d
=
params
[
1
],
dtype
=
params
[
2
],
group_size
=
params
[
3
],
seed
=
params
[
4
],
):
self
.
_per_token_group_quant_fp8
(
*
params
)
# For test
def
native_w8a8_block_fp8_matmul
(
A
,
B
,
As
,
Bs
,
block_size
,
output_dtype
=
torch
.
float16
):
"""This function performs matrix multiplication with block-wise quantization using native torch.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
"""
A
=
A
.
to
(
torch
.
float32
)
B
=
B
.
to
(
torch
.
float32
)
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
]
assert
B
.
ndim
==
2
and
B
.
is_contiguous
()
and
Bs
.
ndim
==
2
assert
len
(
block_size
)
==
2
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
assert
(
A
.
shape
[
-
1
]
+
block_k
-
1
)
//
block_k
==
As
.
shape
[
-
1
]
assert
A
.
shape
[:
-
1
]
==
As
.
shape
[:
-
1
]
M
=
A
.
numel
()
//
A
.
shape
[
-
1
]
N
,
K
=
B
.
shape
origin_C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
A
=
A
.
reshape
(
M
,
A
.
shape
[
-
1
])
As
=
As
.
reshape
(
M
,
As
.
shape
[
-
1
])
n_tiles
=
(
N
+
block_n
-
1
)
//
block_n
k_tiles
=
(
K
+
block_k
-
1
)
//
block_k
assert
n_tiles
==
Bs
.
shape
[
0
]
assert
k_tiles
==
Bs
.
shape
[
1
]
C_shape
=
(
M
,
N
)
C
=
torch
.
zeros
(
C_shape
,
dtype
=
torch
.
float32
,
device
=
A
.
device
)
A_tiles
=
[
A
[:,
i
*
block_k
:
min
((
i
+
1
)
*
block_k
,
K
)]
for
i
in
range
(
k_tiles
)]
B_tiles
=
[
[
B
[
j
*
block_n
:
min
((
j
+
1
)
*
block_n
,
N
),
i
*
block_k
:
min
((
i
+
1
)
*
block_k
,
K
),
]
for
i
in
range
(
k_tiles
)
]
for
j
in
range
(
n_tiles
)
]
C_tiles
=
[
C
[:,
j
*
block_n
:
min
((
j
+
1
)
*
block_n
,
N
)]
for
j
in
range
(
n_tiles
)]
As_tiles
=
[
As
[:,
i
:
i
+
1
]
for
i
in
range
(
k_tiles
)]
for
i
in
range
(
k_tiles
):
for
j
in
range
(
n_tiles
):
a
=
A_tiles
[
i
]
b
=
B_tiles
[
j
][
i
]
c
=
C_tiles
[
j
]
s
=
As_tiles
[
i
]
*
Bs
[
j
][
i
]
c
[:,
:]
+=
torch
.
matmul
(
a
,
b
.
t
())
*
s
C
=
C
.
reshape
(
origin_C_shape
).
to
(
output_dtype
)
return
C
class
TestW8A8BlockFP8Matmul
(
unittest
.
TestCase
):
OUT_DTYPES
=
[
torch
.
float32
,
torch
.
half
,
torch
.
bfloat16
]
M
=
[
1
,
7
,
83
,
512
,
2048
]
N
=
[
128
,
512
,
1024
,
4096
,
7748
,
13824
]
K
=
[
256
,
4096
,
5120
,
3884
,
13824
]
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
BLOCK_SIZE
=
[[
128
,
128
]]
SEEDS
=
[
0
]
@
classmethod
def
setUpClass
(
cls
):
if
not
torch
.
cuda
.
is_available
():
raise
unittest
.
SkipTest
(
"CUDA is not available"
)
torch
.
set_default_device
(
"cuda"
)
def
_w8a8_block_fp8_matmul
(
self
,
M
,
N
,
K
,
block_size
,
out_dtype
,
seed
):
torch
.
manual_seed
(
seed
)
# NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
factor_for_scale
=
1e-2
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
A_fp32
=
(
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float32
)
-
0.5
)
*
2
*
fp8_max
A_fp8
=
A_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
B_fp32
=
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float32
)
-
0.5
)
*
2
*
fp8_max
B_fp8
=
B_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
n_tiles
=
(
N
+
block_n
-
1
)
//
block_n
k_tiles
=
(
K
+
block_k
-
1
)
//
block_k
As
=
torch
.
rand
(
M
,
k_tiles
,
dtype
=
torch
.
float32
)
*
factor_for_scale
Bs
=
torch
.
rand
(
n_tiles
,
k_tiles
,
dtype
=
torch
.
float32
)
*
factor_for_scale
with
torch
.
inference_mode
():
ref_out
=
native_w8a8_block_fp8_matmul
(
A_fp8
,
B_fp8
,
As
,
Bs
,
block_size
,
out_dtype
)
out
=
w8a8_block_fp8_matmul
(
A_fp8
,
B_fp8
,
As
,
Bs
,
block_size
,
out_dtype
)
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out
.
to
(
torch
.
float32
)
-
ref_out
.
to
(
torch
.
float32
)))
/
torch
.
mean
(
torch
.
abs
(
ref_out
.
to
(
torch
.
float32
)))
<
0.001
)
def
test_w8a8_block_fp8_matmul
(
self
):
for
params
in
itertools
.
product
(
self
.
M
,
self
.
N
,
self
.
K
,
self
.
BLOCK_SIZE
,
self
.
OUT_DTYPES
,
self
.
SEEDS
,
):
with
self
.
subTest
(
M
=
params
[
0
],
N
=
params
[
1
],
K
=
params
[
2
],
block_size
=
params
[
3
],
out_dtype
=
params
[
4
],
seed
=
params
[
5
],
):
self
.
_w8a8_block_fp8_matmul
(
*
params
)
# For test
def
torch_w8a8_block_fp8_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
,
block_shape
):
"""This function performs fused moe with block-wise quantization using native torch."""
B
,
D
=
a
.
shape
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
score
=
torch
.
softmax
(
score
,
dim
=-
1
,
dtype
=
torch
.
float32
)
topk_weight
,
topk_ids
=
torch
.
topk
(
score
,
topk
)
topk_weight
=
topk_weight
.
view
(
-
1
)
topk_ids
=
topk_ids
.
view
(
-
1
)
_
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
a_q
,
a_s
=
native_per_token_group_quant_fp8
(
a
,
block_k
)
# NOTE(HandH1998): Since "index_cuda" not implemented for 'Float8_e4m3fn', we need to cast `float8`` to `float32``.
a_q
=
a_q
.
to
(
torch
.
float32
)
for
i
in
range
(
w1
.
shape
[
0
]):
mask
=
topk_ids
==
i
if
mask
.
sum
():
inter_out
=
native_w8a8_block_fp8_matmul
(
a_q
[
mask
],
w1
[
i
],
a_s
[
mask
],
w1_s
[
i
],
block_shape
,
output_dtype
=
a
.
dtype
)
act_out
=
SiluAndMul
().
forward_native
(
inter_out
)
act_out_q
,
act_out_s
=
native_per_token_group_quant_fp8
(
act_out
,
block_k
)
act_out
=
act_out
.
to
(
torch
.
float32
)
out
[
mask
]
=
native_w8a8_block_fp8_matmul
(
act_out_q
,
w2
[
i
],
act_out_s
,
w2_s
[
i
],
block_shape
,
output_dtype
=
a
.
dtype
)
return
(
out
.
view
(
B
,
-
1
,
w2
.
shape
[
1
])
*
topk_weight
.
view
(
B
,
-
1
,
1
).
to
(
out
.
dtype
)
).
sum
(
dim
=
1
)
class
TestW8A8BlockFP8FusedMoE
(
unittest
.
TestCase
):
DTYPES
=
[
torch
.
float32
,
torch
.
half
,
torch
.
bfloat16
]
M
=
[
1
,
33
,
64
,
222
,
1024
*
128
]
N
=
[
128
,
1024
,
2048
]
K
=
[
256
,
4096
,
5120
]
E
=
[
8
,
24
]
TOP_KS
=
[
2
,
6
]
BLOCK_SIZE
=
[[
64
,
64
],
[
64
,
128
],
[
128
,
64
],
[
128
,
128
]]
# BLOCK_SIZE = [[128, 128]]
SEEDS
=
[
0
]
@
classmethod
def
setUpClass
(
cls
):
if
not
torch
.
cuda
.
is_available
():
raise
unittest
.
SkipTest
(
"CUDA is not available"
)
torch
.
set_default_device
(
"cuda"
)
def
_w8a8_block_fp8_fused_moe
(
self
,
M
,
N
,
K
,
E
,
topk
,
block_size
,
dtype
,
seed
):
torch
.
manual_seed
(
seed
)
# NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
factor_for_scale
=
1e-2
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
a
=
torch
.
randn
((
M
,
K
),
dtype
=
dtype
)
/
10
w1_fp32
=
(
torch
.
rand
((
E
,
2
*
N
,
K
),
dtype
=
torch
.
float32
)
-
0.5
)
*
2
*
fp8_max
w1
=
w1_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
w2_fp32
=
(
torch
.
rand
((
E
,
K
,
N
),
dtype
=
torch
.
float32
)
-
0.5
)
*
2
*
fp8_max
w2
=
w2_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
n_tiles_w1
=
(
2
*
N
+
block_n
-
1
)
//
block_n
n_tiles_w2
=
(
K
+
block_n
-
1
)
//
block_n
k_tiles_w1
=
(
K
+
block_k
-
1
)
//
block_k
k_tiles_w2
=
(
N
+
block_k
-
1
)
//
block_k
w1_s
=
(
torch
.
rand
((
E
,
n_tiles_w1
,
k_tiles_w1
),
dtype
=
torch
.
float32
)
*
factor_for_scale
)
w2_s
=
(
torch
.
rand
((
E
,
n_tiles_w2
,
k_tiles_w2
),
dtype
=
torch
.
float32
)
*
factor_for_scale
)
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
with
torch
.
inference_mode
():
out
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
,
use_fp8_w8a8
=
True
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
block_shape
=
block_size
,
)
ref_out
=
torch_w8a8_block_fp8_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
,
block_size
)
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out
.
to
(
torch
.
float32
)
-
ref_out
.
to
(
torch
.
float32
)))
/
torch
.
mean
(
torch
.
abs
(
ref_out
.
to
(
torch
.
float32
)))
<
0.02
)
def
test_w8a8_block_fp8_fused_moe
(
self
):
for
params
in
itertools
.
product
(
self
.
M
,
self
.
N
,
self
.
K
,
self
.
E
,
self
.
TOP_KS
,
self
.
BLOCK_SIZE
,
self
.
DTYPES
,
self
.
SEEDS
,
):
with
self
.
subTest
(
M
=
params
[
0
],
N
=
params
[
1
],
K
=
params
[
2
],
E
=
params
[
3
],
topk
=
params
[
4
],
block_size
=
params
[
5
],
dtype
=
params
[
6
],
seed
=
params
[
7
],
):
self
.
_w8a8_block_fp8_fused_moe
(
*
params
)
if
__name__
==
"__main__"
:
unittest
.
main
(
verbosity
=
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment