Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
92bb49a7
Unverified
Commit
92bb49a7
authored
Mar 27, 2025
by
fzyzcjy
Committed by
GitHub
Mar 27, 2025
Browse files
Patch PyTorch's bug that cross-process tensor transfer will lead to wrong device (#4565)
parent
6f5cc5eb
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
211 additions
and
2 deletions
+211
-2
python/sglang/srt/entrypoints/verl_engine.py
python/sglang/srt/entrypoints/verl_engine.py
+2
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+4
-2
python/sglang/srt/patch_torch.py
python/sglang/srt/patch_torch.py
+71
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_patch_torch.py
test/srt/test_patch_torch.py
+133
-0
No files found.
python/sglang/srt/entrypoints/verl_engine.py
View file @
92bb49a7
...
@@ -19,6 +19,7 @@ import torch.distributed as dist
...
@@ -19,6 +19,7 @@ import torch.distributed as dist
from
torch.distributed.tensor
import
DeviceMesh
,
DTensor
from
torch.distributed.tensor
import
DeviceMesh
,
DTensor
from
sglang.srt.model_executor.model_runner
import
LocalSerializedTensor
from
sglang.srt.model_executor.model_runner
import
LocalSerializedTensor
from
sglang.srt.patch_torch
import
monkey_patch_torch_reductions
from
sglang.srt.server
import
Engine
from
sglang.srt.server
import
Engine
from
sglang.srt.utils
import
MultiprocessingSerializer
,
broadcast_pyobj
from
sglang.srt.utils
import
MultiprocessingSerializer
,
broadcast_pyobj
...
@@ -30,6 +31,7 @@ class VerlEngine:
...
@@ -30,6 +31,7 @@ class VerlEngine:
nnodes
:
int
=
1
,
nnodes
:
int
=
1
,
**
kwargs
,
**
kwargs
,
):
):
monkey_patch_torch_reductions
()
self
.
_device_mesh_cpu
=
device_mesh_cpu
self
.
_device_mesh_cpu
=
device_mesh_cpu
self
.
_tp_rank
=
device_mesh_cpu
.
get_local_rank
()
self
.
_tp_rank
=
device_mesh_cpu
.
get_local_rank
()
self
.
_tp_size
=
device_mesh_cpu
.
size
()
self
.
_tp_size
=
device_mesh_cpu
.
size
()
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
92bb49a7
...
@@ -64,6 +64,7 @@ from sglang.srt.model_loader.loader import (
...
@@ -64,6 +64,7 @@ from sglang.srt.model_loader.loader import (
)
)
from
sglang.srt.model_loader.utils
import
set_default_torch_dtype
from
sglang.srt.model_loader.utils
import
set_default_torch_dtype
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.patch_torch
import
monkey_patch_torch_reductions
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
...
@@ -1082,8 +1083,9 @@ def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tenso
...
@@ -1082,8 +1083,9 @@ def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tenso
def
_unwrap_tensor
(
tensor
,
tp_rank
):
def
_unwrap_tensor
(
tensor
,
tp_rank
):
if
isinstance
(
tensor
,
LocalSerializedTensor
):
if
isinstance
(
tensor
,
LocalSerializedTensor
):
return
tensor
.
get
(
tp_rank
)
monkey_patch_torch_reductions
()
return
tensor
tensor
=
tensor
.
get
(
tp_rank
)
return
tensor
.
to
(
torch
.
cuda
.
current_device
())
@
dataclass
@
dataclass
...
...
python/sglang/srt/patch_torch.py
0 → 100644
View file @
92bb49a7
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
typing
import
Callable
,
Union
import
torch
from
torch.multiprocessing
import
reductions
def
monkey_patch_torch_reductions
():
"""Monkey patching before Torch https://github.com/pytorch/pytorch/pull/149248 is fixed"""
if
hasattr
(
reductions
,
"_reduce_tensor_original"
):
return
reductions
.
_reduce_tensor_original
=
reductions
.
reduce_tensor
reductions
.
_rebuild_cuda_tensor_original
=
reductions
.
rebuild_cuda_tensor
reductions
.
reduce_tensor
=
_reduce_tensor_modified
reductions
.
rebuild_cuda_tensor
=
_rebuild_cuda_tensor_modified
reductions
.
init_reductions
()
# The signature has not been changed for years, and we will not need this when the next version is released,
# so it looks safe to use a constant.
_REDUCE_TENSOR_ARG_DEVICE_INDEX
=
6
def
_reduce_tensor_modified
(
*
args
,
**
kwargs
):
output_fn
,
output_args
=
reductions
.
_reduce_tensor_original
(
*
args
,
**
kwargs
)
output_args
=
_modify_tuple
(
output_args
,
_REDUCE_TENSOR_ARG_DEVICE_INDEX
,
_device_to_uuid
)
return
output_fn
,
output_args
def
_rebuild_cuda_tensor_modified
(
*
args
):
args
=
_modify_tuple
(
args
,
_REDUCE_TENSOR_ARG_DEVICE_INDEX
,
_device_from_maybe_uuid
)
return
reductions
.
_rebuild_cuda_tensor_original
(
*
args
)
def
_device_to_uuid
(
device
:
int
)
->
str
:
return
str
(
torch
.
cuda
.
get_device_properties
(
device
).
uuid
)
def
_device_from_maybe_uuid
(
device_maybe_uuid
:
Union
[
int
,
str
])
->
int
:
if
isinstance
(
device_maybe_uuid
,
int
):
return
device_maybe_uuid
if
isinstance
(
device_maybe_uuid
,
str
):
for
device
in
range
(
torch
.
cuda
.
device_count
()):
if
str
(
torch
.
cuda
.
get_device_properties
(
device
).
uuid
)
==
device_maybe_uuid
:
return
device
raise
Exception
(
"Invalid device_uuid="
+
device_maybe_uuid
)
raise
Exception
(
f
"Unknown type:
{
device_maybe_uuid
=
}
"
)
def
_modify_tuple
(
t
,
index
:
int
,
modifier
:
Callable
):
return
*
t
[:
index
],
modifier
(
t
[
index
]),
*
t
[
index
+
1
:]
test/srt/run_suite.py
View file @
92bb49a7
...
@@ -46,6 +46,7 @@ suites = {
...
@@ -46,6 +46,7 @@ suites = {
TestFile
(
"test_openai_server.py"
,
124
),
TestFile
(
"test_openai_server.py"
,
124
),
TestFile
(
"test_penalty.py"
,
41
),
TestFile
(
"test_penalty.py"
,
41
),
TestFile
(
"test_page_size.py"
,
60
),
TestFile
(
"test_page_size.py"
,
60
),
TestFile
(
"test_patch_torch.py"
,
60
),
TestFile
(
"test_pytorch_sampling_backend.py"
,
66
),
TestFile
(
"test_pytorch_sampling_backend.py"
,
66
),
TestFile
(
"test_radix_attention.py"
,
167
),
TestFile
(
"test_radix_attention.py"
,
167
),
TestFile
(
"test_reasoning_content.py"
,
89
),
TestFile
(
"test_reasoning_content.py"
,
89
),
...
...
test/srt/test_patch_torch.py
0 → 100644
View file @
92bb49a7
import
os
import
traceback
import
unittest
from
typing
import
Dict
,
List
import
torch
import
torch.multiprocessing
as
mp
from
sglang.srt.patch_torch
import
monkey_patch_torch_reductions
class
TestReleaseMemoryOccupation
(
unittest
.
TestCase
):
def
test_monkey_patch_torch_reductions
(
self
):
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
for
enable_patch
in
[
False
,
True
]:
for
params
in
[
# Same visible devices
dict
(
sender_info
=
dict
(
visible_devices
=
[
0
,
1
],
tensor_device
=
1
,
),
receiver_info
=
dict
(
visible_devices
=
[
0
,
1
],
tensor_device
=
1
,
),
),
# Different visible devices
dict
(
sender_info
=
dict
(
visible_devices
=
[
0
,
1
],
tensor_device
=
1
,
),
receiver_info
=
dict
(
visible_devices
=
[
1
,
0
],
# If enable patch, this should be fixed, and cuda:1 becomes cuda:0
tensor_device
=
0
if
enable_patch
else
1
,
),
),
]:
with
self
.
subTest
(
f
"
{
enable_patch
=
}
{
params
=
}
"
):
self
.
_test_monkey_patch_torch_reductions_core
(
enable_patch
=
enable_patch
,
**
params
)
def
_test_monkey_patch_torch_reductions_core
(
self
,
sender_info
:
Dict
,
receiver_info
:
Dict
,
enable_patch
:
bool
,
):
print
(
f
'test_monkey_patch_torch_reductions_core
{
os
.
environ
.
get
(
"CUDA_VISIBLE_DEVICES"
)
=
}
'
)
cuda_visible_devices_list
:
List
[
int
]
=
[
int
(
x
)
for
x
in
os
.
environ
.
get
(
"CUDA_VISIBLE_DEVICES"
,
"0,1,2,3,4,5,6,7"
).
split
(
","
)
]
processes
=
[]
output_reader
,
output_writer
=
mp
.
Pipe
(
duplex
=
False
)
queue
=
mp
.
Queue
()
for
role
,
info
in
[
(
"sender"
,
sender_info
),
(
"receiver"
,
receiver_info
),
]:
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
","
.
join
(
str
(
cuda_visible_devices_list
[
device
])
for
device
in
info
[
"visible_devices"
]
)
p
=
mp
.
Process
(
target
=
_run_subprocess
,
kwargs
=
dict
(
role
=
role
,
queue
=
queue
,
output_writer
=
output_writer
,
tensor_device
=
info
[
"tensor_device"
],
enable_patch
=
enable_patch
,
),
)
p
.
start
()
processes
.
append
(
p
)
for
_
in
range
(
len
(
processes
)):
self
.
assertTrue
(
output_reader
.
recv
(),
f
"Subprocess has error, please see logs above."
)
for
p
in
processes
:
p
.
join
()
def
_run_subprocess
(
role
:
str
,
queue
:
mp
.
Queue
,
output_writer
,
tensor_device
:
int
,
enable_patch
:
bool
):
print
(
f
'subprocess[
{
role
}
] start
{
os
.
environ
.
get
(
"CUDA_VISIBLE_DEVICES"
)
=
}
'
,
flush
=
True
,
)
if
enable_patch
:
print
(
f
"subprocess[
{
role
}
] execute monkey_patch_torch_reductions"
,
flush
=
True
)
monkey_patch_torch_reductions
()
try
:
if
role
==
"sender"
:
tensor
=
torch
.
tensor
([
1.0
,
2.0
],
device
=
f
"cuda:
{
tensor_device
}
"
)
print
(
f
"sender queue.put
{
tensor
=
}
{
tensor
.
device
=
}
"
)
queue
.
put
(
tensor
)
assert
queue
.
get
()
==
"done"
elif
role
==
"receiver"
:
tensor
=
queue
.
get
()
print
(
f
"receiver queue.get
{
tensor
=
}
{
tensor
.
device
=
}
"
)
assert
str
(
tensor
.
device
)
==
f
"cuda:
{
tensor_device
}
"
queue
.
put
(
"done"
)
else
:
raise
NotImplementedError
execution_ok
=
True
except
Exception
as
e
:
print
(
f
"subprocess[
{
role
}
] has error:
{
e
}
"
,
flush
=
True
)
traceback
.
print_exc
()
execution_ok
=
False
output_writer
.
send
(
execution_ok
)
output_writer
.
close
()
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment