Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8e1cddcd
Unverified
Commit
8e1cddcd
authored
Oct 17, 2024
by
Woosuk Kwon
Committed by
GitHub
Oct 17, 2024
Browse files
[TPU] Call torch._sync(param) during weight loading (#9437)
parent
5e443b59
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
0 deletions
+22
-0
vllm/model_executor/utils.py
vllm/model_executor/utils.py
+22
-0
No files found.
vllm/model_executor/utils.py
View file @
8e1cddcd
...
@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional
...
@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional
import
torch
import
torch
from
vllm.platforms
import
current_platform
from
vllm.utils
import
seed_everything
from
vllm.utils
import
seed_everything
...
@@ -28,4 +29,25 @@ def set_weight_attrs(
...
@@ -28,4 +29,25 @@ def set_weight_attrs(
for
key
,
value
in
weight_attrs
.
items
():
for
key
,
value
in
weight_attrs
.
items
():
assert
not
hasattr
(
assert
not
hasattr
(
weight
,
key
),
(
f
"Overwriting existing tensor attribute:
{
key
}
"
)
weight
,
key
),
(
f
"Overwriting existing tensor attribute:
{
key
}
"
)
# NOTE(woosuk): During weight loading, we often do something like:
# narrowed_tensor = param.data.narrow(0, offset, len)
# narrowed_tensor.copy_(real_weight)
# expecting narrowed_tensor and param.data to share the same storage.
# However, on TPUs, narrowed_tensor will lazily propagate to the base
# tensor, which is param.data, leading to the redundant memory usage.
# This sometimes causes OOM errors during model loading. To avoid this,
# we sync the param tensor after its weight loader is called.
# TODO(woosuk): Remove this hack once we have a better solution.
if
current_platform
.
is_tpu
()
and
key
==
"weight_loader"
:
value
=
_make_synced_weight_loader
(
value
)
setattr
(
weight
,
key
,
value
)
setattr
(
weight
,
key
,
value
)
def
_make_synced_weight_loader
(
original_weight_loader
):
def
_synced_weight_loader
(
param
,
*
args
,
**
kwargs
):
original_weight_loader
(
param
,
*
args
,
**
kwargs
)
torch
.
_sync
(
param
)
return
_synced_weight_loader
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment