Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b3856bef
Unverified
Commit
b3856bef
authored
Aug 22, 2024
by
Woosuk Kwon
Committed by
GitHub
Aug 22, 2024
Browse files
[Misc] Use torch.compile for GemmaRMSNorm (#7642)
parent
8c6f694a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
6 deletions
+23
-6
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+23
-6
No files found.
vllm/model_executor/layers/layernorm.py
View file @
b3856bef
...
...
@@ -114,10 +114,12 @@ class GemmaRMSNorm(CustomOp):
self
.
weight
=
nn
.
Parameter
(
torch
.
zeros
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward_native
(
self
,
@
staticmethod
def
forward_static
(
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype
=
x
.
dtype
...
...
@@ -127,17 +129,32 @@ class GemmaRMSNorm(CustomOp):
x
=
x
.
float
()
variance
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
x
=
x
*
torch
.
rsqrt
(
variance
+
variance_epsilon
)
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
x
=
x
*
(
1.0
+
self
.
weight
.
float
())
x
=
x
*
(
1.0
+
weight
.
float
())
x
=
x
.
to
(
orig_dtype
)
return
x
if
residual
is
None
else
(
x
,
residual
)
def
forward_native
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
"""PyTorch-native implementation equivalent to forward()."""
return
self
.
forward_static
(
self
.
weight
.
data
,
self
.
variance_epsilon
,
x
,
residual
)
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
# TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm.
if
torch
.
compiler
.
is_compiling
():
return
self
.
forward_native
(
x
,
residual
)
if
not
getattr
(
self
,
"_is_compiled"
,
False
):
self
.
forward_static
=
torch
.
compile
(
# type: ignore
self
.
forward_static
)
self
.
_is_compiled
=
True
return
self
.
forward_native
(
x
,
residual
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment