Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6b847a9a
"src/vscode:/vscode.git/clone" did not exist on "8d99d30a82d8a623b65b19fb0ba0ef473763091e"
Unverified
Commit
6b847a9a
authored
Aug 10, 2025
by
JiLi
Committed by
GitHub
Aug 10, 2025
Browse files
Optimize: Cache CUDA device to reduce redundant calls during tensor l… (#8996)
parent
473400e4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
4 deletions
+7
-4
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+7
-4
No files found.
python/sglang/srt/model_executor/model_runner.py
View file @
6b847a9a
...
@@ -895,8 +895,12 @@ class ModelRunner:
...
@@ -895,8 +895,12 @@ class ModelRunner:
named_tensors
:
List
[
Tuple
[
str
,
Union
[
torch
.
Tensor
,
"LocalSerializedTensor"
]]],
named_tensors
:
List
[
Tuple
[
str
,
Union
[
torch
.
Tensor
,
"LocalSerializedTensor"
]]],
load_format
:
Optional
[
str
]
=
None
,
load_format
:
Optional
[
str
]
=
None
,
):
):
monkey_patch_torch_reductions
()
# We need to get device after patch otherwise the device would be wrong
infered_device
=
torch
.
cuda
.
current_device
()
named_tensors
=
[
named_tensors
=
[
(
name
,
_unwrap_tensor
(
tensor
,
tp_rank
=
self
.
tp_rank
))
(
name
,
_unwrap_tensor
(
tensor
,
tp_rank
=
self
.
tp_rank
,
device
=
infered_device
))
for
name
,
tensor
in
named_tensors
for
name
,
tensor
in
named_tensors
]
]
if
load_format
==
"direct"
:
if
load_format
==
"direct"
:
...
@@ -1809,11 +1813,10 @@ def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tenso
...
@@ -1809,11 +1813,10 @@ def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tenso
default_weight_loader
(
params_dict
[
name
],
tensor
)
default_weight_loader
(
params_dict
[
name
],
tensor
)
def
_unwrap_tensor
(
tensor
,
tp_rank
):
def
_unwrap_tensor
(
tensor
,
tp_rank
,
device
):
if
isinstance
(
tensor
,
LocalSerializedTensor
):
if
isinstance
(
tensor
,
LocalSerializedTensor
):
monkey_patch_torch_reductions
()
tensor
=
tensor
.
get
(
tp_rank
)
tensor
=
tensor
.
get
(
tp_rank
)
return
tensor
.
to
(
torch
.
cuda
.
current_
device
()
)
return
tensor
.
to
(
device
)
@
dataclass
@
dataclass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment