Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
42f39099
Unverified
Commit
42f39099
authored
Jan 13, 2025
by
kk
Committed by
GitHub
Jan 13, 2025
Browse files
Unify sglang coding style (#2856)
Co-authored-by:
Lin, Soga
<
soga.lin@amd.com
>
parent
72c77763
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
18 deletions
+20
-18
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+5
-4
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+15
-14
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
42f39099
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
import
os
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Callable
,
List
,
Optional
,
Tuple
from
typing
import
Callable
,
List
,
Optional
,
Tuple
...
@@ -19,7 +18,7 @@ from sglang.srt.layers.quantization.base_config import (
...
@@ -19,7 +18,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig
,
QuantizationConfig
,
QuantizeMethodBase
,
QuantizeMethodBase
,
)
)
from
sglang.srt.utils
import
is_hip
,
permute_weight
,
set_weight_attrs
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
,
permute_weight
,
set_weight_attrs
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
...
@@ -28,6 +27,8 @@ else:
...
@@ -28,6 +27,8 @@ else:
import
logging
import
logging
is_hip_
=
is_hip
()
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -99,7 +100,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -99,7 +100,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
if
is_hip
()
and
bool
(
int
(
os
.
getenv
(
"CK_MOE"
,
"0"
))
):
if
is_hip
_
and
get_bool_env_var
(
"CK_MOE"
):
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
permute_weight
(
layer
.
w13_weight
.
data
),
permute_weight
(
layer
.
w13_weight
.
data
),
requires_grad
=
False
,
requires_grad
=
False
,
...
@@ -163,7 +164,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -163,7 +164,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
correction_bias
=
correction_bias
,
correction_bias
=
correction_bias
,
)
)
if
is_hip
()
and
bool
(
int
(
os
.
getenv
(
"CK_MOE"
,
"0"
))
):
if
is_hip
_
and
get_bool_env_var
(
"CK_MOE"
):
import
ater
import
ater
from
ater.fused_moe
import
fused_experts_ck
from
ater.fused_moe
import
fused_experts_ck
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
42f39099
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
import
logging
import
logging
import
os
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
import
torch
...
@@ -47,6 +46,8 @@ from sglang.srt.utils import (
...
@@ -47,6 +46,8 @@ from sglang.srt.utils import (
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
is_hip_
=
is_hip
()
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -162,7 +163,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -162,7 +163,7 @@ class Fp8LinearMethod(LinearMethodBase):
# kernel for fast weight-only FP8 quantization
# kernel for fast weight-only FP8 quantization
self
.
use_marlin
=
get_bool_env_var
(
"SGLANG_FORCE_FP8_MARLIN"
)
self
.
use_marlin
=
get_bool_env_var
(
"SGLANG_FORCE_FP8_MARLIN"
)
# Disable marlin for ROCm
# Disable marlin for ROCm
if
is_hip
()
:
if
is_hip
_
:
self
.
use_marlin
=
False
self
.
use_marlin
=
False
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
...
@@ -274,7 +275,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -274,7 +275,7 @@ class Fp8LinearMethod(LinearMethodBase):
# Block quant doesn't need to process weights after loading
# Block quant doesn't need to process weights after loading
if
self
.
block_quant
:
if
self
.
block_quant
:
# If ROCm, normalize the weights and scales to e4m3fnuz
# If ROCm, normalize the weights and scales to e4m3fnuz
if
is_hip
()
:
if
is_hip
_
:
# activation_scheme: dynamic
# activation_scheme: dynamic
weight
,
weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
,
weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
layer
.
weight
,
weight
=
layer
.
weight
,
...
@@ -331,7 +332,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -331,7 +332,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight_scale
=
layer
.
weight_scale
weight_scale
=
layer
.
weight_scale
# If ROCm, normalize the weights and scales to e4m3fnuz
# If ROCm, normalize the weights and scales to e4m3fnuz
if
is_hip
()
:
if
is_hip
_
:
weight
,
weight_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
,
weight_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight
=
weight
,
weight_scale
=
weight_scale
,
weight_scale
=
weight_scale
,
...
@@ -568,7 +569,7 @@ class Fp8MoEMethod:
...
@@ -568,7 +569,7 @@ class Fp8MoEMethod:
# Block quant doesn't need to process weights after loading
# Block quant doesn't need to process weights after loading
if
self
.
block_quant
:
if
self
.
block_quant
:
# If ROCm, normalize the weights and scales to e4m3fnuz
# If ROCm, normalize the weights and scales to e4m3fnuz
if
is_hip
()
:
if
is_hip
_
:
# activation_scheme: dynamic
# activation_scheme: dynamic
w13_weight
,
w13_weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
w13_weight
,
w13_weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
layer
.
w13_weight
,
weight
=
layer
.
w13_weight
,
...
@@ -595,7 +596,7 @@ class Fp8MoEMethod:
...
@@ -595,7 +596,7 @@ class Fp8MoEMethod:
# If checkpoint is fp16 or bfloat16, quantize in place.
# If checkpoint is fp16 or bfloat16, quantize in place.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
fp8_dtype
=
torch
.
float8_e4m3fnuz
if
is_hip
()
else
torch
.
float8_e4m3fn
fp8_dtype
=
torch
.
float8_e4m3fnuz
if
is_hip
_
else
torch
.
float8_e4m3fn
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
.
data
,
dtype
=
fp8_dtype
)
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
.
data
,
dtype
=
fp8_dtype
)
w2_weight
=
torch
.
empty_like
(
layer
.
w2_weight
.
data
,
dtype
=
fp8_dtype
)
w2_weight
=
torch
.
empty_like
(
layer
.
w2_weight
.
data
,
dtype
=
fp8_dtype
)
...
@@ -617,8 +618,8 @@ class Fp8MoEMethod:
...
@@ -617,8 +618,8 @@ class Fp8MoEMethod:
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
if
is_hip
()
:
if
is_hip
_
:
if
bool
(
int
(
os
.
getenv
(
"CK_MOE"
,
"0"
))
):
if
get_bool_env_var
(
"CK_MOE"
):
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
permute_weight
(
layer
.
w13_weight
.
data
),
permute_weight
(
layer
.
w13_weight
.
data
),
requires_grad
=
False
,
requires_grad
=
False
,
...
@@ -629,7 +630,7 @@ class Fp8MoEMethod:
...
@@ -629,7 +630,7 @@ class Fp8MoEMethod:
requires_grad
=
False
,
requires_grad
=
False
,
)
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
elif
bool
(
int
(
os
.
getenv
(
"MOE_PADDING"
,
"0"
))
):
elif
get_bool_env_var
(
"MOE_PADDING"
):
# If ROCm, apply weight padding (min. Mem channel contention) only if set
# If ROCm, apply weight padding (min. Mem channel contention) only if set
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
F
.
pad
(
layer
.
w13_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
F
.
pad
(
layer
.
w13_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
...
@@ -671,7 +672,7 @@ class Fp8MoEMethod:
...
@@ -671,7 +672,7 @@ class Fp8MoEMethod:
)
)
# If ROCm, normalize the weights and scales to e4m3fnuz
# If ROCm, normalize the weights and scales to e4m3fnuz
if
is_hip
()
:
if
is_hip
_
:
# Normalize the weights and scales
# Normalize the weights and scales
w13_weight
,
w13_weight_scale
,
w13_input_scale
=
(
w13_weight
,
w13_weight_scale
,
w13_input_scale
=
(
normalize_e4m3fn_to_e4m3fnuz
(
normalize_e4m3fn_to_e4m3fnuz
(
...
@@ -721,8 +722,8 @@ class Fp8MoEMethod:
...
@@ -721,8 +722,8 @@ class Fp8MoEMethod:
max_w13_scales
,
requires_grad
=
False
max_w13_scales
,
requires_grad
=
False
)
)
if
is_hip
()
:
if
is_hip
_
:
if
bool
(
int
(
os
.
getenv
(
"CK_MOE"
,
"0"
))
):
if
get_bool_env_var
(
"CK_MOE"
):
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
permute_weight
(
layer
.
w13_weight
.
data
),
permute_weight
(
layer
.
w13_weight
.
data
),
requires_grad
=
False
,
requires_grad
=
False
,
...
@@ -733,7 +734,7 @@ class Fp8MoEMethod:
...
@@ -733,7 +734,7 @@ class Fp8MoEMethod:
requires_grad
=
False
,
requires_grad
=
False
,
)
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
elif
bool
(
int
(
os
.
getenv
(
"MOE_PADDING"
,
"0"
))
):
elif
get_bool_env_var
(
"MOE_PADDING"
):
# If ROCm, apply weight padding (min. Mem channel contention) only if set
# If ROCm, apply weight padding (min. Mem channel contention) only if set
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
F
.
pad
(
layer
.
w13_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
F
.
pad
(
layer
.
w13_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
...
@@ -777,7 +778,7 @@ class Fp8MoEMethod:
...
@@ -777,7 +778,7 @@ class Fp8MoEMethod:
correction_bias
=
correction_bias
,
correction_bias
=
correction_bias
,
)
)
if
is_hip
()
and
bool
(
int
(
os
.
getenv
(
"CK_MOE"
,
"0"
))
):
if
is_hip
_
and
get_bool_env_var
(
"CK_MOE"
):
import
ater
import
ater
from
ater.fused_moe
import
fused_experts_ck
from
ater.fused_moe
import
fused_experts_ck
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment