Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
59bcc5b6
Unverified
Commit
59bcc5b6
authored
Jan 28, 2026
by
Rohan Potdar
Committed by
GitHub
Jan 28, 2026
Browse files
Use aiter triton fused_add_rmsnorm_pad for gpt-oss (#30976)
Signed-off-by:
Rohan138
<
rohanpotdar138@gmail.com
>
parent
3e440786
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
327 additions
and
11 deletions
+327
-11
tests/compile/test_fuse_act_padding.py
tests/compile/test_fuse_act_padding.py
+131
-0
tests/compile/test_fusion.py
tests/compile/test_fusion.py
+2
-2
vllm/_aiter_ops.py
vllm/_aiter_ops.py
+46
-0
vllm/compilation/pass_manager.py
vllm/compilation/pass_manager.py
+6
-2
vllm/compilation/rocm_aiter_fusion.py
vllm/compilation/rocm_aiter_fusion.py
+104
-1
vllm/config/compilation.py
vllm/config/compilation.py
+16
-0
vllm/config/vllm.py
vllm/config/vllm.py
+16
-0
vllm/model_executor/layers/utils.py
vllm/model_executor/layers/utils.py
+5
-5
vllm/model_executor/models/gpt_oss.py
vllm/model_executor/models/gpt_oss.py
+1
-1
No files found.
tests/compile/test_fuse_act_padding.py
0 → 100644
View file @
59bcc5b6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
import
vllm.config
from
vllm._aiter_ops
import
is_aiter_found_and_supported
,
rocm_aiter_ops
from
vllm.compilation.noop_elimination
import
NoOpEliminationPass
from
vllm.compilation.post_cleanup
import
PostCleanupPass
from
vllm.config
import
(
CompilationConfig
,
CompilationMode
,
ModelConfig
,
PassConfig
,
VllmConfig
,
)
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.utils
import
rocm_unquantized_gemm
from
.backend
import
TestBackend
class
TestModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
num_layers
:
int
,
hidden_size
:
int
,
num_local_experts
:
int
,
x_pad_to_multiple
:
int
,
):
super
().
__init__
()
self
.
num_layers
=
num_layers
self
.
hidden_size
=
hidden_size
self
.
x_pad_to_multiple
=
x_pad_to_multiple
self
.
pad_dim
=
x_pad_to_multiple
-
(
hidden_size
%
x_pad_to_multiple
)
self
.
norm
=
[
RMSNorm
(
hidden_size
,
eps
=
1e-5
)
for
_
in
range
(
num_layers
)]
self
.
router
=
[
torch
.
nn
.
Linear
(
hidden_size
,
num_local_experts
)
for
_
in
range
(
4
)
]
def
forward
(
self
,
x
):
# avoid having graph input be an arg to a pattern directly
x
=
resid
=
torch
.
relu
(
x
)
all_router_logits
=
[]
for
layer
in
range
(
self
.
num_layers
):
x
=
x
[:,
:
self
.
hidden_size
]
x
,
resid
=
self
.
norm
[
layer
](
x
,
resid
)
router_logits
=
rocm_unquantized_gemm
(
self
,
x
,
self
.
router
[
layer
].
weight
,
self
.
router
[
layer
].
bias
)
x
=
torch
.
nn
.
functional
.
pad
(
x
,
(
0
,
self
.
pad_dim
),
mode
=
"constant"
,
value
=
0.0
)
all_router_logits
.
append
(
router_logits
)
return
x
,
resid
,
*
all_router_logits
def
ops_in_model_before
(
self
):
return
[
rocm_aiter_ops
.
get_rmsnorm_fused_add_op
(),
torch
.
ops
.
aten
.
constant_pad_nd
,
]
def
ops_in_model_after
(
self
):
return
[
rocm_aiter_ops
.
get_triton_add_rmsnorm_pad_op
()]
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"num_layers"
,
[
3
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
2880
])
@
pytest
.
mark
.
parametrize
(
"num_local_experts"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"x_pad_to_multiple"
,
[
256
])
@
pytest
.
mark
.
skipif
(
not
is_aiter_found_and_supported
(),
reason
=
"Only test on ROCm with AITER installed and supported"
,
)
def
test_fuse_act_padding
(
dtype
:
torch
.
dtype
,
num_layers
:
int
,
hidden_size
:
int
,
num_local_experts
:
int
,
x_pad_to_multiple
:
int
,
monkeypatch
:
pytest
.
MonkeyPatch
,
):
vllm_config
=
VllmConfig
(
model_config
=
ModelConfig
(
dtype
=
dtype
),
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
custom_ops
=
[
"+rms_norm"
],
pass_config
=
PassConfig
(
fuse_act_padding
=
True
,
eliminate_noops
=
True
),
),
)
with
vllm
.
config
.
set_current_vllm_config
(
vllm_config
),
monkeypatch
.
context
()
as
m
:
from
vllm.compilation.rocm_aiter_fusion
import
(
RocmAiterTritonAddRMSNormPadFusionPass
,
)
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_dtype
(
dtype
)
torch
.
manual_seed
(
1
)
m
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
rocm_aiter_ops
.
refresh_env_variables
()
fusion_pass
=
RocmAiterTritonAddRMSNormPadFusionPass
(
vllm_config
)
passes
=
[
NoOpEliminationPass
(
vllm_config
),
fusion_pass
,
PostCleanupPass
(
vllm_config
),
]
backend
=
TestBackend
(
*
passes
)
model
=
TestModel
(
num_layers
,
hidden_size
,
num_local_experts
,
x_pad_to_multiple
)
x
=
torch
.
rand
(
1
,
hidden_size
)
torch
.
_dynamo
.
mark_dynamic
(
x
,
0
)
outputs_unfused
=
model
(
x
)
model_fused
=
torch
.
compile
(
model
,
backend
=
backend
)
outputs_fused
=
model_fused
(
x
)
torch
.
testing
.
assert_close
(
outputs_unfused
,
outputs_fused
)
assert
fusion_pass
.
matched_count
==
num_layers
backend
.
check_before_ops
(
model
.
ops_in_model_before
())
backend
.
check_after_ops
(
model
.
ops_in_model_after
())
tests/compile/test_fusion.py
View file @
59bcc5b6
...
...
@@ -410,7 +410,7 @@ def test_aiter_fusion_rmsnorm_quant(
)
with
vllm
.
config
.
set_current_vllm_config
(
vllm_config
),
monkeypatch
.
context
()
as
m
:
from
vllm.compilation.rocm_aiter_fusion
import
RocmAiterRMSNormFusionPass
from
vllm.compilation.rocm_aiter_fusion
import
RocmAiterRMSNorm
Quant
FusionPass
m
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
...
...
@@ -420,7 +420,7 @@ def test_aiter_fusion_rmsnorm_quant(
torch
.
set_default_dtype
(
dtype
)
torch
.
manual_seed
(
1
)
fusion_pass
=
RocmAiterRMSNormFusionPass
(
vllm_config
)
fusion_pass
=
RocmAiterRMSNorm
Quant
FusionPass
(
vllm_config
)
model
=
TestModel
(
hidden_size
=
hidden_size
,
...
...
vllm/_aiter_ops.py
View file @
59bcc5b6
...
...
@@ -790,6 +790,41 @@ def _rocm_aiter_act_mul_and_fp8_group_quant_fake(
return
x_fp8
,
out_bs
def
_rocm_aiter_triton_add_rmsnorm_pad_impl
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
,
residual
:
torch
.
Tensor
,
x_pad_to_multiple
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
aiter.ops.triton.fused_add_rmsnorm_pad
import
fused_add_rmsnorm_pad
return
fused_add_rmsnorm_pad
(
x
,
weight
,
variance_epsilon
,
residual
,
x_pad_to_multiple
=
x_pad_to_multiple
,
)
def
_rocm_aiter_triton_add_rmsnorm_pad_fake
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
,
residual
:
torch
.
Tensor
,
x_pad_to_multiple
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
M
,
N
=
x
.
shape
if
x_pad_to_multiple
>
0
:
N_out
=
(
N
+
x_pad_to_multiple
-
1
)
//
x_pad_to_multiple
*
x_pad_to_multiple
else
:
N_out
=
N
out
=
torch
.
empty
((
M
,
N_out
),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
residual_out
=
torch
.
empty_like
(
residual
)
return
out
,
residual_out
# Global flag to ensure ops are registered only once
_OPS_REGISTERED
=
False
...
...
@@ -1108,6 +1143,13 @@ class rocm_aiter_ops:
fake_impl
=
_rocm_aiter_act_mul_and_fp8_group_quant_fake
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_triton_add_rmsnorm_pad"
,
op_func
=
_rocm_aiter_triton_add_rmsnorm_pad_impl
,
fake_impl
=
_rocm_aiter_triton_add_rmsnorm_pad_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_group_fp8_quant"
,
op_func
=
_rocm_aiter_group_fp8_quant_impl
,
...
...
@@ -1175,6 +1217,10 @@ class rocm_aiter_ops:
def
get_act_mul_fused_fp8_group_quant_op
()
->
OpOverload
:
return
torch
.
ops
.
vllm
.
rocm_aiter_act_mul_and_fp8_group_quant
.
default
@
staticmethod
def
get_triton_add_rmsnorm_pad_op
()
->
OpOverload
:
return
torch
.
ops
.
vllm
.
rocm_aiter_triton_add_rmsnorm_pad
.
default
@
staticmethod
def
rms_norm
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
...
...
vllm/compilation/pass_manager.py
View file @
59bcc5b6
...
...
@@ -18,8 +18,9 @@ from .vllm_inductor_pass import VllmInductorPass
if
rocm_aiter_ops
.
is_enabled
():
from
vllm.compilation.rocm_aiter_fusion
import
(
RocmAiterRMSNormFusionPass
,
RocmAiterRMSNorm
Quant
FusionPass
,
RocmAiterSiluMulFp8GroupQuantFusionPass
,
RocmAiterTritonAddRMSNormPadFusionPass
,
)
if
current_platform
.
is_cuda_alike
():
...
...
@@ -123,13 +124,16 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
self
.
passes
+=
[
RMSNormQuantFusionPass
(
config
)]
if
rocm_aiter_ops
.
is_enabled
():
self
.
passes
+=
[
RocmAiterRMSNormFusionPass
(
config
),
RocmAiterRMSNorm
Quant
FusionPass
(
config
),
]
if
self
.
pass_config
.
fuse_act_quant
:
self
.
passes
+=
[
ActivationQuantFusionPass
(
config
)]
if
rocm_aiter_ops
.
is_enabled
():
self
.
passes
+=
[
RocmAiterSiluMulFp8GroupQuantFusionPass
(
config
)]
if
self
.
pass_config
.
fuse_act_padding
and
rocm_aiter_ops
.
is_enabled
():
self
.
passes
+=
[
RocmAiterTritonAddRMSNormPadFusionPass
(
config
)]
if
self
.
pass_config
.
fuse_attn_quant
:
self
.
passes
+=
[
AttnFusionPass
(
config
)]
...
...
vllm/compilation/rocm_aiter_fusion.py
View file @
59bcc5b6
...
...
@@ -266,7 +266,7 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
)
class
RocmAiterRMSNormFusionPass
(
VllmPatternMatcherPass
):
class
RocmAiterRMSNorm
Quant
FusionPass
(
VllmPatternMatcherPass
):
"""
This pass fuses aiter rms_norm & vllm/aiter quant custom ops
into a fused rms_norm_quant op.
...
...
@@ -399,3 +399,106 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
AiterSiluMulFp8GroupQuantPattern
,
]
return
VllmInductorPass
.
hash_source
(
self
,
*
fusion_patterns
)
class
AddAiterRMSNormPadPattern
:
"""
This pattern replaces an aiter_rmsnorm_with_add & a pad op
with a custom triton_add_rmsnorm_pad op from AITER.
"""
AITER_TRITON_ADD_RMSNORM_PAD_OP
=
rocm_aiter_ops
.
get_triton_add_rmsnorm_pad_op
()
def
__init__
(
self
,
epsilon
:
float
,
hidden_size
:
int
,
x_pad_to_multiple
:
int
,
):
self
.
epsilon
=
epsilon
self
.
hidden_size
=
hidden_size
self
.
x_pad_to_multiple
=
x_pad_to_multiple
self
.
rmsnorm_matcher
=
MatcherFusedAddRMSNorm
(
epsilon
,
match_rocm_aiter
=
True
)
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]:
input
,
weight
,
residual
=
self
.
rmsnorm_matcher
.
inputs
()
router_weight
=
torch
.
empty
([
8
,
16
],
dtype
=
weight
.
dtype
,
device
=
weight
.
device
)
router_bias
=
torch
.
empty
([
8
],
dtype
=
weight
.
dtype
,
device
=
weight
.
device
)
return
[
input
,
weight
,
residual
,
router_weight
,
router_bias
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
router_weight
:
torch
.
Tensor
,
router_bias
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
pad_size
=
self
.
x_pad_to_multiple
-
(
self
.
hidden_size
%
self
.
x_pad_to_multiple
)
result_rms
,
residual_out
=
self
.
rmsnorm_matcher
(
input
,
weight
,
residual
)
router_logits
=
torch
.
ops
.
vllm
.
rocm_unquantized_gemm
(
result_rms
,
router_weight
,
router_bias
)
result
=
torch
.
nn
.
functional
.
pad
(
result_rms
,
(
0
,
pad_size
),
mode
=
"constant"
,
value
=
0.0
)
return
result
,
residual_out
,
router_logits
def
replacement
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
router_weight
:
torch
.
Tensor
,
router_bias
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
at
=
self
.
AITER_TRITON_ADD_RMSNORM_PAD_OP
(
x
=
input
,
weight
=
weight
,
variance_epsilon
=
self
.
epsilon
,
residual
=
residual
,
x_pad_to_multiple
=
self
.
x_pad_to_multiple
,
)
result_padded
=
at
[
0
]
router_logits
=
torch
.
ops
.
vllm
.
rocm_unquantized_gemm
(
result_padded
[:,
:
self
.
hidden_size
],
router_weight
,
router_bias
)
residual_out
=
at
[
1
]
return
result_padded
,
residual_out
,
router_logits
pm
.
register_replacement
(
pattern
,
replacement
,
self
.
get_inputs
(),
pm
.
fwd_only
,
pm_pass
)
class
RocmAiterTritonAddRMSNormPadFusionPass
(
VllmPatternMatcherPass
):
"""
This pass replaces an AITER CK RMSNorm + residual add and a pad op
with an triton_add_rmsnorm_pad op from AITER.
"""
def
__init__
(
self
,
config
:
VllmConfig
):
super
().
__init__
(
config
)
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
pass_name
=
"rocm_aiter_triton_add_rmsnorm_pad_fusion_pass"
)
# gpt-oss has hidden size 2880
# padded to a multiple of 128 on gfx942 and 256 on gfx950 respectively
hidden_size
=
2880
for
epsilon
in
[
1e-5
,
1e-6
]:
for
x_pad_to_multiple
in
[
128
,
256
]:
AddAiterRMSNormPadPattern
(
epsilon
,
hidden_size
,
x_pad_to_multiple
).
register
(
self
.
patterns
)
self
.
dump_patterns
(
config
,
self
.
patterns
)
@
VllmInductorPass
.
time_and_log
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
)
->
None
:
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
logger
.
debug
(
"Replaced %s patterns"
,
self
.
matched_count
)
def
uuid
(
self
)
->
str
:
return
VllmInductorPass
.
hash_source
(
self
,
AddAiterRMSNormPadPattern
)
vllm/config/compilation.py
View file @
59bcc5b6
...
...
@@ -126,6 +126,10 @@ class PassConfig:
fuse_allreduce_rms
:
bool
=
Field
(
default
=
None
)
"""Enable flashinfer allreduce fusion."""
# ROCm/AITER specific fusions
fuse_act_padding
:
bool
=
Field
(
default
=
None
)
"""Fuse the custom RMSNorm + padding ops."""
fi_allreduce_fusion_max_size_mb
:
float
|
None
=
None
"""The threshold of the communicated tensor sizes under which
vllm should use flashinfer fused allreduce. Specified as a
...
...
@@ -194,6 +198,7 @@ class PassConfig:
"enable_sp"
,
"fuse_gemm_comms"
,
"fuse_allreduce_rms"
,
"fuse_act_padding"
,
mode
=
"wrap"
,
)
@
classmethod
...
...
@@ -222,12 +227,23 @@ class PassConfig:
"Fusion enabled but reshape elimination disabled. "
"Allreduce + rms norm + quant (fp8) fusion might not work"
)
if
self
.
fuse_act_padding
:
logger
.
warning_once
(
"Fusion enabled but reshape elimination disabled. "
"RMSNorm + padding fusion might not work"
)
if
self
.
enable_qk_norm_rope_fusion
and
not
current_platform
.
is_cuda_alike
():
logger
.
warning_once
(
"QK Norm + RoPE fusion enabled but the current platform is not "
"CUDA or ROCm. The fusion will be disabled."
)
self
.
enable_qk_norm_rope_fusion
=
False
if
self
.
fuse_act_padding
and
not
current_platform
.
is_rocm
():
logger
.
warning_once
(
"Padding fusion enabled but the current platform is not ROCm. "
"The fusion will be disabled."
)
self
.
fuse_act_padding
=
False
class
DynamicShapesType
(
str
,
enum
.
Enum
):
...
...
vllm/config/vllm.py
View file @
59bcc5b6
...
...
@@ -102,6 +102,18 @@ def enable_act_fusion(cfg: "VllmConfig") -> bool:
)
or
cfg
.
compilation_config
.
is_custom_op_enabled
(
"quant_fp8"
)
def
enable_norm_pad_fusion
(
cfg
:
"VllmConfig"
)
->
bool
:
"""Enable if using AITER RMSNorm and AITER Triton GEMMs
and hidden size is 2880 i.e. gpt-oss; otherwise Inductor handles fusion."""
return
(
envs
.
VLLM_ROCM_USE_AITER
and
envs
.
VLLM_ROCM_USE_AITER_RMSNORM
and
envs
.
VLLM_ROCM_USE_AITER_TRITON_GEMM
and
cfg
.
model_config
.
get_hidden_size
()
==
2880
)
OPTIMIZATION_LEVEL_00
=
{
"compilation_config"
:
{
"pass_config"
:
{
...
...
@@ -112,6 +124,7 @@ OPTIMIZATION_LEVEL_00 = {
"fuse_attn_quant"
:
False
,
"enable_sp"
:
False
,
"fuse_gemm_comms"
:
False
,
"fuse_act_padding"
:
False
,
},
"cudagraph_mode"
:
CUDAGraphMode
.
NONE
,
"use_inductor_graph_partition"
:
False
,
...
...
@@ -127,6 +140,7 @@ OPTIMIZATION_LEVEL_01 = {
"fuse_attn_quant"
:
False
,
"enable_sp"
:
False
,
"fuse_gemm_comms"
:
False
,
"fuse_act_padding"
:
enable_norm_pad_fusion
,
},
"cudagraph_mode"
:
CUDAGraphMode
.
PIECEWISE
,
"use_inductor_graph_partition"
:
False
,
...
...
@@ -142,6 +156,7 @@ OPTIMIZATION_LEVEL_02 = {
"fuse_attn_quant"
:
IS_QUANTIZED
,
"enable_sp"
:
IS_DENSE
,
"fuse_gemm_comms"
:
IS_DENSE
,
"fuse_act_padding"
:
enable_norm_pad_fusion
,
},
"cudagraph_mode"
:
CUDAGraphMode
.
FULL_AND_PIECEWISE
,
"use_inductor_graph_partition"
:
False
,
...
...
@@ -157,6 +172,7 @@ OPTIMIZATION_LEVEL_03 = {
"fuse_attn_quant"
:
IS_QUANTIZED
,
"enable_sp"
:
IS_DENSE
,
"fuse_gemm_comms"
:
IS_DENSE
,
"fuse_act_padding"
:
enable_norm_pad_fusion
,
},
"cudagraph_mode"
:
CUDAGraphMode
.
FULL_AND_PIECEWISE
,
"use_inductor_graph_partition"
:
False
,
...
...
vllm/model_executor/layers/utils.py
View file @
59bcc5b6
...
...
@@ -137,6 +137,11 @@ def rocm_unquantized_gemm_impl(
import
math
if
use_aiter_triton_gemm
(
n
,
m
,
k
,
x
.
dtype
):
from
aiter.ops.triton.gemm_a16w16
import
gemm_a16w16
return
gemm_a16w16
(
x
,
weight
,
bias
)
use_skinny_reduce_counting
=
(
envs
.
VLLM_ROCM_USE_SKINNY_GEMM
and
on_gfx950
()
...
...
@@ -155,11 +160,6 @@ def rocm_unquantized_gemm_impl(
out
=
ops
.
wvSplitKrc
(
weight
,
x_view
,
cu_count
,
bias
)
return
out
.
reshape
(
*
x
.
shape
[:
-
1
],
weight
.
shape
[
0
])
if
use_aiter_triton_gemm
(
n
,
m
,
k
,
x
.
dtype
):
from
aiter.ops.triton.gemm_a16w16
import
gemm_a16w16
return
gemm_a16w16
(
x
,
weight
,
bias
)
use_skinny
=
(
envs
.
VLLM_ROCM_USE_SKINNY_GEMM
and
on_gfx9
()
...
...
vllm/model_executor/models/gpt_oss.py
View file @
59bcc5b6
...
...
@@ -187,7 +187,7 @@ class MLPBlock(torch.nn.Module):
)
else
:
g
=
self
.
router
(
x
)
x
=
self
.
experts
(
hidden_states
=
x
,
router_logits
=
g
)
x
=
self
.
experts
(
hidden_states
=
x
,
router_logits
=
g
)
[:,
:
self
.
hidden_size
]
if
self
.
is_sequence_parallel
:
x
=
tensor_model_parallel_all_gather
(
x
.
contiguous
(),
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment