Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
89cd9235
"vscode:/vscode.git/clone" did not exist on "baea8882bf15b3fc1fbe5e7f629e3acebdfc85bf"
Unverified
Commit
89cd9235
authored
Jan 20, 2025
by
Lianmin Zheng
Committed by
GitHub
Jan 20, 2025
Browse files
Roll back to use vllm custom allreduce (#3006)
parent
dc188132
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
18 additions
and
65 deletions
+18
-65
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+1
-1
python/sglang/srt/distributed/__init__.py
python/sglang/srt/distributed/__init__.py
+3
-3
python/sglang/srt/distributed/communication_op.py
python/sglang/srt/distributed/communication_op.py
+1
-1
python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py
...stributed/device_communicators/custom_all_reduce_utils.py
+0
-1
python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py
...ng/srt/distributed/device_communicators/pynccl_wrapper.py
+1
-1
python/sglang/srt/distributed/device_communicators/shm_broadcast.py
...ang/srt/distributed/device_communicators/shm_broadcast.py
+1
-1
python/sglang/srt/layers/attention/vision.py
python/sglang/srt/layers/attention/vision.py
+2
-2
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+0
-3
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+3
-2
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+6
-50
No files found.
python/sglang/srt/_custom_ops.py
View file @
89cd9235
...
@@ -12,7 +12,7 @@ import torch.library
...
@@ -12,7 +12,7 @@ import torch.library
from
sglang.srt.utils
import
is_hpu
from
sglang.srt.utils
import
is_hpu
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
use_vllm_custom_allreduce
=
os
.
environ
.
get
(
"USE_VLLM_CUSTOM_ALLREDUCE"
,
default
=
Fals
e
)
use_vllm_custom_allreduce
=
os
.
environ
.
get
(
"USE_VLLM_CUSTOM_ALLREDUCE"
,
default
=
Tru
e
)
if
not
is_hpu
():
if
not
is_hpu
():
if
use_vllm_custom_allreduce
:
if
use_vllm_custom_allreduce
:
...
...
python/sglang/srt/distributed/__init__.py
View file @
89cd9235
from
.communication_op
import
*
from
sglang.srt.distributed
.communication_op
import
*
from
.parallel_state
import
*
from
sglang.srt.distributed
.parallel_state
import
*
from
.utils
import
*
from
sglang.srt.distributed
.utils
import
*
python/sglang/srt/distributed/communication_op.py
View file @
89cd9235
...
@@ -4,7 +4,7 @@ from typing import Any, Dict, Optional, Union
...
@@ -4,7 +4,7 @@ from typing import Any, Dict, Optional, Union
import
torch
import
torch
import
torch.distributed
import
torch.distributed
from
.parallel_state
import
get_tp_group
from
sglang.srt.distributed
.parallel_state
import
get_tp_group
def
tensor_model_parallel_all_reduce
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
tensor_model_parallel_all_reduce
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py
View file @
89cd9235
...
@@ -7,7 +7,6 @@ import pickle
...
@@ -7,7 +7,6 @@ import pickle
import
subprocess
import
subprocess
import
sys
import
sys
import
tempfile
import
tempfile
from
functools
import
lru_cache
from
itertools
import
product
from
itertools
import
product
from
typing
import
Dict
,
List
,
Optional
,
Sequence
from
typing
import
Dict
,
List
,
Optional
,
Sequence
...
...
python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py
View file @
89cd9235
...
@@ -57,7 +57,7 @@ def find_nccl_library() -> str:
...
@@ -57,7 +57,7 @@ def find_nccl_library() -> str:
so_file
=
"librccl.so.1"
so_file
=
"librccl.so.1"
else
:
else
:
raise
ValueError
(
"NCCL only supports CUDA and ROCm backends."
)
raise
ValueError
(
"NCCL only supports CUDA and ROCm backends."
)
logger
.
info
(
"Found nccl from library %s"
,
so_file
)
logger
.
debug
(
"Found nccl from library %s"
,
so_file
)
return
so_file
return
so_file
...
...
python/sglang/srt/distributed/device_communicators/shm_broadcast.py
View file @
89cd9235
...
@@ -313,7 +313,7 @@ class MessageQueue:
...
@@ -313,7 +313,7 @@ class MessageQueue:
remote_subscribe_port
=
remote_subscribe_port
,
remote_subscribe_port
=
remote_subscribe_port
,
)
)
logger
.
info
(
"vLLM m
essage queue communication handle: %s"
,
self
.
handle
)
logger
.
debug
(
"M
essage queue communication handle: %s"
,
self
.
handle
)
def
export_handle
(
self
)
->
Handle
:
def
export_handle
(
self
)
->
Handle
:
return
self
.
handle
return
self
.
handle
...
...
python/sglang/srt/layers/attention/vision.py
View file @
89cd9235
...
@@ -5,9 +5,9 @@ from typing import Optional
...
@@ -5,9 +5,9 @@ from typing import Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
,
repeat
from
vllm.distributed
import
parallel_state
from
vllm.distributed
import
utils
as
dist_utils
from
sglang.srt.distributed
import
parallel_state
from
sglang.srt.distributed
import
utils
as
dist_utils
from
sglang.srt.layers.attention.triton_ops.prefill_attention
import
(
from
sglang.srt.layers.attention.triton_ops.prefill_attention
import
(
context_attention_fwd
,
context_attention_fwd
,
)
)
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
89cd9235
...
@@ -33,7 +33,6 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -33,7 +33,6 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch
,
ForwardBatch
,
ForwardMode
,
ForwardMode
,
)
)
from
sglang.srt.utils
import
monkey_patch_vllm_all_gather
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
...
@@ -72,7 +71,6 @@ def patch_model(
...
@@ -72,7 +71,6 @@ def patch_model(
try
:
try
:
if
enable_compile
:
if
enable_compile
:
_to_torch
(
model
,
reverse
=
False
,
batch_size
=
batch_size
)
_to_torch
(
model
,
reverse
=
False
,
batch_size
=
batch_size
)
monkey_patch_vllm_all_gather
()
backup_ca_comm
=
tp_group
.
ca_comm
backup_ca_comm
=
tp_group
.
ca_comm
# Use custom-allreduce here.
# Use custom-allreduce here.
# We found the custom allreduce is much faster than the built-in allreduce in torch,
# We found the custom allreduce is much faster than the built-in allreduce in torch,
...
@@ -88,7 +86,6 @@ def patch_model(
...
@@ -88,7 +86,6 @@ def patch_model(
finally
:
finally
:
if
enable_compile
:
if
enable_compile
:
_to_torch
(
model
,
reverse
=
True
,
batch_size
=
batch_size
)
_to_torch
(
model
,
reverse
=
True
,
batch_size
=
batch_size
)
monkey_patch_vllm_all_gather
(
reverse
=
True
)
tp_group
.
ca_comm
=
backup_ca_comm
tp_group
.
ca_comm
=
backup_ca_comm
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
89cd9235
...
@@ -63,8 +63,8 @@ from sglang.srt.utils import (
...
@@ -63,8 +63,8 @@ from sglang.srt.utils import (
init_custom_process_group
,
init_custom_process_group
,
is_cuda
,
is_cuda
,
is_hip
,
is_hip
,
monkey_patch_p2p_access_check
,
monkey_patch_vllm_gguf_config
,
monkey_patch_vllm_gguf_config
,
monkey_patch_vllm_p2p_access_check
,
set_cpu_offload_max_bytes
,
set_cpu_offload_max_bytes
,
)
)
...
@@ -229,7 +229,8 @@ class ModelRunner:
...
@@ -229,7 +229,8 @@ class ModelRunner:
backend
=
"gloo"
backend
=
"gloo"
if
not
self
.
server_args
.
enable_p2p_check
:
if
not
self
.
server_args
.
enable_p2p_check
:
monkey_patch_vllm_p2p_access_check
(
self
.
gpu_id
)
monkey_patch_p2p_access_check
()
if
self
.
server_args
.
dist_init_addr
:
if
self
.
server_args
.
dist_init_addr
:
dist_init_method
=
f
"tcp://
{
self
.
server_args
.
dist_init_addr
}
"
dist_init_method
=
f
"tcp://
{
self
.
server_args
.
dist_init_addr
}
"
else
:
else
:
...
...
python/sglang/srt/utils.py
View file @
89cd9235
...
@@ -518,66 +518,22 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N
...
@@ -518,66 +518,22 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N
pass
pass
def
monkey_patch_
vllm_
p2p_access_check
(
gpu_id
:
int
):
def
monkey_patch_p2p_access_check
():
"""
"""
Monkey patch the slow p2p access check
in vllm
.
Monkey patch the slow p2p access check.
NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
"""
"""
import
vllm
.distributed.device_communicators.custom_all_reduce_utils
as
tgt
import
sglang.srt
.distributed.device_communicators.custom_all_reduce_utils
as
tgt
setattr
(
tgt
,
"gpu_p2p_access_check"
,
lambda
*
arg
,
**
kwargs
:
True
)
setattr
(
tgt
,
"gpu_p2p_access_check"
,
lambda
*
arg
,
**
kwargs
:
True
)
# Suppress the warnings from this delete function when using sglang.bench_one_batch
# Suppress the warnings from this delete function when using sglang.bench_one_batch
from
vllm.distributed.device_communicators.custom_all_reduce
import
CustomAllreduce
from
sglang.srt.distributed.device_communicators.custom_all_reduce
import
(
CustomAllreduce
,
setattr
(
CustomAllreduce
,
"__del__"
,
lambda
*
args
,
**
kwargs
:
None
)
vllm_all_gather_backup
=
None
def
monkey_patch_vllm_all_gather
(
reverse
:
bool
=
False
):
"""Monkey patch all-gather to remove in-place operations."""
from
torch.distributed
import
_functional_collectives
as
funcol
from
vllm.distributed.parallel_state
import
GroupCoordinator
global
vllm_all_gather_backup
if
vllm_all_gather_backup
is
None
:
vllm_all_gather_backup
=
GroupCoordinator
.
all_gather
def
all_gather
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
world_size
=
self
.
world_size
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
assert
(
-
input_
.
dim
()
<=
dim
<
input_
.
dim
()
),
f
"Invalid dim (
{
dim
}
) for input tensor with shape
{
input_
.
size
()
}
"
if
dim
<
0
:
# Convert negative dim to positive.
dim
+=
input_
.
dim
()
input_size
=
input_
.
size
()
# Allocate output tensor.
output_tensor
=
torch
.
empty
(
(
world_size
,)
+
input_size
,
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
)
output_tensor
=
funcol
.
all_gather_tensor
(
setattr
(
CustomAllreduce
,
"__del__"
,
lambda
*
args
,
**
kwargs
:
None
)
input_
,
gather_dim
=
0
,
group
=
self
.
device_group
).
view
((
world_size
,)
+
input_size
)
# Reshape
output_tensor
=
output_tensor
.
movedim
(
0
,
dim
)
output_tensor
=
output_tensor
.
reshape
(
input_size
[:
dim
]
+
(
world_size
*
input_size
[
dim
],)
+
input_size
[
dim
+
1
:]
)
return
output_tensor
if
reverse
:
setattr
(
GroupCoordinator
,
"all_gather"
,
vllm_all_gather_backup
)
else
:
setattr
(
GroupCoordinator
,
"all_gather"
,
all_gather
)
def
monkey_patch_vllm_gguf_config
():
def
monkey_patch_vllm_gguf_config
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment