Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d364b9b0
Unverified
Commit
d364b9b0
authored
Apr 28, 2025
by
HAI
Committed by
GitHub
Apr 28, 2025
Browse files
ROCm: update AITER (#5816)
parent
849c83a0
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
48 additions
and
52 deletions
+48
-52
.github/workflows/pr-test-amd.yml
.github/workflows/pr-test-amd.yml
+6
-6
3rdparty/amd/tuning/benchmark_moe_rocm.py
3rdparty/amd/tuning/benchmark_moe_rocm.py
+1
-1
docker/Dockerfile.rocm
docker/Dockerfile.rocm
+2
-2
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+2
-2
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+15
-17
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+20
-22
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+2
-2
No files found.
.github/workflows/pr-test-amd.yml
View file @
d364b9b0
...
@@ -38,12 +38,12 @@ jobs:
...
@@ -38,12 +38,12 @@ jobs:
else
else
DEVICE_FLAG="--device /dev/dri"
DEVICE_FLAG="--device /dev/dri"
fi
fi
docker pull
lmsysorg/sglang:v0.4.5.post3-rocm630
docker pull
ghcr.io/saienduri/sglang-aiter-v0.1.1:428
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
--cap-add=SYS_PTRACE -e HF_TOKEN=${HF_TOKEN} --security-opt seccomp=unconfined \
--cap-add=SYS_PTRACE -e HF_TOKEN=${HF_TOKEN} --security-opt seccomp=unconfined \
-w /sglang-checkout --name ci_sglang \
-w /sglang-checkout --name ci_sglang \
lmsysorg/sglang:v0.4.5.post3-rocm630
ghcr.io/saienduri/sglang-aiter-v0.1.1:428
-
name
:
Install dependencies
-
name
:
Install dependencies
run
:
|
run
:
|
...
@@ -82,12 +82,12 @@ jobs:
...
@@ -82,12 +82,12 @@ jobs:
else
else
DEVICE_FLAG="--device /dev/dri"
DEVICE_FLAG="--device /dev/dri"
fi
fi
docker pull
lmsysorg/sglang:v0.4.5.post3-rocm630
docker pull
ghcr.io/saienduri/sglang-aiter-v0.1.1:428
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
--cap-add=SYS_PTRACE -e HF_TOKEN=${{ secrets.AMD_HF_TOKEN }} --security-opt seccomp=unconfined \
--cap-add=SYS_PTRACE -e HF_TOKEN=${{ secrets.AMD_HF_TOKEN }} --security-opt seccomp=unconfined \
-w /sglang-checkout --name ci_sglang \
-w /sglang-checkout --name ci_sglang \
lmsysorg/sglang:v0.4.5.post3-rocm630
ghcr.io/saienduri/sglang-aiter-v0.1.1:428
-
name
:
Install dependencies
-
name
:
Install dependencies
run
:
|
run
:
|
...
@@ -120,12 +120,12 @@ jobs:
...
@@ -120,12 +120,12 @@ jobs:
else
else
DEVICE_FLAG="--device /dev/dri"
DEVICE_FLAG="--device /dev/dri"
fi
fi
docker pull
lmsysorg/sglang:v0.4.5.post3-rocm630
docker pull
ghcr.io/saienduri/sglang-aiter-v0.1.1:428
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
--cap-add=SYS_PTRACE -e HF_TOKEN=${HF_TOKEN} --security-opt seccomp=unconfined \
--cap-add=SYS_PTRACE -e HF_TOKEN=${HF_TOKEN} --security-opt seccomp=unconfined \
-w /sglang-checkout --name ci_sglang \
-w /sglang-checkout --name ci_sglang \
lmsysorg/sglang:v0.4.5.post3-rocm630
ghcr.io/saienduri/sglang-aiter-v0.1.1:428
-
name
:
Install dependencies
-
name
:
Install dependencies
run
:
|
run
:
|
...
...
3rdparty/amd/tuning/benchmark_moe_rocm.py
View file @
d364b9b0
...
@@ -15,7 +15,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
...
@@ -15,7 +15,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
get_config_file_name
,
get_config_file_name
,
)
)
padding_size
=
128
if
bool
(
int
(
os
.
getenv
(
"MOE_PADDING"
,
"0"
)))
else
0
padding_size
=
128
if
bool
(
int
(
os
.
getenv
(
"
SGLANG_
MOE_PADDING"
,
"0"
)))
else
0
def
main
(
model
,
tp_size
,
dtype
:
str
,
batches
):
def
main
(
model
,
tp_size
,
dtype
:
str
,
batches
):
...
...
docker/Dockerfile.rocm
View file @
d364b9b0
...
@@ -18,7 +18,7 @@ ARG TRITON_COMMIT="improve_fa_decode_3.0.0"
...
@@ -18,7 +18,7 @@ ARG TRITON_COMMIT="improve_fa_decode_3.0.0"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
ARG AITER_COMMIT="
testx
"
ARG AITER_COMMIT="
v0.1.1
"
RUN git clone ${SGL_REPO} \
RUN git clone ${SGL_REPO} \
&& cd sglang \
&& cd sglang \
...
@@ -74,7 +74,7 @@ ENV SGLANG_SET_CPU_AFFINITY=1
...
@@ -74,7 +74,7 @@ ENV SGLANG_SET_CPU_AFFINITY=1
ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1
ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1
ENV NCCL_MIN_NCHANNELS=112
ENV NCCL_MIN_NCHANNELS=112
ENV MOE_PADDING=1
ENV
SGLANG_
MOE_PADDING=1
ENV VLLM_FP8_PADDING=1
ENV VLLM_FP8_PADDING=1
ENV VLLM_FP8_ACT_PADDING=1
ENV VLLM_FP8_ACT_PADDING=1
ENV VLLM_FP8_WEIGHT_PADDING=1
ENV VLLM_FP8_WEIGHT_PADDING=1
...
...
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
d364b9b0
...
@@ -45,7 +45,7 @@ if _is_cuda or _is_hip:
...
@@ -45,7 +45,7 @@ if _is_cuda or _is_hip:
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
padding_size
=
128
if
bool
(
int
(
os
.
getenv
(
"MOE_PADDING"
,
"0"
)))
else
0
padding_size
=
128
if
bool
(
int
(
os
.
getenv
(
"
SGLANG_
MOE_PADDING"
,
"0"
)))
else
0
enable_moe_align_block_size_triton
=
bool
(
enable_moe_align_block_size_triton
=
bool
(
int
(
os
.
getenv
(
"ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON"
,
"0"
))
int
(
os
.
getenv
(
"ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON"
,
"0"
))
)
)
...
@@ -1327,7 +1327,7 @@ def fused_experts_impl(
...
@@ -1327,7 +1327,7 @@ def fused_experts_impl(
if
(
if
(
not
(
use_fp8_w8a8
or
use_int8_w8a8
)
not
(
use_fp8_w8a8
or
use_int8_w8a8
)
or
block_shape
is
not
None
or
block_shape
is
not
None
or
(
_is_hip
and
get_bool_env_var
(
"
CK
_MOE"
))
or
(
_is_hip
and
get_bool_env_var
(
"
SGLANG_AITER
_MOE"
))
):
):
padded_size
=
0
padded_size
=
0
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
d364b9b0
...
@@ -18,7 +18,7 @@ from sglang.srt.layers.quantization.base_config import (
...
@@ -18,7 +18,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig
,
QuantizationConfig
,
QuantizeMethodBase
,
QuantizeMethodBase
,
)
)
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
,
permute_weight
,
set_weight_attrs
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
,
set_weight_attrs
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
...
@@ -30,7 +30,9 @@ import logging
...
@@ -30,7 +30,9 @@ import logging
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
if
_is_hip
:
if
_is_hip
:
from
aiter
import
ck_moe
from
aiter
import
ActivationType
from
aiter.fused_moe_bf16_asm
import
ck_moe_2stages
from
aiter.ops.shuffle
import
shuffle_weight
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -102,14 +104,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -102,14 +104,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
if
_is_hip
and
get_bool_env_var
(
"
CK
_MOE"
):
if
_is_hip
and
get_bool_env_var
(
"
SGLANG_AITER
_MOE"
):
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
permut
e_weight
(
layer
.
w13_weight
.
data
),
shuffl
e_weight
(
layer
.
w13_weight
.
data
,
(
16
,
16
)
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
permut
e_weight
(
layer
.
w2_weight
.
data
),
shuffl
e_weight
(
layer
.
w2_weight
.
data
,
(
16
,
16
)
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -182,21 +184,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -182,21 +184,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
,
)
)
if
_is_hip
and
get_bool_env_var
(
"
CK
_MOE"
):
if
_is_hip
and
get_bool_env_var
(
"
SGLANG_AITER
_MOE"
):
assert
not
no_combine
,
"unsupported"
assert
not
no_combine
,
"unsupported"
return
ck_moe
(
return
ck_moe
_2stages
(
x
,
x
,
layer
.
w13_weight
,
layer
.
w13_weight
,
layer
.
w2_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
None
,
activation
=
(
None
,
ActivationType
.
Silu
if
activation
==
"silu"
else
ActivationType
.
Gelu
None
,
),
None
,
32
,
None
,
activation
,
)
)
else
:
else
:
return
fused_experts
(
return
fused_experts
(
...
@@ -527,7 +525,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -527,7 +525,7 @@ class FusedMoE(torch.nn.Module):
# Case input scale: input_scale loading is only supported for fp8
# Case input scale: input_scale loading is only supported for fp8
if
"input_scale"
in
weight_name
:
if
"input_scale"
in
weight_name
:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
if
_is_hip
and
get_bool_env_var
(
"
USE
_INT4_WEIGHT"
):
if
_is_hip
and
get_bool_env_var
(
"
SGLANG
_INT4_WEIGHT"
):
loaded_weight
=
loaded_weight
*
2.0
loaded_weight
=
loaded_weight
*
2.0
# this is needed for compressed-tensors only
# this is needed for compressed-tensors only
...
@@ -569,7 +567,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -569,7 +567,7 @@ class FusedMoE(torch.nn.Module):
quant_method
=
getattr
(
param
,
"quant_method"
,
None
)
quant_method
=
getattr
(
param
,
"quant_method"
,
None
)
if
quant_method
==
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
:
if
quant_method
==
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
if
_is_hip
and
get_bool_env_var
(
"
USE
_INT4_WEIGHT"
):
if
_is_hip
and
get_bool_env_var
(
"
SGLANG
_INT4_WEIGHT"
):
loaded_weight
=
loaded_weight
*
0.5
loaded_weight
=
loaded_weight
*
0.5
self
.
_load_per_channel_weight_scale
(
self
.
_load_per_channel_weight_scale
(
...
@@ -592,7 +590,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -592,7 +590,7 @@ class FusedMoE(torch.nn.Module):
)
)
elif
quant_method
==
FusedMoeWeightScaleSupported
.
TENSOR
.
value
:
elif
quant_method
==
FusedMoeWeightScaleSupported
.
TENSOR
.
value
:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
if
_is_hip
and
get_bool_env_var
(
"
USE
_INT4_WEIGHT"
):
if
_is_hip
and
get_bool_env_var
(
"
SGLANG
_INT4_WEIGHT"
):
loaded_weight
=
loaded_weight
*
2.0
loaded_weight
=
loaded_weight
*
2.0
self
.
_load_per_tensor_weight_scale
(
self
.
_load_per_tensor_weight_scale
(
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
d364b9b0
...
@@ -72,8 +72,8 @@ _is_hip = is_hip()
...
@@ -72,8 +72,8 @@ _is_hip = is_hip()
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
if
_is_hip
:
if
_is_hip
:
from
aiter
import
ActivationType
from
aiter
import
ActivationType
,
QuantType
from
aiter.fused_moe_bf16_asm
import
asm_moe
,
ck_moe_2stages
,
ck_moe_2stages_win4
from
aiter.fused_moe_bf16_asm
import
asm_moe
,
ck_moe_2stages
from
aiter.ops.shuffle
import
shuffle_weight
from
aiter.ops.shuffle
import
shuffle_weight
if
not
_is_cuda
:
if
not
_is_cuda
:
...
@@ -484,7 +484,7 @@ class Fp8MoEMethod:
...
@@ -484,7 +484,7 @@ class Fp8MoEMethod:
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
params_dtype
=
(
params_dtype
=
(
torch
.
uint32
torch
.
uint32
if
get_bool_env_var
(
"
USE
_INT4_WEIGHT"
)
if
get_bool_env_var
(
"
SGLANG
_INT4_WEIGHT"
)
else
torch
.
float8_e4m3fn
else
torch
.
float8_e4m3fn
)
)
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
...
@@ -511,7 +511,7 @@ class Fp8MoEMethod:
...
@@ -511,7 +511,7 @@ class Fp8MoEMethod:
)
)
# WEIGHTS
# WEIGHTS
if
_is_hip
and
get_bool_env_var
(
"
USE
_INT4_WEIGHT"
):
if
_is_hip
and
get_bool_env_var
(
"
SGLANG
_INT4_WEIGHT"
):
# INT4 MoE weight - INT32 packed
# INT4 MoE weight - INT32 packed
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
torch
.
empty
(
...
@@ -585,7 +585,7 @@ class Fp8MoEMethod:
...
@@ -585,7 +585,7 @@ class Fp8MoEMethod:
if
(
if
(
_is_hip
_is_hip
):
# and get_bool_env_var("
CK
_MOE"): TODO: add check back after triton kernel
):
# and get_bool_env_var("
SGLANG_AITER
_MOE"): TODO: add check back after triton kernel
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
w13_weight_scale1
=
torch
.
nn
.
Parameter
(
w13_weight_scale1
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size
,
dtype
=
torch
.
float32
),
torch
.
ones
(
num_experts
,
2
*
intermediate_size
,
dtype
=
torch
.
float32
),
...
@@ -612,7 +612,7 @@ class Fp8MoEMethod:
...
@@ -612,7 +612,7 @@ class Fp8MoEMethod:
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
if
_is_hip
and
get_bool_env_var
(
"
USE
_INT4_WEIGHT"
):
if
_is_hip
and
get_bool_env_var
(
"
SGLANG
_INT4_WEIGHT"
):
extra_weight_attrs
.
update
(
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
}
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
}
)
)
...
@@ -644,7 +644,7 @@ class Fp8MoEMethod:
...
@@ -644,7 +644,7 @@ class Fp8MoEMethod:
layer
.
w2_input_scale
=
None
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
_is_hip
and
get_bool_env_var
(
"
USE
_INT4_WEIGHT"
):
if
_is_hip
and
get_bool_env_var
(
"
SGLANG
_INT4_WEIGHT"
):
self
.
process_weights_hip_int4
(
layer
)
self
.
process_weights_hip_int4
(
layer
)
return
return
...
@@ -675,7 +675,7 @@ class Fp8MoEMethod:
...
@@ -675,7 +675,7 @@ class Fp8MoEMethod:
)
)
layer
.
w2_input_scale
=
None
layer
.
w2_input_scale
=
None
if
get_bool_env_var
(
"
CK
_MOE"
):
if
get_bool_env_var
(
"
SGLANG_AITER
_MOE"
):
# Pre-shuffle weights
# Pre-shuffle weights
layer
.
w13_weight
.
data
=
shuffle_weight
(
layer
.
w13_weight
.
data
=
shuffle_weight
(
layer
.
w13_weight
.
contiguous
(),
(
16
,
16
)
layer
.
w13_weight
.
contiguous
(),
(
16
,
16
)
...
@@ -798,17 +798,15 @@ class Fp8MoEMethod:
...
@@ -798,17 +798,15 @@ class Fp8MoEMethod:
return
return
def
process_weights_hip_int4
(
self
,
layer
:
Module
):
def
process_weights_hip_int4
(
self
,
layer
:
Module
):
# TODO: and get_bool_env_var("
CK
_MOE"): add after triton kernel added
# TODO: and get_bool_env_var("
SGLANG_AITER
_MOE"): add after triton kernel added
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
# Weight Permutation
# Weight Permutation
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
# permute_weight(layer.w13_weight.data),
shuffle_weight
(
layer
.
w13_weight
.
data
,
(
16
,
16
)),
shuffle_weight
(
layer
.
w13_weight
.
data
,
(
16
,
16
)),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
# permute_weight(layer.w2_weight.data),
shuffle_weight
(
layer
.
w2_weight
.
data
,
(
16
,
16
)),
shuffle_weight
(
layer
.
w2_weight
.
data
,
(
16
,
16
)),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
...
@@ -847,23 +845,21 @@ class Fp8MoEMethod:
...
@@ -847,23 +845,21 @@ class Fp8MoEMethod:
padding_size
,
# Avoid circular import
padding_size
,
# Avoid circular import
)
)
if
get_bool_env_var
(
"
CK
_MOE"
):
if
get_bool_env_var
(
"
SGLANG_AITER
_MOE"
):
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
# permute_weight(layer.w13_weight.data),
shuffle_weight
(
layer
.
w13_weight
.
data
,
(
16
,
16
)),
shuffle_weight
(
layer
.
w13_weight
.
data
,
(
16
,
16
)),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
# permute_weight(layer.w2_weight.data),
shuffle_weight
(
layer
.
w2_weight
.
data
,
(
16
,
16
)),
shuffle_weight
(
layer
.
w2_weight
.
data
,
(
16
,
16
)),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
# ROCm (
CK
_MOE): using column-wise scaling
# ROCm (
SGLANG_AITER
_MOE): using column-wise scaling
layer
.
w13_weight_scale1
*=
layer
.
w13_weight_scale
.
unsqueeze
(
-
1
)
layer
.
w13_weight_scale1
*=
layer
.
w13_weight_scale
.
unsqueeze
(
-
1
)
layer
.
w2_weight_scale1
*=
layer
.
w2_weight_scale
.
unsqueeze
(
-
1
)
layer
.
w2_weight_scale1
*=
layer
.
w2_weight_scale
.
unsqueeze
(
-
1
)
elif
get_bool_env_var
(
"MOE_PADDING"
):
elif
get_bool_env_var
(
"
SGLANG_
MOE_PADDING"
):
# If ROCm, apply weight padding (min. Mem channel contention) only if set
# If ROCm, apply weight padding (min. Mem channel contention) only if set
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
F
.
pad
(
layer
.
w13_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
F
.
pad
(
layer
.
w13_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
...
@@ -912,15 +908,16 @@ class Fp8MoEMethod:
...
@@ -912,15 +908,16 @@ class Fp8MoEMethod:
)
)
if
_is_hip
:
if
_is_hip
:
if
get_bool_env_var
(
"
USE
_INT4_WEIGHT"
):
if
get_bool_env_var
(
"
SGLANG
_INT4_WEIGHT"
):
# TODO: add triton kernel and add check get_bool_env_var("
CK
_MOE")
# TODO: add triton kernel and add check get_bool_env_var("
SGLANG_AITER
_MOE")
assert
not
no_combine
,
f
"
{
no_combine
=
}
is not supported."
assert
not
no_combine
,
f
"
{
no_combine
=
}
is not supported."
return
ck_moe_2stages
_win4
(
return
ck_moe_2stages
(
x
,
x
,
layer
.
w13_weight
,
layer
.
w13_weight
,
layer
.
w2_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
QuantType
.
per_Token
,
layer
.
w13_weight_scale1
,
layer
.
w13_weight_scale1
,
layer
.
w2_weight_scale1
,
layer
.
w2_weight_scale1
,
activation
=
(
activation
=
(
...
@@ -930,13 +927,13 @@ class Fp8MoEMethod:
...
@@ -930,13 +927,13 @@ class Fp8MoEMethod:
),
),
)
)
if
get_bool_env_var
(
"
CK
_MOE"
):
if
get_bool_env_var
(
"
SGLANG_AITER
_MOE"
):
assert
not
no_combine
,
f
"
{
no_combine
=
}
is not supported."
assert
not
no_combine
,
f
"
{
no_combine
=
}
is not supported."
if
self
.
block_quant
:
if
self
.
block_quant
:
# TODO(
CK
_MOE): FP8 block_quant only supports 'silu' for the time-being.
# TODO(
SGLANG_AITER
_MOE): FP8 block_quant only supports 'silu' for the time-being.
assert
(
assert
(
activation
==
"silu"
activation
==
"silu"
),
f
"
CK
_MOE: FP8 bloack_quant
{
activation
=
}
will be supported later, unset
CK
_MOE"
),
f
"
SGLANG_AITER
_MOE: FP8 bloack_quant
{
activation
=
}
will be supported later, unset
SGLANG_AITER
_MOE"
return
asm_moe
(
return
asm_moe
(
x
,
x
,
layer
.
w13_weight
,
layer
.
w13_weight
,
...
@@ -955,6 +952,7 @@ class Fp8MoEMethod:
...
@@ -955,6 +952,7 @@ class Fp8MoEMethod:
layer
.
w2_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
QuantType
.
per_Token
,
layer
.
w13_weight_scale1
,
layer
.
w13_weight_scale1
,
layer
.
w2_weight_scale1
,
layer
.
w2_weight_scale1
,
activation
=
(
activation
=
(
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
d364b9b0
...
@@ -31,7 +31,7 @@ from sglang.srt.utils import (
...
@@ -31,7 +31,7 @@ from sglang.srt.utils import (
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
if
_is_hip
and
get_bool_env_var
(
"
CK
_MOE"
):
if
_is_hip
and
get_bool_env_var
(
"
SGLANG_AITER
_MOE"
):
from
aiter
import
gemm_a8w8_blockscale
from
aiter
import
gemm_a8w8_blockscale
if
_is_cuda
:
if
_is_cuda
:
...
@@ -132,7 +132,7 @@ def apply_w8a8_block_fp8_linear(
...
@@ -132,7 +132,7 @@ def apply_w8a8_block_fp8_linear(
output
=
fp8_blockwise_scaled_mm
(
output
=
fp8_blockwise_scaled_mm
(
q_input
,
weight
.
T
,
x_scale
,
weight_scale
.
T
,
out_dtype
=
input
.
dtype
q_input
,
weight
.
T
,
x_scale
,
weight_scale
.
T
,
out_dtype
=
input
.
dtype
)
)
elif
_is_hip
and
get_bool_env_var
(
"
CK
_MOE"
):
elif
_is_hip
and
get_bool_env_var
(
"
SGLANG_AITER
_MOE"
):
q_input
,
x_scale
=
per_token_group_quant_fp8
(
q_input
,
x_scale
=
per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
],
column_major_scales
=
False
input_2d
,
block_size
[
1
],
column_major_scales
=
False
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment