Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
900edfa8
Unverified
Commit
900edfa8
authored
Apr 29, 2025
by
Harry Mellor
Committed by
GitHub
Apr 29, 2025
Browse files
Transformers backend tweaks (#17365)
Signed-off-by:
Harry Mellor
<
19981378+hmellor@users.noreply.github.com
>
parent
88ad9ec6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
13 deletions
+9
-13
vllm/model_executor/models/transformers.py
vllm/model_executor/models/transformers.py
+9
-13
No files found.
vllm/model_executor/models/transformers.py
View file @
900edfa8
...
@@ -15,7 +15,6 @@
...
@@ -15,7 +15,6 @@
# limitations under the License.
# limitations under the License.
"""Wrapper around `transformers` models"""
"""Wrapper around `transformers` models"""
import
re
import
re
from
itertools
import
chain
from
typing
import
Iterable
,
Literal
,
Optional
,
Union
from
typing
import
Iterable
,
Literal
,
Optional
,
Union
import
torch
import
torch
...
@@ -166,12 +165,9 @@ class TransformersModel(nn.Module):
...
@@ -166,12 +165,9 @@ class TransformersModel(nn.Module):
# Initialize buffers (e.g. rotary embedding inverse frequency)
# Initialize buffers (e.g. rotary embedding inverse frequency)
self
.
init_buffers
(
self
.
model
)
self
.
init_buffers
(
self
.
model
)
# Initialize parameters
# Initialize
any
parameters
that have not had their modules replaced
self
.
init_parameters
(
self
.
model
)
self
.
init_parameters
(
self
.
model
)
# Move remaining meta tensors to device (should happen last)
self
.
meta_to_empty
(
self
.
model
)
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
hidden_size
))
config
.
hidden_size
))
...
@@ -296,6 +292,14 @@ class TransformersModel(nn.Module):
...
@@ -296,6 +292,14 @@ class TransformersModel(nn.Module):
"""
"""
for
name
,
buffer
in
module
.
named_buffers
(
recurse
=
False
):
for
name
,
buffer
in
module
.
named_buffers
(
recurse
=
False
):
if
buffer
.
device
==
torch
.
device
(
"meta"
):
if
buffer
.
device
==
torch
.
device
(
"meta"
):
if
module
==
self
.
model
:
logger
.
warning
(
"To initialize buffers correctly, we instantiate the "
"parent module and and extract the value of the "
"buffer from it. In this case, the parent module is "
"the base model. Instantiating the entire model here "
"risks GPU OOM. Could this buffer be moved to a child "
"module?"
)
new_buffer
=
getattr
(
type
(
module
)(
self
.
config
),
name
)
new_buffer
=
getattr
(
type
(
module
)(
self
.
config
),
name
)
setattr
(
module
,
name
,
new_buffer
)
setattr
(
module
,
name
,
new_buffer
)
for
child
in
module
.
children
():
for
child
in
module
.
children
():
...
@@ -320,14 +324,6 @@ class TransformersModel(nn.Module):
...
@@ -320,14 +324,6 @@ class TransformersModel(nn.Module):
for
child
in
module
.
children
():
for
child
in
module
.
children
():
self
.
init_parameters
(
child
)
self
.
init_parameters
(
child
)
def
meta_to_empty
(
self
,
module
:
nn
.
Module
):
tensors
=
list
(
chain
(
module
.
buffers
(),
module
.
parameters
()))
if
tensors
and
all
(
t
.
device
==
torch
.
device
(
"meta"
)
for
t
in
tensors
):
module
.
to_empty
(
device
=
self
.
device_config
.
device
)
return
# We can stop recursing because to_empty is recursive
for
child
in
module
.
children
():
self
.
meta_to_empty
(
child
)
def
get_input_embeddings
(
self
)
->
nn
.
Module
:
def
get_input_embeddings
(
self
)
->
nn
.
Module
:
return
self
.
model
.
get_input_embeddings
()
return
self
.
model
.
get_input_embeddings
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment