Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
424fa81f
"vscode:/vscode.git/clone" did not exist on "e858bfe05167a3bbb064e283da5a1a7709dee24e"
Commit
424fa81f
authored
Jan 05, 2026
by
zhuwenwen
Browse files
back to forward_static
parent
57e945fd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
33 deletions
+33
-33
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+33
-33
No files found.
vllm/model_executor/layers/layernorm.py
View file @
424fa81f
...
...
@@ -148,7 +148,7 @@ class RMSNorm(CustomOp):
@
staticmethod
def
forward_static
(
self
,
#
self,
x
:
torch
.
Tensor
,
variance_epsilon
:
float
,
hidden_size
:
int
,
...
...
@@ -158,45 +158,45 @@ class RMSNorm(CustomOp):
variance_size_override
:
int
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""PyTorch-native implementation equivalent to forward()."""
if
not
torch
.
compiler
.
is_compiling
()
and
envs
.
VLLM_USE_OPT_OP
:
return
self
.
forward_cuda
(
x
,
residual
)
# if not torch.compiler.is_compiling() and envs.VLLM_USE_OPT_OP:
# return self.forward_cuda(x, residual)
# else:
orig_dtype
=
x
.
dtype
x
=
x
.
to
(
torch
.
float32
)
if
residual
is
not
None
:
# residual promoted f16->f32 automatically,
# otherwise Inductor eliminates the casts to and from f16,
# increasing memory usage (and complicating pattern matching)
x
=
x
+
residual
residual
=
x
.
to
(
orig_dtype
)
if
x
.
shape
[
-
1
]
!=
hidden_size
:
raise
ValueError
(
f
"Expected hidden_size to be
{
hidden_size
}
, but found:
{
x
.
shape
[
-
1
]
}
"
)
if
variance_size_override
is
None
:
x_var
=
x
else
:
orig_dtype
=
x
.
dtype
x
=
x
.
to
(
torch
.
float32
)
if
residual
is
not
None
:
# residual promoted f16->f32 automatically,
# otherwise Inductor eliminates the casts to and from f16,
# increasing memory usage (and complicating pattern matching)
x
=
x
+
residual
residual
=
x
.
to
(
orig_dtype
)
if
x
.
shape
[
-
1
]
!=
hidden_size
:
if
hidden_size
<
variance_size_override
:
raise
ValueError
(
f
"Expected hidden_size to be
{
hidden_size
}
, but found:
{
x
.
shape
[
-
1
]
}
"
"Expected hidden_size to be at least "
f
"
{
variance_size_override
}
, but found:
{
hidden_size
}
"
)
if
variance_size_override
is
None
:
x_var
=
x
else
:
if
hidden_size
<
variance_size_override
:
raise
ValueError
(
"Expected hidden_size to be at least "
f
"
{
variance_size_override
}
, but found:
{
hidden_size
}
"
)
x_var
=
x
[:,
:,
:
variance_size_override
]
x_var
=
x
[:,
:,
:
variance_size_override
]
variance
=
x_var
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
variance
=
x_var
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
variance_epsilon
)
x
=
x
.
to
(
orig_dtype
)
if
weight
is
not
None
:
x
=
x
*
weight
if
residual
is
None
:
return
x
else
:
return
x
,
residual
x
=
x
*
torch
.
rsqrt
(
variance
+
variance_epsilon
)
x
=
x
.
to
(
orig_dtype
)
if
weight
is
not
None
:
x
=
x
*
weight
if
residual
is
None
:
return
x
else
:
return
x
,
residual
def
forward_native
(
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment