Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7d266abb
Unverified
Commit
7d266abb
authored
Apr 04, 2026
by
Robert Shaw
Committed by
GitHub
Apr 04, 2026
Browse files
Revert "[vLLM IR] gemma_rms_norm" (#38998)
parent
156405d2
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
67 additions
and
25 deletions
+67
-25
vllm/ir/ops/layernorm.py
vllm/ir/ops/layernorm.py
+3
-2
vllm/kernels/aiter_ops.py
vllm/kernels/aiter_ops.py
+6
-4
vllm/kernels/vllm_c.py
vllm/kernels/vllm_c.py
+2
-5
vllm/kernels/xpu_ops.py
vllm/kernels/xpu_ops.py
+1
-3
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+55
-11
No files found.
vllm/ir/ops/layernorm.py
View file @
7d266abb
...
...
@@ -16,6 +16,7 @@ def rms_norm(
x_var
=
x
if
variance_size
is
None
else
x
[...,
:
variance_size
]
variance
=
x_var
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
epsilon
)
x
=
x
.
to
(
orig_dtype
)
if
weight
is
not
None
:
x
=
x
.
to
(
weight
.
dtype
)
*
weight
return
x
.
to
(
orig_dtype
)
x
=
x
*
weight
return
x
vllm/kernels/aiter_ops.py
View file @
7d266abb
...
...
@@ -36,11 +36,13 @@ AITER_SUPPORTED = is_aiter_found()
rms_no_var_16bit_only
=
(
lambda
x
,
weight
,
epsilon
,
variance_size
=
None
:
variance_size
is
None
and
x
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
)
and
(
weight
is
None
or
weight
.
dtype
==
x
.
dtype
)
and
x
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
,
)
)
"""AITER rms_norm only supports float16 and bfloat16 acts, no var_size override,
and requires weight dtype to match x dtype."""
"""AITER rms_norm only supports float16 and bfloat16 acts and no var_size override."""
@
ir
.
ops
.
rms_norm
.
register_impl
(
...
...
vllm/kernels/vllm_c.py
View file @
7d266abb
...
...
@@ -11,11 +11,8 @@ current_platform.import_kernels()
CUDA_ALIKE
=
current_platform
.
is_cuda_alike
()
"""Most kernels in this file are supported on all CUDA-alike platforms."""
rms_no_var_size
=
(
lambda
x
,
weight
,
epsilon
,
variance_size
=
None
:
variance_size
is
None
and
(
weight
is
None
or
weight
.
dtype
==
x
.
dtype
)
)
"""vLLM kernel does not support variance_size parameter or mismatched weight dtype."""
rms_no_var_size
=
lambda
x
,
weight
,
epsilon
,
variance_size
=
None
:
variance_size
is
None
"""vLLM kernel does not support variance_size parameter."""
@
ir
.
ops
.
rms_norm
.
register_impl
(
...
...
vllm/kernels/xpu_ops.py
View file @
7d266abb
...
...
@@ -18,9 +18,7 @@ def is_xpu_kernels_found() -> bool:
XPU_KERNELS_SUPPORTED
=
is_xpu_kernels_found
()
"""Kernels in this file are supported if vLLM XPU kernels are installed."""
rms_no_var
=
lambda
x
,
weight
,
epsilon
,
variance_size
=
None
:
variance_size
is
None
and
(
weight
is
None
or
weight
.
dtype
==
x
.
dtype
)
rms_no_var
=
lambda
x
,
weight
,
epsilon
,
variance_size
=
None
:
variance_size
is
None
@
ir
.
ops
.
rms_norm
.
register_impl
(
...
...
vllm/model_executor/layers/layernorm.py
View file @
7d266abb
...
...
@@ -376,6 +376,46 @@ class GemmaRMSNorm(CustomOp):
self
.
weight
=
nn
.
Parameter
(
torch
.
zeros
(
hidden_size
))
self
.
variance_epsilon
=
eps
@
staticmethod
def
_forward_static_no_residual
(
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward() without residual."""
orig_dtype
=
x
.
dtype
x
=
x
.
float
()
variance
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
variance_epsilon
)
x
=
x
*
(
1.0
+
weight
.
float
())
x
=
x
.
to
(
orig_dtype
)
return
x
@
staticmethod
def
_forward_static_with_residual
(
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
,
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""PyTorch-native implementation equivalent to forward() with residual."""
orig_dtype
=
x
.
dtype
x
=
(
x
.
float
()
+
residual
.
float
()
if
orig_dtype
==
torch
.
float16
else
x
+
residual
)
residual
=
x
x
=
x
.
float
()
variance
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
variance_epsilon
)
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
x
=
x
*
(
1.0
+
weight
.
float
())
x
=
x
.
to
(
orig_dtype
)
return
x
,
residual
def
forward_native
(
self
,
x
:
torch
.
Tensor
,
...
...
@@ -383,26 +423,30 @@ class GemmaRMSNorm(CustomOp):
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""PyTorch-native implementation equivalent to forward()."""
if
residual
is
None
:
return
ir
.
ops
.
rms_norm
(
x
,
self
.
weight
.
data
.
float
()
+
1.0
,
self
.
variance_epsilon
return
self
.
_forward_static_no_residual
(
self
.
weight
.
data
,
self
.
variance_epsilon
,
x
)
else
:
orig_dtype
=
x
.
dtype
x
=
(
x
.
float
()
+
residual
.
float
()
if
orig_dtype
==
torch
.
float16
else
x
+
residual
return
self
.
_forward_static_with_residual
(
self
.
weight
.
data
,
self
.
variance_epsilon
,
x
,
residual
)
residual
=
x
return
ir
.
ops
.
rms_norm
(
x
,
self
.
weight
.
data
.
float
()
+
1.0
,
self
.
variance_epsilon
).
to
(
orig_dtype
),
residual
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
torch
.
compiler
.
is_compiling
():
return
self
.
forward_native
(
x
,
residual
)
if
not
getattr
(
self
,
"_is_compiled"
,
False
):
self
.
_forward_static_no_residual
=
torch
.
compile
(
# type: ignore
self
.
_forward_static_no_residual
)
self
.
_forward_static_with_residual
=
torch
.
compile
(
# type: ignore
self
.
_forward_static_with_residual
)
self
.
_is_compiled
=
True
return
self
.
forward_native
(
x
,
residual
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment