Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3a6e0418
Unverified
Commit
3a6e0418
authored
Sep 17, 2024
by
HAI
Committed by
GitHub
Sep 17, 2024
Browse files
[Feature, Hardware] Enable SGLang on AMD GPUs via PyTorch for ROCm (#1420)
parent
2fa5cec7
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
104 additions
and
24 deletions
+104
-24
python/sglang/srt/layers/activation.py
python/sglang/srt/layers/activation.py
+12
-0
python/sglang/srt/layers/attention_backend.py
python/sglang/srt/layers/attention_backend.py
+11
-7
python/sglang/srt/layers/fused_moe/layer.py
python/sglang/srt/layers/fused_moe/layer.py
+27
-7
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+12
-0
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+10
-6
python/sglang/srt/lora/lora_manager.py
python/sglang/srt/lora/lora_manager.py
+5
-2
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+5
-1
python/sglang/srt/models/minicpm3.py
python/sglang/srt/models/minicpm3.py
+5
-1
python/sglang/srt/server.py
python/sglang/srt/server.py
+5
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+5
-0
No files found.
python/sglang/srt/layers/activation.py
View file @
3a6e0418
...
...
@@ -13,6 +13,7 @@ limitations under the License.
"""Fused operators for activation layers."""
import
logging
from
typing
import
Optional
import
torch
...
...
@@ -28,6 +29,10 @@ from vllm.model_executor.custom_op import CustomOp
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.utils
import
set_weight_attrs
from
sglang.srt.utils
import
is_hip
logger
=
logging
.
getLogger
(
__name__
)
class
SiluAndMul
(
CustomOp
):
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -135,3 +140,10 @@ def get_act_fn(
act_fn
,
intermediate_size
,
input_is_parallel
,
params_dtype
)
return
act_fn
if
is_hip
():
logger
.
info
(
"FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
)
from
vllm.model_executor.layers.activation
import
GeluAndMul
,
SiluAndMul
python/sglang/srt/layers/attention_backend.py
View file @
3a6e0418
...
...
@@ -12,22 +12,26 @@ from typing import TYPE_CHECKING
import
torch
import
torch.nn
as
nn
from
flashinfer
import
(
BatchDecodeWithPagedKVCacheWrapper
,
BatchPrefillWithPagedKVCacheWrapper
,
BatchPrefillWithRaggedKVCacheWrapper
,
)
from
flashinfer.cascade
import
merge_state
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
from
sglang.global_config
import
global_config
from
sglang.srt.layers.flashinfer_utils
import
update_flashinfer_indices
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
InputMetadata
from
sglang.srt.utils
import
is_hip
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
# ROCm: flashinfer available later
if
not
is_hip
():
from
flashinfer
import
(
BatchDecodeWithPagedKVCacheWrapper
,
BatchPrefillWithPagedKVCacheWrapper
,
BatchPrefillWithRaggedKVCacheWrapper
,
)
from
flashinfer.cascade
import
merge_state
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
class
AttentionBackend
(
ABC
):
"""The base class of attention backends"""
...
...
python/sglang/srt/layers/fused_moe/layer.py
View file @
3a6e0418
...
...
@@ -18,6 +18,8 @@ from vllm.model_executor.layers.quantization.base_config import (
from
vllm.model_executor.layers.quantization.fp8
import
Fp8Config
from
vllm.model_executor.utils
import
set_weight_attrs
from
sglang.srt.utils
import
is_hip
logger
=
init_logger
(
__name__
)
...
...
@@ -381,6 +383,7 @@ from torch.nn import Module
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
,
)
from
vllm.utils
import
print_warning_once
...
...
@@ -479,14 +482,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# If checkpoint is fp16, quantize in place.
# If checkpoint is fp16
or bfloat16
, quantize in place.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
w2_weight
=
torch
.
empty_like
(
layer
.
w2_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
fp8_dtype
=
torch
.
float8_e4m3fnuz
if
is_hip
()
else
torch
.
float8_e4m3fn
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
.
data
,
dtype
=
fp8_dtype
)
w2_weight
=
torch
.
empty_like
(
layer
.
w2_weight
.
data
,
dtype
=
fp8_dtype
)
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
...
...
@@ -534,6 +535,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer
.
a2_scale
.
max
(),
requires_grad
=
False
)
# If ROCm, normalize the weights and scales to e4m3fnuz
if
is_hip
():
# Normalize the weights and scales
w13_weight
,
w13_scale
,
a13_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
layer
.
w13_weight
,
layer
.
w13_scale
,
layer
.
a13_scale
)
w2_weight
,
w2_scale
,
a2_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
layer
.
w2_weight
,
layer
.
w2_scale
,
layer
.
a2_scale
)
# Reset the parameters
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w13_scale
=
torch
.
nn
.
Parameter
(
w13_scale
,
requires_grad
=
False
)
if
a13_scale
is
not
None
:
layer
.
a13_scale
=
torch
.
nn
.
Parameter
(
a13_scale
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
layer
.
w2_scale
=
torch
.
nn
.
Parameter
(
w2_scale
,
requires_grad
=
False
)
if
a2_scale
is
not
None
:
layer
.
a2_scale
=
torch
.
nn
.
Parameter
(
a2_scale
,
requires_grad
=
False
)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert
layer
.
w13_scale
is
not
None
...
...
python/sglang/srt/layers/layernorm.py
View file @
3a6e0418
...
...
@@ -15,6 +15,7 @@ limitations under the License.
"""Fused operators for normalization layers."""
import
logging
from
typing
import
Optional
,
Tuple
,
Union
import
torch
...
...
@@ -27,6 +28,10 @@ from flashinfer.norm import (
)
from
vllm.model_executor.custom_op
import
CustomOp
from
sglang.srt.utils
import
is_hip
logger
=
logging
.
getLogger
(
__name__
)
class
RMSNorm
(
CustomOp
):
def
__init__
(
...
...
@@ -109,3 +114,10 @@ class GemmaRMSNorm(CustomOp):
return
x
,
residual
out
=
gemma_rmsnorm
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
out
if
is_hip
():
logger
.
info
(
"FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
)
from
vllm.model_executor.layers.layernorm
import
GemmaRMSNorm
,
RMSNorm
python/sglang/srt/layers/sampler.py
View file @
3a6e0418
...
...
@@ -2,17 +2,21 @@ import logging
from
typing
import
Union
import
torch
from
flashinfer.sampling
import
(
min_p_sampling_from_probs
,
top_k_renorm_prob
,
top_k_top_p_sampling_from_probs
,
top_p_renorm_prob
,
)
from
torch
import
nn
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.utils
import
is_hip
# ROCm: flashinfer available later
if
not
is_hip
():
from
flashinfer.sampling
import
(
min_p_sampling_from_probs
,
top_k_renorm_prob
,
top_k_top_p_sampling_from_probs
,
top_p_renorm_prob
,
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/lora/lora_manager.py
View file @
3a6e0418
...
...
@@ -21,12 +21,15 @@ import re
from
dataclasses
import
dataclass
import
torch
from
flashinfer
import
SegmentGEMMWrapper
from
sglang.srt.lora.lora
import
LoRAAdapter
,
get_lora_layer
from
sglang.srt.lora.lora_config
import
LoRAConfig
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.utils
import
replace_submodule
from
sglang.srt.utils
import
is_hip
,
replace_submodule
# ROCm: flashinfer available later
if
not
is_hip
():
from
flashinfer
import
SegmentGEMMWrapper
def
get_stacked_name
(
name
):
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
3a6e0418
...
...
@@ -19,7 +19,6 @@ limitations under the License.
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
import
torch
from
flashinfer
import
bmm_fp8
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
CacheConfig
...
...
@@ -48,6 +47,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.utils
import
is_hip
# ROCm: flashinfer available later
if
not
is_hip
():
from
flashinfer
import
bmm_fp8
class
DeepseekV2MLP
(
nn
.
Module
):
...
...
python/sglang/srt/models/minicpm3.py
View file @
3a6e0418
...
...
@@ -19,7 +19,6 @@ import math
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
import
torch
from
flashinfer
import
bmm_fp8
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
CacheConfig
...
...
@@ -44,6 +43,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.utils
import
is_hip
# ROCm: flashinfer available later
if
not
is_hip
():
from
flashinfer
import
bmm_fp8
class
MiniCPM3MLP
(
nn
.
Module
):
...
...
python/sglang/srt/server.py
View file @
3a6e0418
...
...
@@ -78,6 +78,7 @@ from sglang.srt.utils import (
assert_pkg_version
,
configure_logger
,
enable_show_time_cost
,
is_hip
,
kill_child_process
,
maybe_set_triton_cache_manager
,
prepare_model
,
...
...
@@ -434,6 +435,10 @@ def _set_envs_and_config(server_args: ServerArgs):
"at https://docs.flashinfer.ai/installation.html."
,
)
if
is_hip
():
# to figure out a better method of not using fork later
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
def
_wait_and_warmup
(
server_args
,
pipe_finish_writer
,
pid
):
headers
=
{}
...
...
python/sglang/srt/server_args.py
View file @
3a6e0418
...
...
@@ -21,6 +21,8 @@ import logging
import
random
from
typing
import
List
,
Optional
,
Union
from
sglang.srt.utils
import
is_hip
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -164,6 +166,11 @@ class ServerArgs:
)
self
.
sampling_backend
=
"pytorch"
# ROCm: flashinfer available later
if
is_hip
():
self
.
attention_backend
=
"triton"
self
.
sampling_backend
=
"pytorch"
# Default kernel backends
if
self
.
enable_mla
:
logger
.
info
(
"MLA optimization is tunred on. Use triton backend."
)
...
...
python/sglang/srt/utils.py
View file @
3a6e0418
...
...
@@ -51,6 +51,11 @@ show_time_cost = False
time_infos
=
{}
# torch flag AMD GPU
def
is_hip
()
->
bool
:
return
torch
.
version
.
hip
is
not
None
def
enable_show_time_cost
():
global
show_time_cost
show_time_cost
=
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment