Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
51d25405
Unverified
Commit
51d25405
authored
Mar 04, 2025
by
HAI
Committed by
GitHub
Mar 04, 2025
Browse files
ROCm: update aiter and its usage to fused moe (bloat16, fp8, fp8 block-quant) (#4053)
parent
e0a2c963
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
82 additions
and
40 deletions
+82
-40
docker/Dockerfile.rocm
docker/Dockerfile.rocm
+1
-1
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+16
-11
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+64
-27
No files found.
docker/Dockerfile.rocm
View file @
51d25405
...
...
@@ -18,7 +18,7 @@ ARG TRITON_COMMIT="improve_fa_decode_3.0.0"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
ARG AITER_COMMIT="
dev/
testx"
ARG AITER_COMMIT="testx"
RUN git clone ${SGL_REPO} \
&& cd sglang \
...
...
python/pyproject.toml
View file @
51d25405
...
...
@@ -51,7 +51,7 @@ srt = [
]
# HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:202
41022
, not from public vllm whl
# => base docker rocm/vllm-dev:202
50114
, not from public vllm whl
srt_hip
=
["sglang[runtime_common]
", "
sgl-kernel==
0.0.3
.post
6
", "
torch
", "
vllm==
0.6.7
.dev
2
", "
outlines==
0.1.11
"]
# xpu is not enabled in public vllm and torch whl,
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
51d25405
...
...
@@ -29,6 +29,9 @@ import logging
is_hip_
=
is_hip
()
if
is_hip_
:
from
aiter
import
ck_moe
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -173,18 +176,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
)
if
is_hip_
and
get_bool_env_var
(
"CK_MOE"
):
import
aiter
from
aiter.fused_moe
import
fused_experts_ck
assert
activation
==
"silu"
,
f
"
{
activation
=
}
is not supported."
assert
not
no_combine
,
"unsupported"
return
fused_experts_ck
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
return
ck_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
None
,
None
,
None
,
None
,
32
,
None
,
activation
,
)
else
:
return
fused_experts
(
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
51d25405
...
...
@@ -51,6 +51,10 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
is_hip_
=
is_hip
()
if
is_hip_
:
from
aiter.fused_moe_bf16_asm
import
asm_moe
from
aiter.ops.shuffle
import
shuffle_weight
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -533,6 +537,20 @@ class Fp8MoEMethod:
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
if
is_hip_
and
get_bool_env_var
(
"CK_MOE"
):
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
w13_weight_scale1
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
w2_weight_scale1
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
hidden_size
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_scale1"
,
w13_weight_scale1
)
layer
.
register_parameter
(
"w2_weight_scale1"
,
w2_weight_scale1
)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs
.
update
(
...
...
@@ -602,6 +620,15 @@ class Fp8MoEMethod:
w2_weight_scale
,
requires_grad
=
False
)
layer
.
w2_input_scale
=
None
if
get_bool_env_var
(
"CK_MOE"
):
# Pre-shuffle weights
layer
.
w13_weight
.
data
=
shuffle_weight
(
layer
.
w13_weight
.
contiguous
(),
(
16
,
16
)
)
layer
.
w2_weight
.
data
=
shuffle_weight
(
layer
.
w2_weight
.
contiguous
(),
(
16
,
16
)
)
return
# If checkpoint is fp16 or bfloat16, quantize in place.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
...
...
@@ -640,6 +667,9 @@ class Fp8MoEMethod:
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
# ROCm (CK_MOE): using column-wise scaling
layer
.
w13_weight_scale1
*=
layer
.
w13_weight_scale
.
unsqueeze
(
-
1
)
layer
.
w2_weight_scale1
*=
layer
.
w2_weight_scale
.
unsqueeze
(
-
1
)
elif
get_bool_env_var
(
"MOE_PADDING"
):
# If ROCm, apply weight padding (min. Mem channel contention) only if set
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
...
...
@@ -744,6 +774,9 @@ class Fp8MoEMethod:
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
# ROCm (CK_MOE): using column-wise scaling
layer
.
w13_weight_scale1
*=
layer
.
w13_weight_scale
.
unsqueeze
(
-
1
)
layer
.
w2_weight_scale1
*=
layer
.
w2_weight_scale
.
unsqueeze
(
-
1
)
elif
get_bool_env_var
(
"MOE_PADDING"
):
# If ROCm, apply weight padding (min. Mem channel contention) only if set
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
...
...
@@ -790,34 +823,38 @@ class Fp8MoEMethod:
correction_bias
=
correction_bias
,
)
if
is_hip_
and
get_bool_env_var
(
"CK_MOE"
):
import
aiter
from
aiter.fused_moe
import
fused_experts_ck
assert
activation
==
"silu"
,
f
"
{
activation
=
}
is not supported."
if
is_hip_
and
get_bool_env_var
(
"CK_MOE"
)
and
activation
==
"silu"
:
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
assert
not
no_combine
,
f
"
{
no_combine
=
}
is not supported."
return
fused_experts_ck
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
use_fp8_w8a8
=
True
,
w1_scale
=
(
layer
.
w13_weight_scale_inv
if
self
.
block_quant
else
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale_inv
if
self
.
block_quant
else
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
)
if
self
.
block_quant
:
return
asm_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
layer
.
w13_weight_scale_inv
,
layer
.
w2_weight_scale_inv
,
None
,
None
,
False
,
None
,
block_shape
=
tuple
(
self
.
quant_config
.
weight_block_size
),
expert_mask
=
None
,
)
else
:
return
asm_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
layer
.
w13_weight_scale1
,
layer
.
w2_weight_scale1
,
None
,
None
,
False
,
)
else
:
# Expert fusion with FP8 quantization
return
fused_experts
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment