Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5f65e2b8
Unverified
Commit
5f65e2b8
authored
Oct 30, 2024
by
HAI
Committed by
GitHub
Oct 30, 2024
Browse files
[Performance, Hardware] MoE weights padding to AMD MI300x GPUs (#1836)
parent
4e2af03c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
3 deletions
+32
-3
python/sglang/srt/layers/fused_moe/fused_moe.py
python/sglang/srt/layers/fused_moe/fused_moe.py
+4
-3
python/sglang/srt/layers/fused_moe/layer.py
python/sglang/srt/layers/fused_moe/layer.py
+28
-0
No files found.
python/sglang/srt/layers/fused_moe/fused_moe.py
View file @
5f65e2b8
...
@@ -14,6 +14,7 @@ from vllm import _custom_ops as ops
...
@@ -14,6 +14,7 @@ from vllm import _custom_ops as ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
padding_size
=
128
if
bool
(
int
(
os
.
getenv
(
"MOE_PADDING"
,
"0"
)))
else
0
@
triton
.
jit
@
triton
.
jit
...
@@ -263,7 +264,7 @@ def invoke_fused_moe_kernel(
...
@@ -263,7 +264,7 @@ def invoke_fused_moe_kernel(
expert_ids
,
expert_ids
,
num_tokens_post_padded
,
num_tokens_post_padded
,
B
.
shape
[
1
],
B
.
shape
[
1
],
B
.
shape
[
2
],
B
.
shape
[
2
]
-
padding_size
,
sorted_token_ids
.
shape
[
0
],
sorted_token_ids
.
shape
[
0
],
topk_ids
.
numel
(),
topk_ids
.
numel
(),
A
.
stride
(
0
),
A
.
stride
(
0
),
...
@@ -464,7 +465,7 @@ def fused_experts(
...
@@ -464,7 +465,7 @@ def fused_experts(
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
# Check constraints.
# Check constraints.
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
]
-
padding_size
,
"Hidden size mismatch"
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
...
@@ -481,7 +482,7 @@ def fused_experts(
...
@@ -481,7 +482,7 @@ def fused_experts(
get_config_func
=
functools
.
partial
(
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
try_get_optimal_moe_config
,
w1
.
shape
,
w1
.
shape
,
w2
.
shape
,
(
w2
.
shape
[
0
],
w2
.
shape
[
1
],
w2
.
shape
[
2
]
-
padding_size
)
,
topk_ids
.
shape
[
1
],
topk_ids
.
shape
[
1
],
"float8"
if
use_fp8
else
None
,
"float8"
if
use_fp8
else
None
,
override_config
=
override_config
,
override_config
=
override_config
,
...
...
python/sglang/srt/layers/fused_moe/layer.py
View file @
5f65e2b8
# Adapted from
# Adapted from
# https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe
# https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe
import
os
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
import
torch.nn.functional
as
F
from
vllm.distributed
import
(
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
...
@@ -18,6 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -18,6 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from
vllm.model_executor.layers.quantization.fp8
import
Fp8Config
from
vllm.model_executor.layers.quantization.fp8
import
Fp8Config
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
sglang.srt.layers.fused_moe.fused_moe
import
padding_size
from
sglang.srt.utils
import
is_hip
from
sglang.srt.utils
import
is_hip
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -506,6 +509,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -506,6 +509,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
# If ROCm, apply weight padding (min. Mem channel contention) only if set
if
is_hip
()
and
bool
(
int
(
os
.
getenv
(
"MOE_PADDING"
,
"0"
))):
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
F
.
pad
(
layer
.
w13_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
F
.
pad
(
layer
.
w2_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
return
return
# If checkpoint is fp8, we need to handle that the
# If checkpoint is fp8, we need to handle that the
...
@@ -572,6 +588,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -572,6 +588,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
start
+=
shard_size
start
+=
shard_size
layer
.
w13_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
layer
.
w13_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
# If ROCm, apply weight padding (min. Mem channel contention) only if set
if
is_hip
()
and
bool
(
int
(
os
.
getenv
(
"MOE_PADDING"
,
"0"
))):
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
F
.
pad
(
layer
.
w13_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
F
.
pad
(
layer
.
w2_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
return
return
def
apply
(
def
apply
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment