Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d1da58e2
Unverified
Commit
d1da58e2
authored
Mar 11, 2025
by
Yineng Zhang
Committed by
GitHub
Mar 11, 2025
Browse files
unify is_cuda and is_hip (#4321)
parent
1cf63485
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
104 additions
and
92 deletions
+104
-92
python/sglang/srt/custom_op.py
python/sglang/srt/custom_op.py
+5
-3
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
...srt/distributed/device_communicators/custom_all_reduce.py
+18
-17
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+6
-6
python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py
.../layers/attention/triton_ops/double_sparsity_attention.py
+3
-3
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
...glang/srt/layers/attention/triton_ops/extend_attention.py
+4
-4
python/sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py
...g/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py
+3
-3
python/sglang/srt/layers/moe/ep_moe/kernels.py
python/sglang/srt/layers/moe/ep_moe/kernels.py
+2
-1
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+3
-1
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+9
-9
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+7
-7
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+13
-13
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+7
-7
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+4
-4
python/sglang/srt/layers/quantization/w8a8_fp8.py
python/sglang/srt/layers/quantization/w8a8_fp8.py
+3
-1
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+2
-2
python/sglang/srt/models/deepseek_nextn.py
python/sglang/srt/models/deepseek_nextn.py
+3
-3
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+4
-4
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+8
-4
No files found.
python/sglang/srt/custom_op.py
View file @
d1da58e2
import
torch
from
torch
import
nn
_is_cuda
=
torch
.
cuda
.
is_available
()
and
torch
.
version
.
cuda
_is_rocm
=
torch
.
cuda
.
is_available
()
and
torch
.
version
.
hip
from
sglang.srt.utils
import
is_cuda
,
is_hip
_is_cuda
=
is_cuda
()
_is_hip
=
is_hip
()
class
CustomOp
(
nn
.
Module
):
...
...
@@ -34,7 +36,7 @@ class CustomOp(nn.Module):
def
dispatch_forward
(
self
):
if
_is_cuda
:
return
self
.
forward_cuda
elif
_is_
rocm
:
elif
_is_
hip
:
return
self
.
forward_hip
else
:
return
self
.
forward_native
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
View file @
d1da58e2
...
...
@@ -22,15 +22,16 @@ from sglang.srt.utils import cuda_device_count_stateless, is_cuda, is_hip
logger
=
logging
.
getLogger
(
__name__
)
is_hip_
=
is_hip
()
_is_cuda
=
is_cuda
()
_is_hip
=
is_hip
()
if
is_cuda
()
:
if
_
is_cuda
:
try
:
import
pynvml
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import pynvml with %r"
,
e
)
if
is_hip
_
:
if
_
is_hip
:
try
:
from
amdsmi
import
(
AmdSmiException
,
...
...
@@ -43,7 +44,7 @@ if is_hip_:
logger
.
warning
(
"Failed to import amdsmi with %r"
,
e
)
try
:
if
ops
.
use_vllm_custom_allreduce
and
not
is_hip
_
:
if
ops
.
use_vllm_custom_allreduce
and
not
_
is_hip
:
# Use vLLM custom allreduce
ops
.
meta_size
()
else
:
...
...
@@ -63,7 +64,7 @@ _R = TypeVar("_R")
def
with_nvml_context
(
fn
:
Callable
[
_P
,
_R
])
->
Callable
[
_P
,
_R
]:
@
wraps
(
fn
)
def
wrapper
(
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
_R
:
if
is_hip
_
:
if
_
is_hip
:
try
:
amdsmi_init
()
return
fn
(
*
args
,
**
kwargs
)
...
...
@@ -81,7 +82,7 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@
with_nvml_context
def
is_full_nvlink
(
physical_device_ids
:
List
[
int
],
world_size
:
int
)
->
bool
:
if
is_hip
_
:
if
_
is_hip
:
"""
query if the set of gpus are fully connected by xgmi (1 hop)
"""
...
...
@@ -145,7 +146,7 @@ def is_weak_contiguous(inp: torch.Tensor):
class
CustomAllreduce
:
_SUPPORTED_WORLD_SIZES
=
[
2
,
4
,
6
,
8
]
_MAX_CAR_SIZE
=
8192
*
1024
if
is_hip
_
:
if
_
is_hip
:
# crossover is at 16MB buffer size for ROCm
_MAX_CAR_SIZE
=
2
*
8192
*
1024
...
...
@@ -229,7 +230,7 @@ class CustomAllreduce:
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
if
is_cuda
()
or
is_hip
_
:
if
_
is_cuda
or
_
is_hip
:
full_nvlink
=
is_full_nvlink
(
physical_device_ids
,
world_size
)
if
world_size
>
2
and
not
full_nvlink
:
...
...
@@ -243,7 +244,7 @@ class CustomAllreduce:
# this is expensive to compute at the first time
# then we cache the result
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
if
not
is_hip
_
and
not
_can_p2p
(
rank
,
world_size
):
if
not
_
is_hip
and
not
_can_p2p
(
rank
,
world_size
):
logger
.
warning
(
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
...
...
@@ -256,7 +257,7 @@ class CustomAllreduce:
self
.
world_size
=
world_size
self
.
full_nvlink
=
full_nvlink
if
ops
.
use_vllm_custom_allreduce
and
not
is_hip
_
:
if
ops
.
use_vllm_custom_allreduce
and
not
_
is_hip
:
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
...
...
@@ -279,7 +280,7 @@ class CustomAllreduce:
)
ops
.
register_buffer
(
self
.
_ptr
,
self
.
buffer_ptrs
)
else
:
if
is_hip
_
:
if
_
is_hip
:
# meta data buffers need to be "uncached" for signal on MI200
self
.
meta
=
ops
.
allocate_meta_buffer
(
ops
.
meta_size
()
+
max_size
)
self
.
buffer
=
torch
.
empty
(
...
...
@@ -418,7 +419,7 @@ class CustomAllreduce:
ops
.
register_buffer
(
self
.
_ptr
,
inp
,
handles
,
offsets
)
def
register_graph_buffers
(
self
):
if
is_hip
_
:
if
_
is_hip
:
handle
,
offset
=
ops
.
get_graph_buffer_ipc_meta
(
self
.
_ptr
)
handles
,
offsets
=
self
.
_gather_ipc_meta
((
bytes
(
handle
),
offset
))
logger
.
info
(
"Registering %d cuda graph addresses"
,
len
(
offset
))
...
...
@@ -454,12 +455,12 @@ class CustomAllreduce:
return
False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
if
ops
.
use_vllm_custom_allreduce
and
not
is_hip
_
:
if
ops
.
use_vllm_custom_allreduce
and
not
_
is_hip
:
if
self
.
world_size
==
2
or
self
.
full_nvlink
:
return
inp_size
<
self
.
max_size
return
False
if
is_hip
_
:
if
_
is_hip
:
if
self
.
full_nvlink
:
if
self
.
world_size
==
8
:
if
self
.
MSCCL
:
...
...
@@ -532,7 +533,7 @@ class CustomAllreduce:
return
None
if
self
.
_IS_CAPTURING
:
if
torch
.
cuda
.
is_current_stream_capturing
():
if
is_hip
_
:
if
_
is_hip
:
return
self
.
all_reduce_reg
(
input
)
else
:
return
self
.
all_reduce
(
input
,
registered
=
True
)
...
...
@@ -541,7 +542,7 @@ class CustomAllreduce:
# allreduce is out-of-place.
return
torch
.
empty_like
(
input
)
else
:
if
is_hip
_
:
if
_
is_hip
:
# note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance
...
...
@@ -556,7 +557,7 @@ class CustomAllreduce:
if
ops
.
use_vllm_custom_allreduce
:
self
.
free_shared_buffer
(
self
.
meta_ptrs
)
self
.
free_shared_buffer
(
self
.
buffer_ptrs
)
elif
is_cuda
()
:
elif
_
is_cuda
:
self
.
free_shared_buffer
(
self
.
buffer_ptrs
)
self
.
free_shared_buffer
(
self
.
tmp_result_buffer_ptrs
)
self
.
free_shared_buffer
(
self
.
barrier_in_ptrs
)
...
...
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
View file @
d1da58e2
...
...
@@ -27,7 +27,7 @@ import triton.language as tl
from
sglang.srt.utils
import
is_hip
is_hip
_
=
is_hip
()
_
is_hip
=
is_hip
()
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -180,7 +180,7 @@ def _decode_att_m_fwd(
):
BLOCK
=
64
# [TODO] work around SGPR limit on MI3xx
if
is_hip
_
:
if
_
is_hip
:
BLOCK
=
8
NUM_KV_SPLITS
=
num_kv_splits
Lk
=
k_buffer
.
shape
[
-
1
]
...
...
@@ -195,7 +195,7 @@ def _decode_att_m_fwd(
num_warps
=
4
else
:
num_warps
=
2
if
is_hip
_
:
if
_
is_hip
:
num_warps
=
1
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
...
...
@@ -406,7 +406,7 @@ def _decode_grouped_att_m_fwd(
Lv
=
v_buffer
.
shape
[
-
1
]
# [TODO] work around shmem limit on MI3xx
if
is_hip
_
and
Lk
>=
576
:
if
_
is_hip
and
Lk
>=
576
:
BLOCK
=
16
if
Lk
==
576
:
...
...
@@ -433,7 +433,7 @@ def _decode_grouped_att_m_fwd(
extra_kargs
=
{}
num_stages
=
2
if
is_hip
_
:
if
_
is_hip
:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs
=
{
"waves_per_eu"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
}
...
...
@@ -546,7 +546,7 @@ def _decode_softmax_reducev_fwd(
NUM_KV_SPLITS
=
num_kv_splits
extra_kargs
=
{}
if
is_hip
_
:
if
_
is_hip
:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs
=
{
"waves_per_eu"
:
4
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
}
...
...
python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py
View file @
d1da58e2
...
...
@@ -9,7 +9,7 @@ is_cuda_available = torch.cuda.is_available()
if
is_cuda_available
:
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
is_hip
_
=
is_hip
()
_
is_hip
=
is_hip
()
if
global_server_args_dict
.
get
(
"attention_reduce_in_fp32"
,
False
):
REDUCE_TRITON_TYPE
=
tl
.
float32
...
...
@@ -1032,7 +1032,7 @@ def extend_attention_fwd(
BLOCK_DPE
=
0
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
if
is_hip
_
:
if
_
is_hip
:
BLOCK_M
,
BLOCK_N
=
(
64
,
64
)
num_warps
=
4
...
...
@@ -1062,7 +1062,7 @@ def extend_attention_fwd(
num_stages
=
1
extra_kargs
=
{}
if
is_hip
_
:
if
_
is_hip
:
extra_kargs
=
{
"waves_per_eu"
:
4
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
}
_fwd_kernel
[
grid
](
...
...
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
View file @
d1da58e2
...
...
@@ -29,7 +29,7 @@ is_cuda_available = torch.cuda.is_available()
if
is_cuda_available
:
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
is_hip
_
=
is_hip
()
_
is_hip
=
is_hip
()
@
triton
.
jit
...
...
@@ -330,7 +330,7 @@ def extend_attention_fwd(
BLOCK_DPE
=
0
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
if
is_hip
_
:
if
_
is_hip
:
BLOCK_M
,
BLOCK_N
=
(
64
,
64
)
num_warps
=
4
...
...
@@ -364,7 +364,7 @@ def extend_attention_fwd(
num_stages
=
1
extra_kargs
=
{}
if
is_hip
_
:
if
_
is_hip
:
extra_kargs
=
{
"waves_per_eu"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
}
_fwd_kernel
[
grid
](
...
...
@@ -403,7 +403,7 @@ def extend_attention_fwd(
Lv
=
Lv
,
USE_CUSTOM_MASK
=
USE_CUSTOM_MASK
,
SKIP_PREFIX_CUSTOM_MASK
=
SKIP_PREFIX_CUSTOM_MASK
,
STORE_TRANSPOSE
=
is_hip
_
,
STORE_TRANSPOSE
=
_
is_hip
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
**
extra_kargs
,
...
...
python/sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py
View file @
d1da58e2
...
...
@@ -32,7 +32,7 @@ def is_hip():
return
triton
.
runtime
.
driver
.
active
.
get_current_target
().
backend
==
"hip"
is_hip
_
=
is_hip
()
_
is_hip
=
is_hip
()
@
triton
.
jit
...
...
@@ -333,7 +333,7 @@ def _decode_grouped_att_m_fwd_rope(
BLOCK
=
32
# # [TODO] work around shmem limit on MI3xx
# if is_hip
_
and kv_lora_rank >= 576:
# if
_
is_hip and kv_lora_rank >= 576:
# BLOCK = 16
qk_rope_head_dim
=
k_buffer
.
shape
[
-
1
]
-
kv_lora_rank
...
...
@@ -353,7 +353,7 @@ def _decode_grouped_att_m_fwd_rope(
extra_kargs
=
{}
num_stages
=
2
if
is_hip
_
:
if
_
is_hip
:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs
=
{
"waves_per_eu"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
}
...
...
python/sglang/srt/layers/moe/ep_moe/kernels.py
View file @
d1da58e2
...
...
@@ -6,8 +6,9 @@ import triton
import
triton.language
as
tl
from
sglang.srt.layers.quantization.fp8_kernel
import
per_token_group_quant_fp8
from
sglang.srt.utils
import
is_cuda
_is_cuda
=
torch
.
cuda
.
is_available
()
and
torch
.
version
.
cuda
_is_cuda
=
is_
cuda
()
if
_is_cuda
:
from
sglang.srt.layers.quantization.fp8_kernel
import
(
sglang_per_token_group_quant_fp8
,
...
...
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
d1da58e2
...
...
@@ -30,6 +30,8 @@ from sglang.srt.utils import is_hip, set_weight_attrs
logger
=
logging
.
getLogger
(
__name__
)
_is_hip
=
is_hip
()
class
GroupedGemmRunner
(
torch
.
nn
.
Module
):
flashinfer_gemm_warpper
=
None
...
...
@@ -703,7 +705,7 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
# If checkpoint is fp16, quantize in place.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# If rocm, use float8_e4m3fnuz as dtype
fp8_dtype
=
torch
.
float8_e4m3fnuz
if
is_hip
()
else
torch
.
float8_e4m3fn
fp8_dtype
=
torch
.
float8_e4m3fnuz
if
_
is_hip
else
torch
.
float8_e4m3fn
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
.
data
,
dtype
=
fp8_dtype
)
w2_weight
=
torch
.
empty_like
(
layer
.
w2_weight
.
data
,
dtype
=
fp8_dtype
)
...
...
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
d1da58e2
...
...
@@ -23,10 +23,11 @@ from sglang.srt.utils import (
direct_register_custom_op
,
get_bool_env_var
,
get_device_name
,
is_cuda
,
is_hip
,
)
is_hip
_
=
is_hip
()
_
is_hip
=
is_hip
()
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -36,8 +37,7 @@ enable_moe_align_block_size_triton = bool(
int
(
os
.
getenv
(
"ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON"
,
"0"
))
)
_is_cuda
=
torch
.
cuda
.
is_available
()
and
torch
.
version
.
cuda
_is_rocm
=
torch
.
cuda
.
is_available
()
and
torch
.
version
.
hip
_is_cuda
=
is_cuda
()
if
_is_cuda
:
from
sgl_kernel
import
gelu_and_mul
,
silu_and_mul
...
...
@@ -46,7 +46,7 @@ if _is_cuda:
sglang_per_token_group_quant_fp8
,
)
if
_is_cuda
or
_is_
rocm
:
if
_is_cuda
or
_is_
hip
:
from
sgl_kernel
import
moe_align_block_size
as
sgl_moe_align_block_size
...
...
@@ -679,7 +679,7 @@ def get_default_config(
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
8
,
"num_stages"
:
2
if
is_hip
_
else
4
,
"num_stages"
:
2
if
_
is_hip
else
4
,
}
if
M
<=
E
:
config
=
{
...
...
@@ -688,7 +688,7 @@ def get_default_config(
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
if
is_hip
_
else
4
,
"num_stages"
:
2
if
_
is_hip
else
4
,
}
else
:
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
...
...
@@ -698,7 +698,7 @@ def get_default_config(
"BLOCK_SIZE_K"
:
block_shape
[
1
],
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
2
if
is_hip
_
else
3
,
"num_stages"
:
2
if
_
is_hip
else
3
,
}
else
:
config
=
{
...
...
@@ -976,7 +976,7 @@ def fused_experts_impl(
if
(
not
(
use_fp8_w8a8
or
use_int8_w8a8
)
or
block_shape
is
not
None
or
(
is_hip
_
and
get_bool_env_var
(
"CK_MOE"
))
or
(
_
is_hip
and
get_bool_env_var
(
"CK_MOE"
))
):
padded_size
=
0
...
...
@@ -1131,7 +1131,7 @@ def fused_experts_impl(
if
no_combine
:
pass
elif
is_hip
_
:
elif
_
is_hip
:
ops
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
d1da58e2
...
...
@@ -27,9 +27,9 @@ else:
import
logging
is_hip
_
=
is_hip
()
_
is_hip
=
is_hip
()
if
is_hip
_
:
if
_
is_hip
:
from
aiter
import
ck_moe
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -102,7 +102,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
if
is_hip
_
and
get_bool_env_var
(
"CK_MOE"
):
if
_
is_hip
and
get_bool_env_var
(
"CK_MOE"
):
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
permute_weight
(
layer
.
w13_weight
.
data
),
requires_grad
=
False
,
...
...
@@ -175,7 +175,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
correction_bias
=
correction_bias
,
)
if
is_hip
_
and
get_bool_env_var
(
"CK_MOE"
):
if
_
is_hip
and
get_bool_env_var
(
"CK_MOE"
):
assert
not
no_combine
,
"unsupported"
return
ck_moe
(
x
,
...
...
@@ -514,7 +514,7 @@ class FusedMoE(torch.nn.Module):
# Case input scale: input_scale loading is only supported for fp8
if
"input_scale"
in
weight_name
:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
if
is_hip
_
and
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
if
_
is_hip
and
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
loaded_weight
=
loaded_weight
*
2.0
# this is needed for compressed-tensors only
...
...
@@ -556,7 +556,7 @@ class FusedMoE(torch.nn.Module):
quant_method
=
getattr
(
param
,
"quant_method"
,
None
)
if
quant_method
==
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
if
is_hip
_
and
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
if
_
is_hip
and
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
loaded_weight
=
loaded_weight
*
0.5
self
.
_load_per_channel_weight_scale
(
...
...
@@ -579,7 +579,7 @@ class FusedMoE(torch.nn.Module):
)
elif
quant_method
==
FusedMoeWeightScaleSupported
.
TENSOR
.
value
:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
if
is_hip
_
and
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
if
_
is_hip
and
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
loaded_weight
=
loaded_weight
*
2.0
self
.
_load_per_tensor_weight_scale
(
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
d1da58e2
...
...
@@ -54,9 +54,9 @@ from sglang.srt.utils import (
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
is_hip
_
=
is_hip
()
_
is_hip
=
is_hip
()
if
is_hip
_
:
if
_
is_hip
:
from
aiter.fused_moe_bf16_asm
import
asm_moe
from
aiter.ops.shuffle
import
shuffle_weight
...
...
@@ -175,7 +175,7 @@ class Fp8LinearMethod(LinearMethodBase):
# kernel for fast weight-only FP8 quantization
self
.
use_marlin
=
get_bool_env_var
(
"SGLANG_FORCE_FP8_MARLIN"
)
# Disable marlin for ROCm
if
is_hip
_
:
if
_
is_hip
:
self
.
use_marlin
=
False
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
...
...
@@ -287,7 +287,7 @@ class Fp8LinearMethod(LinearMethodBase):
# Block quant doesn't need to process weights after loading
if
self
.
block_quant
:
# If ROCm, normalize the weights and scales to e4m3fnuz
if
is_hip
_
:
if
_
is_hip
:
# activation_scheme: dynamic
weight
,
weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
layer
.
weight
,
...
...
@@ -347,7 +347,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
# If ROCm, normalize the weights and scales to e4m3fnuz
if
is_hip
_
:
if
_
is_hip
:
weight
,
weight_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
weight_scale
,
...
...
@@ -563,7 +563,7 @@ class Fp8MoEMethod:
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
if
(
is_hip
_
_
is_hip
):
# and get_bool_env_var("CK_MOE"): TODO: add check back after triton kernel
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
w13_weight_scale1
=
torch
.
nn
.
Parameter
(
...
...
@@ -630,7 +630,7 @@ class Fp8MoEMethod:
# Block quant doesn't need to process weights after loading
if
self
.
block_quant
:
# If ROCm, normalize the weights and scales to e4m3fnuz
if
is_hip
_
:
if
_
is_hip
:
# activation_scheme: dynamic
w13_weight
,
w13_weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
layer
.
w13_weight
,
...
...
@@ -667,7 +667,7 @@ class Fp8MoEMethod:
# If checkpoint is fp16 or bfloat16, quantize in place.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
fp8_dtype
=
torch
.
float8_e4m3fnuz
if
is_hip
_
else
torch
.
float8_e4m3fn
fp8_dtype
=
torch
.
float8_e4m3fnuz
if
_
is_hip
else
torch
.
float8_e4m3fn
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
.
data
,
dtype
=
fp8_dtype
)
w2_weight
=
torch
.
empty_like
(
layer
.
w2_weight
.
data
,
dtype
=
fp8_dtype
)
...
...
@@ -689,7 +689,7 @@ class Fp8MoEMethod:
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
if
is_hip
_
:
if
_
is_hip
:
self
.
process_weights_hip_scale_padding
(
layer
)
return
...
...
@@ -721,7 +721,7 @@ class Fp8MoEMethod:
)
# If ROCm, normalize the weights and scales to e4m3fnuz
if
is_hip
_
:
if
_
is_hip
:
# Normalize the weights and scales
w13_weight
,
w13_weight_scale
,
w13_input_scale
=
(
normalize_e4m3fn_to_e4m3fnuz
(
...
...
@@ -771,7 +771,7 @@ class Fp8MoEMethod:
max_w13_scales
,
requires_grad
=
False
)
if
is_hip
_
:
if
_
is_hip
:
self
.
process_weights_hip_scale_padding
(
layer
)
return
...
...
@@ -882,7 +882,7 @@ class Fp8MoEMethod:
correction_bias
=
correction_bias
,
)
if
is_hip
_
and
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
if
_
is_hip
and
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
assert
not
no_combine
,
f
"
{
no_combine
=
}
is not supported."
return
asm_moe
(
...
...
@@ -895,7 +895,7 @@ class Fp8MoEMethod:
layer
.
w2_weight_scale1
,
activation
=
activation
,
)
if
is_hip
_
and
get_bool_env_var
(
"CK_MOE"
):
if
_
is_hip
and
get_bool_env_var
(
"CK_MOE"
):
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
assert
(
activation
==
"silu"
...
...
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
d1da58e2
...
...
@@ -22,12 +22,12 @@ import torch
import
triton
import
triton.language
as
tl
from
sglang.srt.utils
import
get_device_core_count
,
get_device_name
,
is_hip
from
sglang.srt.utils
import
get_device_core_count
,
get_device_name
,
is_cuda
,
is_hip
is_hip
_
=
is_hip
()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
is_hip
_
else
torch
.
float8_e4m3fn
_
is_hip
=
is_hip
()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
_
is_hip
else
torch
.
float8_e4m3fn
_is_cuda
=
torch
.
cuda
.
is_available
()
and
torch
.
version
.
cuda
_is_cuda
=
is_
cuda
()
if
_is_cuda
:
import
deep_gemm
from
sgl_kernel
import
sgl_per_token_group_quant_fp8
,
sgl_per_token_quant_fp8
...
...
@@ -157,7 +157,7 @@ def per_token_group_quant_fp8(
finfo
=
torch
.
finfo
(
dtype
)
fp8_max
=
finfo
.
max
if
is_hip
_
:
if
_
is_hip
:
fp8_max
=
224.0
fp8_min
=
-
fp8_max
...
...
@@ -332,7 +332,7 @@ def static_quant_fp8(
finfo
=
torch
.
finfo
(
dtype
)
fp8_max
=
finfo
.
max
if
is_hip
_
:
if
_
is_hip
:
fp8_max
=
224.0
fp8_min
=
-
fp8_max
...
...
@@ -732,7 +732,7 @@ def w8a8_block_fp8_matmul(
else
:
kernel
=
(
_w8a8_block_fp8_matmul_unrolledx4
if
(
is_hip
_
==
True
and
num_workgroups
<=
get_device_core_count
())
if
(
_
is_hip
==
True
and
num_workgroups
<=
get_device_core_count
())
else
_w8a8_block_fp8_matmul
)
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
d1da58e2
...
...
@@ -17,8 +17,8 @@ from sglang.srt.utils import (
use_vllm_cutlass_w8a8_fp8_kernel
=
get_bool_env_var
(
"USE_VLLM_CUTLASS_W8A8_FP8_KERNEL"
)
is_hip
_
=
is_hip
()
if
is_hip
_
and
get_bool_env_var
(
"CK_MOE"
):
_
is_hip
=
is_hip
()
if
_
is_hip
and
get_bool_env_var
(
"CK_MOE"
):
from
aiter
import
gemm_a8w8_blockscale
_is_cuda
=
is_cuda
()
...
...
@@ -111,7 +111,7 @@ def apply_w8a8_block_fp8_linear(
output
=
fp8_blockwise_scaled_mm
(
q_input
,
weight
.
T
,
x_scale
,
weight_scale
.
T
,
out_dtype
=
input
.
dtype
)
elif
is_hip
_
and
get_bool_env_var
(
"CK_MOE"
):
elif
_
is_hip
and
get_bool_env_var
(
"CK_MOE"
):
q_input
,
x_scale
=
per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
],
column_major_scales
=
False
)
...
...
@@ -142,7 +142,7 @@ def input_to_float8(
min_val
,
max_val
=
x
.
aminmax
()
amax
=
torch
.
maximum
(
min_val
.
abs
(),
max_val
.
abs
()).
clamp
(
min
=
1e-12
)
fp8_max
=
finfo
.
max
if
is_hip
_
:
if
_
is_hip
:
fp8_max
=
224.0
scale
=
fp8_max
/
amax
x_scl_sat
=
(
x
*
scale
).
clamp
(
min
=-
fp8_max
,
max
=
fp8_max
)
...
...
python/sglang/srt/layers/quantization/w8a8_fp8.py
View file @
d1da58e2
...
...
@@ -16,6 +16,8 @@ from sglang.srt.layers.quantization.fp8_utils import (
)
from
sglang.srt.utils
import
is_hip
_is_hip
=
is_hip
()
class
W8A8Fp8Config
(
QuantizationConfig
):
"""Config class for W8A8 FP8 Quantization.
...
...
@@ -71,7 +73,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
.
detach
()
if
is_hip
()
:
if
_
is_hip
:
weight
,
weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
weight_scale
)
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
d1da58e2
...
...
@@ -35,7 +35,7 @@ from sglang.srt.model_executor.forward_batch_info import (
)
from
sglang.srt.utils
import
is_hip
is_hip
_
=
is_hip
()
_
is_hip
=
is_hip
()
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
...
...
@@ -119,7 +119,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
else
:
capture_bs
=
list
(
range
(
1
,
33
))
if
is_hip
_
:
if
_
is_hip
:
capture_bs
+=
[
i
*
8
for
i
in
range
(
21
,
33
)]
if
max
(
capture_bs
)
>
model_runner
.
req_to_token_pool
.
size
:
...
...
python/sglang/srt/models/deepseek_nextn.py
View file @
d1da58e2
...
...
@@ -40,7 +40,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from
sglang.srt.models.deepseek_v2
import
DeepseekV2DecoderLayer
,
DeepseekV3ForCausalLM
from
sglang.srt.utils
import
add_prefix
,
is_hip
is_hip
_
=
is_hip
()
_
is_hip
=
is_hip
()
class
DeepseekModelNextN
(
nn
.
Module
):
...
...
@@ -277,7 +277,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
weight_block_size
=
self
.
quant_config
.
weight_block_size
if
weight_block_size
is
not
None
:
assert
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale_inv"
)
if
is_hip
_
:
if
_
is_hip
:
weight
,
weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
w
,
weight_scale
=
self_attn
.
kv_b_proj
.
weight_scale_inv
,
...
...
@@ -301,7 +301,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
and
self_attn
.
w_scale
is
None
):
self_attn
.
w_scale
=
self_attn
.
kv_b_proj
.
weight_scale
if
is_hip
_
:
if
_
is_hip
:
self_attn
.
w_scale
*=
2.0
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
d1da58e2
...
...
@@ -65,7 +65,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
add_prefix
,
is_cuda_available
,
is_hip
is_hip
_
=
is_hip
()
_
is_hip
=
is_hip
()
if
is_cuda_available
():
from
sgl_kernel
import
bmm_fp8
...
...
@@ -571,7 +571,7 @@ class DeepseekV2AttentionMLA(nn.Module):
if
no_absorb
():
return
self
.
forward_normal
(
positions
,
hidden_states
,
forward_batch
)
else
:
if
is_hip
_
:
if
_
is_hip
:
if
(
os
.
getenv
(
"SGLANG_ROCM_FUSED_DECODE_MLA"
)
==
"1"
and
forward_batch
.
forward_mode
.
is_decode
()
...
...
@@ -1190,7 +1190,7 @@ class DeepseekV2ForCausalLM(nn.Module):
weight_block_size
=
self
.
quant_config
.
weight_block_size
if
weight_block_size
is
not
None
:
assert
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale_inv"
)
if
is_hip
_
:
if
_
is_hip
:
weight
,
weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
w
,
weight_scale
=
self_attn
.
kv_b_proj
.
weight_scale_inv
,
...
...
@@ -1230,7 +1230,7 @@ class DeepseekV2ForCausalLM(nn.Module):
and
self_attn
.
w_scale
is
None
):
self_attn
.
w_scale
=
self_attn
.
kv_b_proj
.
weight_scale
if
is_hip
_
:
if
_
is_hip
:
self_attn
.
w_scale
*=
2.0
def
get_embed_and_head
(
self
):
...
...
python/sglang/srt/utils.py
View file @
d1da58e2
...
...
@@ -72,13 +72,17 @@ show_time_cost = False
time_infos
=
{}
# https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
def
is_hip
()
->
bool
:
"""Return whether it is HIP on the AMD ROCm platform."""
return
torch
.
version
.
hip
is
not
None
def
is_rocm
()
->
bool
:
return
torch
.
cuda
.
is_available
()
and
torch
.
version
.
hip
def
is_cuda
():
return
hasattr
(
torch
,
"
cuda
"
)
and
torch
.
version
.
cuda
is
not
None
return
torch
.
cuda
.
is_available
(
)
and
torch
.
version
.
cuda
def
is_cuda_alike
():
...
...
@@ -100,11 +104,11 @@ def is_flashinfer_available():
"""
if
not
get_bool_env_var
(
"SGLANG_IS_FLASHINFER_AVAILABLE"
,
default
=
"true"
):
return
False
return
torch
.
cuda
.
is_available
()
and
torch
.
version
.
cuda
return
is_
cuda
()
def
is_cuda_available
():
return
torch
.
cuda
.
is_available
()
and
torch
.
version
.
cuda
return
is_
cuda
()
def
enable_show_time_cost
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment