Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a6149aa5
Unverified
Commit
a6149aa5
authored
Sep 19, 2025
by
Chendi.Xue
Committed by
GitHub
Sep 19, 2025
Browse files
[OOT] Support sync_model_loading for OOT (#25126)
Signed-off-by:
Chendi Xue
<
Chendi.Xue@intel.com
>
parent
6c8a3c09
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
33 additions
and
17 deletions
+33
-17
vllm/model_executor/parameter.py
vllm/model_executor/parameter.py
+3
-3
vllm/model_executor/utils.py
vllm/model_executor/utils.py
+3
-14
vllm/platforms/interface.py
vllm/platforms/interface.py
+23
-0
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+4
-0
No files found.
vllm/model_executor/parameter.py
View file @
a6149aa5
...
@@ -12,7 +12,6 @@ from torch.nn import Parameter
...
@@ -12,7 +12,6 @@ from torch.nn import Parameter
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.utils
import
_make_synced_weight_loader
__all__
=
[
__all__
=
[
"BasevLLMParameter"
,
"PackedvLLMParameter"
,
"PerTensorScaleParameter"
,
"BasevLLMParameter"
,
"PackedvLLMParameter"
,
"PerTensorScaleParameter"
,
...
@@ -53,8 +52,9 @@ class BasevLLMParameter(Parameter):
...
@@ -53,8 +52,9 @@ class BasevLLMParameter(Parameter):
# This sometimes causes OOM errors during model loading. To avoid this,
# This sometimes causes OOM errors during model loading. To avoid this,
# we sync the param tensor after its weight loader is called.
# we sync the param tensor after its weight loader is called.
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
if
current_platform
.
is_tpu
():
if
current_platform
.
use_sync_weight_loader
():
weight_loader
=
_make_synced_weight_loader
(
weight_loader
)
weight_loader
=
current_platform
.
make_synced_weight_loader
(
weight_loader
)
self
.
_weight_loader
=
weight_loader
self
.
_weight_loader
=
weight_loader
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
...
...
vllm/model_executor/utils.py
View file @
a6149aa5
...
@@ -44,23 +44,12 @@ def set_weight_attrs(
...
@@ -44,23 +44,12 @@ def set_weight_attrs(
# TODO(woosuk): Remove this hack once we have a better solution.
# TODO(woosuk): Remove this hack once we have a better solution.
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
if
current_platform
.
is_tpu
()
and
key
==
"weight_loader"
:
if
current_platform
.
use_sync_weight_loader
(
value
=
_make_synced_weight_loader
(
value
)
)
and
key
==
"weight_loader"
:
value
=
current_platform
.
make_synced_weight_loader
(
value
)
setattr
(
weight
,
key
,
value
)
setattr
(
weight
,
key
,
value
)
def
_make_synced_weight_loader
(
original_weight_loader
):
def
_synced_weight_loader
(
param
,
*
args
,
**
kwargs
):
out
=
original_weight_loader
(
param
,
*
args
,
**
kwargs
)
# torch._sync doesn't support, is not needed for CPU tensors.
if
param
.
device
!=
torch
.
device
(
"cpu"
):
torch
.
_sync
(
param
)
return
out
return
_synced_weight_loader
def
get_packed_modules_mapping
(
model
:
torch
.
nn
.
Module
)
->
dict
[
str
,
list
[
str
]]:
def
get_packed_modules_mapping
(
model
:
torch
.
nn
.
Module
)
->
dict
[
str
,
list
[
str
]]:
parent_map
=
getattr
(
model
,
"packed_modules_mapping"
,
None
)
parent_map
=
getattr
(
model
,
"packed_modules_mapping"
,
None
)
parent_map
=
copy
.
deepcopy
(
parent_map
)
if
parent_map
is
not
None
else
{}
parent_map
=
copy
.
deepcopy
(
parent_map
)
if
parent_map
is
not
None
else
{}
...
...
vllm/platforms/interface.py
View file @
a6149aa5
...
@@ -594,6 +594,29 @@ class Platform:
...
@@ -594,6 +594,29 @@ class Platform:
"""
"""
return
False
return
False
@
classmethod
def
use_sync_weight_loader
(
cls
)
->
bool
:
"""
Returns if the current platform needs to sync weight loader.
"""
return
False
@
classmethod
def
make_synced_weight_loader
(
cls
,
original_weight_loader
):
"""
Wrap the original weight loader to make it synced.
"""
if
not
cls
.
use_sync_weight_loader
():
return
original_weight_loader
def
_synced_weight_loader
(
param
,
*
args
,
**
kwargs
):
out
=
original_weight_loader
(
param
,
*
args
,
**
kwargs
)
if
param
.
device
!=
torch
.
device
(
"cpu"
):
torch
.
_sync
(
param
)
return
out
return
_synced_weight_loader
class
UnspecifiedPlatform
(
Platform
):
class
UnspecifiedPlatform
(
Platform
):
_enum
=
PlatformEnum
.
UNSPECIFIED
_enum
=
PlatformEnum
.
UNSPECIFIED
...
...
vllm/platforms/tpu.py
View file @
a6149aa5
...
@@ -226,6 +226,10 @@ class TpuPlatform(Platform):
...
@@ -226,6 +226,10 @@ class TpuPlatform(Platform):
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
src_cache
,
True
)
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
src_cache
,
True
)
dst_cache
[
dst_block_indices
]
=
src_cache
[
src_block_indices
].
cpu
()
dst_cache
[
dst_block_indices
]
=
src_cache
[
src_block_indices
].
cpu
()
@
classmethod
def
use_sync_weight_loader
(
cls
)
->
bool
:
return
True
try
:
try
:
from
tpu_commons.platforms
import
TpuPlatform
as
TpuCommonsPlatform
from
tpu_commons.platforms
import
TpuPlatform
as
TpuCommonsPlatform
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment