Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
156405d2
Unverified
Commit
156405d2
authored
Apr 05, 2026
by
Xiaoshuang Wang
Committed by
GitHub
Apr 04, 2026
Browse files
[vLLM IR] gemma_rms_norm (#38780)
Signed-off-by:
Icey
<
1790571317@qq.com
>
parent
99e5539a
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
25 additions
and
67 deletions
+25
-67
vllm/ir/ops/layernorm.py
vllm/ir/ops/layernorm.py
+2
-3
vllm/kernels/aiter_ops.py
vllm/kernels/aiter_ops.py
+4
-6
vllm/kernels/vllm_c.py
vllm/kernels/vllm_c.py
+5
-2
vllm/kernels/xpu_ops.py
vllm/kernels/xpu_ops.py
+3
-1
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+11
-55
No files found.
vllm/ir/ops/layernorm.py
View file @
156405d2
...
@@ -16,7 +16,6 @@ def rms_norm(
...
@@ -16,7 +16,6 @@ def rms_norm(
x_var
=
x
if
variance_size
is
None
else
x
[...,
:
variance_size
]
x_var
=
x
if
variance_size
is
None
else
x
[...,
:
variance_size
]
variance
=
x_var
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
variance
=
x_var
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
epsilon
)
x
=
x
*
torch
.
rsqrt
(
variance
+
epsilon
)
x
=
x
.
to
(
orig_dtype
)
if
weight
is
not
None
:
if
weight
is
not
None
:
x
=
x
*
weight
x
=
x
.
to
(
weight
.
dtype
)
*
weight
return
x
return
x
.
to
(
orig_dtype
)
vllm/kernels/aiter_ops.py
View file @
156405d2
...
@@ -36,13 +36,11 @@ AITER_SUPPORTED = is_aiter_found()
...
@@ -36,13 +36,11 @@ AITER_SUPPORTED = is_aiter_found()
rms_no_var_16bit_only
=
(
rms_no_var_16bit_only
=
(
lambda
x
,
weight
,
epsilon
,
variance_size
=
None
:
variance_size
is
None
lambda
x
,
weight
,
epsilon
,
variance_size
=
None
:
variance_size
is
None
and
x
.
dtype
and
x
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
)
in
(
and
(
weight
is
None
or
weight
.
dtype
==
x
.
dtype
)
torch
.
float16
,
torch
.
bfloat16
,
)
)
)
"""AITER rms_norm only supports float16 and bfloat16 acts and no var_size override."""
"""AITER rms_norm only supports float16 and bfloat16 acts, no var_size override,
and requires weight dtype to match x dtype."""
@
ir
.
ops
.
rms_norm
.
register_impl
(
@
ir
.
ops
.
rms_norm
.
register_impl
(
...
...
vllm/kernels/vllm_c.py
View file @
156405d2
...
@@ -11,8 +11,11 @@ current_platform.import_kernels()
...
@@ -11,8 +11,11 @@ current_platform.import_kernels()
CUDA_ALIKE
=
current_platform
.
is_cuda_alike
()
CUDA_ALIKE
=
current_platform
.
is_cuda_alike
()
"""Most kernels in this file are supported on all CUDA-alike platforms."""
"""Most kernels in this file are supported on all CUDA-alike platforms."""
rms_no_var_size
=
lambda
x
,
weight
,
epsilon
,
variance_size
=
None
:
variance_size
is
None
rms_no_var_size
=
(
"""vLLM kernel does not support variance_size parameter."""
lambda
x
,
weight
,
epsilon
,
variance_size
=
None
:
variance_size
is
None
and
(
weight
is
None
or
weight
.
dtype
==
x
.
dtype
)
)
"""vLLM kernel does not support variance_size parameter or mismatched weight dtype."""
@
ir
.
ops
.
rms_norm
.
register_impl
(
@
ir
.
ops
.
rms_norm
.
register_impl
(
...
...
vllm/kernels/xpu_ops.py
View file @
156405d2
...
@@ -18,7 +18,9 @@ def is_xpu_kernels_found() -> bool:
...
@@ -18,7 +18,9 @@ def is_xpu_kernels_found() -> bool:
XPU_KERNELS_SUPPORTED
=
is_xpu_kernels_found
()
XPU_KERNELS_SUPPORTED
=
is_xpu_kernels_found
()
"""Kernels in this file are supported if vLLM XPU kernels are installed."""
"""Kernels in this file are supported if vLLM XPU kernels are installed."""
rms_no_var
=
lambda
x
,
weight
,
epsilon
,
variance_size
=
None
:
variance_size
is
None
rms_no_var
=
lambda
x
,
weight
,
epsilon
,
variance_size
=
None
:
variance_size
is
None
and
(
weight
is
None
or
weight
.
dtype
==
x
.
dtype
)
@
ir
.
ops
.
rms_norm
.
register_impl
(
@
ir
.
ops
.
rms_norm
.
register_impl
(
...
...
vllm/model_executor/layers/layernorm.py
View file @
156405d2
...
@@ -376,46 +376,6 @@ class GemmaRMSNorm(CustomOp):
...
@@ -376,46 +376,6 @@ class GemmaRMSNorm(CustomOp):
self
.
weight
=
nn
.
Parameter
(
torch
.
zeros
(
hidden_size
))
self
.
weight
=
nn
.
Parameter
(
torch
.
zeros
(
hidden_size
))
self
.
variance_epsilon
=
eps
self
.
variance_epsilon
=
eps
@
staticmethod
def
_forward_static_no_residual
(
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward() without residual."""
orig_dtype
=
x
.
dtype
x
=
x
.
float
()
variance
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
variance_epsilon
)
x
=
x
*
(
1.0
+
weight
.
float
())
x
=
x
.
to
(
orig_dtype
)
return
x
@
staticmethod
def
_forward_static_with_residual
(
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
,
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""PyTorch-native implementation equivalent to forward() with residual."""
orig_dtype
=
x
.
dtype
x
=
(
x
.
float
()
+
residual
.
float
()
if
orig_dtype
==
torch
.
float16
else
x
+
residual
)
residual
=
x
x
=
x
.
float
()
variance
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
variance_epsilon
)
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
x
=
x
*
(
1.0
+
weight
.
float
())
x
=
x
.
to
(
orig_dtype
)
return
x
,
residual
def
forward_native
(
def
forward_native
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
@@ -423,30 +383,26 @@ class GemmaRMSNorm(CustomOp):
...
@@ -423,30 +383,26 @@ class GemmaRMSNorm(CustomOp):
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""PyTorch-native implementation equivalent to forward()."""
"""PyTorch-native implementation equivalent to forward()."""
if
residual
is
None
:
if
residual
is
None
:
return
self
.
_forward_static_no_residual
(
return
ir
.
ops
.
rms_norm
(
self
.
weight
.
data
,
self
.
variance_epsilon
,
x
x
,
self
.
weight
.
data
.
float
()
+
1.0
,
self
.
variance_epsilon
)
)
else
:
else
:
return
self
.
_forward_static_with_residual
(
orig_dtype
=
x
.
dtype
self
.
weight
.
data
,
self
.
variance_epsilon
,
x
,
residual
x
=
(
x
.
float
()
+
residual
.
float
()
if
orig_dtype
==
torch
.
float16
else
x
+
residual
)
)
residual
=
x
return
ir
.
ops
.
rms_norm
(
x
,
self
.
weight
.
data
.
float
()
+
1.0
,
self
.
variance_epsilon
).
to
(
orig_dtype
),
residual
def
forward_cuda
(
def
forward_cuda
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
=
None
,
residual
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
torch
.
compiler
.
is_compiling
():
return
self
.
forward_native
(
x
,
residual
)
if
not
getattr
(
self
,
"_is_compiled"
,
False
):
self
.
_forward_static_no_residual
=
torch
.
compile
(
# type: ignore
self
.
_forward_static_no_residual
)
self
.
_forward_static_with_residual
=
torch
.
compile
(
# type: ignore
self
.
_forward_static_with_residual
)
self
.
_is_compiled
=
True
return
self
.
forward_native
(
x
,
residual
)
return
self
.
forward_native
(
x
,
residual
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment