Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
38d80967
Commit
38d80967
authored
Sep 12, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.10.2rc2' into v0.10.2rc2-ori
parents
33650733
880c741b
Changes
544
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
138 additions
and
25 deletions
+138
-25
vllm/model_executor/layers/fused_moe/prepare_finalize.py
vllm/model_executor/layers/fused_moe/prepare_finalize.py
+1
-3
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
+2
-3
vllm/model_executor/layers/fused_moe/routing_simulator.py
vllm/model_executor/layers/fused_moe/routing_simulator.py
+5
-3
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+130
-16
No files found.
Too many changes to show.
To preserve performance only
544 of 544+
files are displayed.
Plain diff
Email patch
vllm/model_executor/layers/fused_moe/prepare_finalize.py
View file @
38d80967
...
...
@@ -38,9 +38,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
mk
.
ExpertTokensMetadata
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
)
->
mk
.
PrepareResultType
:
if
apply_router_weight_on_input
:
topk
=
topk_ids
.
size
(
1
)
...
...
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
View file @
38d80967
...
...
@@ -420,9 +420,8 @@ def shuffle_weights(
Args:
*tensors: Variable number of torch.Tensor objects.
layout: A pair of integers specifying the
block sizes used to divide the tensors during shuffling.
Default is (16, 16).
layout: A pair of integers specifying the block sizes used to divide
the tensors during shuffling. Default is (16, 16).
Returns:
A Tuple of shuffled tensors.
...
...
vllm/model_executor/layers/fused_moe/routing_simulator.py
View file @
38d80967
...
...
@@ -10,7 +10,7 @@ like uniform random routing.
"""
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
from
typing
import
Any
,
Optional
import
torch
...
...
@@ -50,7 +50,9 @@ class DistributionBasedRouting(RoutingStrategy):
distributions for testing different routing patterns.
"""
def
__init__
(
self
,
distribution
:
str
=
"uniform"
,
**
distribution_params
):
def
__init__
(
self
,
distribution
:
str
=
"uniform"
,
**
distribution_params
:
Any
):
"""
Initialize distribution-based routing.
...
...
@@ -244,7 +246,7 @@ class RoutingSimulator:
cls
.
_routing_strategies
[
name
]
=
strategy
@
classmethod
def
get_available_strategies
(
cls
):
def
get_available_strategies
(
cls
)
->
list
[
str
]
:
"""
Get list of available routing strategy names.
...
...
vllm/model_executor/layers/layernorm.py
View file @
38d80967
...
...
@@ -9,11 +9,11 @@ import torch.nn as nn
import
vllm.envs
as
envs
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
def
is_rocm_aiter_rmsnorm_enabled
()
->
bool
:
return
current_platform
.
is_rocm
()
\
and
envs
.
VLLM_ROCM_USE_AITER_RMSNORM
\
return
envs
.
VLLM_ROCM_USE_AITER_RMSNORM
\
and
envs
.
VLLM_ROCM_USE_AITER
...
...
@@ -43,8 +43,22 @@ def fused_add_rms_norm(
return
x
,
residual
def
rocm_aiter_rms_norm
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
)
->
torch
.
Tensor
:
def
poly_norm
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
variance_epsilon
:
float
)
->
torch
.
Tensor
:
from
vllm
import
_custom_ops
as
ops
out
=
torch
.
empty_like
(
x
)
ops
.
poly_norm
(
out
,
x
,
weight
,
bias
,
variance_epsilon
,
)
return
out
def
rocm_aiter_rms_norm_impl
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
)
->
torch
.
Tensor
:
import
aiter
as
rocm_aiter
if
x
.
dim
()
>
2
:
x_original_shape
=
x
.
shape
...
...
@@ -55,7 +69,7 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
return
rocm_aiter
.
rms_norm
(
x
,
weight
,
variance_epsilon
)
def
rocm_aiter_
fused_add_rms_norm
(
def
rocm_aiter_
rmsnorm2d_fwd_with_add_impl
(
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
@@ -74,14 +88,48 @@ def rocm_aiter_fused_add_rms_norm(
return
output
,
residual_out
def
dispatch_cuda_rmsnorm_func
(
add_residual
:
bool
):
if
add_residual
:
if
is_rocm_aiter_rmsnorm_enabled
():
return
rocm_aiter_fused_add_rms_norm
return
fused_add_rms_norm
def
rocm_aiter_rms_norm_fake
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
x
)
def
rocm_aiter_rmsnorm2d_fwd_with_add_fake
(
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
empty_like
(
x
),
torch
.
empty_like
(
residual
)
if
is_rocm_aiter_rmsnorm_enabled
():
return
rocm_aiter_rms_norm
if
current_platform
.
is_rocm
():
direct_register_custom_op
(
op_name
=
"rocm_aiter_rms_norm"
,
op_func
=
rocm_aiter_rms_norm_impl
,
mutates_args
=
[],
fake_impl
=
rocm_aiter_rms_norm_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_rmsnorm2d_fwd_with_add"
,
op_func
=
rocm_aiter_rmsnorm2d_fwd_with_add_impl
,
mutates_args
=
[],
fake_impl
=
rocm_aiter_rmsnorm2d_fwd_with_add_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
def
dispatch_rocm_rmsnorm_func
(
with_fused_add
:
bool
,
dtype
:
torch
.
dtype
):
use_aiter
=
is_rocm_aiter_rmsnorm_enabled
()
and
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
if
use_aiter
and
with_fused_add
:
return
torch
.
ops
.
vllm
.
rocm_aiter_rmsnorm2d_fwd_with_add
if
use_aiter
:
return
torch
.
ops
.
vllm
.
rocm_aiter_rms_norm
# fall back to CUDA implementation
if
with_fused_add
:
return
fused_add_rms_norm
return
rms_norm
...
...
@@ -114,6 +162,13 @@ class RMSNorm(CustomOp):
self
.
weight
=
torch
.
ones
(
hidden_size
)
if
self
.
has_weight
:
self
.
weight
=
nn
.
Parameter
(
self
.
weight
)
weight_dtype
=
self
.
weight
.
data
.
dtype
if
current_platform
.
is_rocm
():
self
.
rocm_norm_func
=
dispatch_rocm_rmsnorm_func
(
with_fused_add
=
False
,
dtype
=
weight_dtype
)
self
.
rocm_norm_func_with_add
=
dispatch_rocm_rmsnorm_func
(
with_fused_add
=
True
,
dtype
=
weight_dtype
)
def
forward_native
(
self
,
...
...
@@ -162,13 +217,27 @@ class RMSNorm(CustomOp):
return
self
.
forward_native
(
x
,
residual
)
add_residual
=
residual
is
not
None
norm_func
=
dispatch_cuda_rmsnorm_func
(
add_residual
)
if
add_residual
:
return
fused_add_rms_norm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
else
:
return
rms_norm
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
def
forward_hip
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
if
self
.
variance_size_override
is
not
None
:
return
self
.
forward_native
(
x
,
residual
)
add_residual
=
residual
is
not
None
if
add_residual
:
return
norm_func
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
self
.
rocm_norm_func_with_add
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
else
:
return
norm_func
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
self
.
rocm_norm_func
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
def
forward_xpu
(
self
,
...
...
@@ -265,3 +334,48 @@ class GemmaRMSNorm(CustomOp):
self
.
forward_static
)
self
.
_is_compiled
=
True
return
self
.
forward_native
(
x
,
residual
)
@
CustomOp
.
register
(
"poly_norm"
)
class
PolyNorm
(
CustomOp
):
"""Polynomial normalization.
Computes x -> w_0 * RMSNorm(x^3) + w_1 * RMSNorm(x^2) + w_2 * RMSNorm(x) + b
where w_n is the learned weight and b is the bias.
Refer to https://arxiv.org/html/2411.03884v1
"""
def
__init__
(
self
,
eps
:
float
=
1e-6
,
)
->
None
:
super
().
__init__
()
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
3
)
/
3
)
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
1
))
self
.
variance_epsilon
=
eps
def
_norm
(
self
,
x
):
return
x
/
torch
.
sqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
self
.
variance_epsilon
)
def
forward_native
(
self
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward().
Refer to https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md
"""
orig_dtype
=
x
.
dtype
x_float
=
x
.
to
(
torch
.
float32
)
output
=
(
self
.
weight
[
0
]
*
self
.
_norm
(
x_float
**
3
)
+
self
.
weight
[
1
]
*
self
.
_norm
(
x_float
**
2
)
+
self
.
weight
[
2
]
*
self
.
_norm
(
x_float
)
+
self
.
bias
)
return
output
.
to
(
orig_dtype
)
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
return
poly_norm
(
x
,
self
.
weight
,
self
.
bias
,
self
.
variance_epsilon
)
Prev
1
…
24
25
26
27
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment