Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
89cd9235
Unverified
Commit
89cd9235
authored
Jan 20, 2025
by
Lianmin Zheng
Committed by
GitHub
Jan 20, 2025
Browse files
Roll back to use vllm custom allreduce (#3006)
parent
dc188132
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
18 additions
and
65 deletions
+18
-65
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+1
-1
python/sglang/srt/distributed/__init__.py
python/sglang/srt/distributed/__init__.py
+3
-3
python/sglang/srt/distributed/communication_op.py
python/sglang/srt/distributed/communication_op.py
+1
-1
python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py
...stributed/device_communicators/custom_all_reduce_utils.py
+0
-1
python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py
...ng/srt/distributed/device_communicators/pynccl_wrapper.py
+1
-1
python/sglang/srt/distributed/device_communicators/shm_broadcast.py
...ang/srt/distributed/device_communicators/shm_broadcast.py
+1
-1
python/sglang/srt/layers/attention/vision.py
python/sglang/srt/layers/attention/vision.py
+2
-2
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+0
-3
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+3
-2
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+6
-50
No files found.
python/sglang/srt/_custom_ops.py
View file @
89cd9235
...
@@ -12,7 +12,7 @@ import torch.library
...
@@ -12,7 +12,7 @@ import torch.library
from
sglang.srt.utils
import
is_hpu
from
sglang.srt.utils
import
is_hpu
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
use_vllm_custom_allreduce
=
os
.
environ
.
get
(
"USE_VLLM_CUSTOM_ALLREDUCE"
,
default
=
Fals
e
)
use_vllm_custom_allreduce
=
os
.
environ
.
get
(
"USE_VLLM_CUSTOM_ALLREDUCE"
,
default
=
Tru
e
)
if
not
is_hpu
():
if
not
is_hpu
():
if
use_vllm_custom_allreduce
:
if
use_vllm_custom_allreduce
:
...
...
python/sglang/srt/distributed/__init__.py
View file @
89cd9235
from
.communication_op
import
*
from
sglang.srt.distributed
.communication_op
import
*
from
.parallel_state
import
*
from
sglang.srt.distributed
.parallel_state
import
*
from
.utils
import
*
from
sglang.srt.distributed
.utils
import
*
python/sglang/srt/distributed/communication_op.py
View file @
89cd9235
...
@@ -4,7 +4,7 @@ from typing import Any, Dict, Optional, Union
...
@@ -4,7 +4,7 @@ from typing import Any, Dict, Optional, Union
import
torch
import
torch
import
torch.distributed
import
torch.distributed
from
.parallel_state
import
get_tp_group
from
sglang.srt.distributed
.parallel_state
import
get_tp_group
def
tensor_model_parallel_all_reduce
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
tensor_model_parallel_all_reduce
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py
View file @
89cd9235
...
@@ -7,7 +7,6 @@ import pickle
...
@@ -7,7 +7,6 @@ import pickle
import
subprocess
import
subprocess
import
sys
import
sys
import
tempfile
import
tempfile
from
functools
import
lru_cache
from
itertools
import
product
from
itertools
import
product
from
typing
import
Dict
,
List
,
Optional
,
Sequence
from
typing
import
Dict
,
List
,
Optional
,
Sequence
...
...
python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py
View file @
89cd9235
...
@@ -57,7 +57,7 @@ def find_nccl_library() -> str:
...
@@ -57,7 +57,7 @@ def find_nccl_library() -> str:
so_file
=
"librccl.so.1"
so_file
=
"librccl.so.1"
else
:
else
:
raise
ValueError
(
"NCCL only supports CUDA and ROCm backends."
)
raise
ValueError
(
"NCCL only supports CUDA and ROCm backends."
)
logger
.
info
(
"Found nccl from library %s"
,
so_file
)
logger
.
debug
(
"Found nccl from library %s"
,
so_file
)
return
so_file
return
so_file
...
...
python/sglang/srt/distributed/device_communicators/shm_broadcast.py
View file @
89cd9235
...
@@ -313,7 +313,7 @@ class MessageQueue:
...
@@ -313,7 +313,7 @@ class MessageQueue:
remote_subscribe_port
=
remote_subscribe_port
,
remote_subscribe_port
=
remote_subscribe_port
,
)
)
logger
.
info
(
"vLLM m
essage queue communication handle: %s"
,
self
.
handle
)
logger
.
debug
(
"M
essage queue communication handle: %s"
,
self
.
handle
)
def
export_handle
(
self
)
->
Handle
:
def
export_handle
(
self
)
->
Handle
:
return
self
.
handle
return
self
.
handle
...
...
python/sglang/srt/layers/attention/vision.py
View file @
89cd9235
...
@@ -5,9 +5,9 @@ from typing import Optional
...
@@ -5,9 +5,9 @@ from typing import Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
,
repeat
from
vllm.distributed
import
parallel_state
from
vllm.distributed
import
utils
as
dist_utils
from
sglang.srt.distributed
import
parallel_state
from
sglang.srt.distributed
import
utils
as
dist_utils
from
sglang.srt.layers.attention.triton_ops.prefill_attention
import
(
from
sglang.srt.layers.attention.triton_ops.prefill_attention
import
(
context_attention_fwd
,
context_attention_fwd
,
)
)
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
89cd9235
...
@@ -33,7 +33,6 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -33,7 +33,6 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch
,
ForwardBatch
,
ForwardMode
,
ForwardMode
,
)
)
from
sglang.srt.utils
import
monkey_patch_vllm_all_gather
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
...
@@ -72,7 +71,6 @@ def patch_model(
...
@@ -72,7 +71,6 @@ def patch_model(
try
:
try
:
if
enable_compile
:
if
enable_compile
:
_to_torch
(
model
,
reverse
=
False
,
batch_size
=
batch_size
)
_to_torch
(
model
,
reverse
=
False
,
batch_size
=
batch_size
)
monkey_patch_vllm_all_gather
()
backup_ca_comm
=
tp_group
.
ca_comm
backup_ca_comm
=
tp_group
.
ca_comm
# Use custom-allreduce here.
# Use custom-allreduce here.
# We found the custom allreduce is much faster than the built-in allreduce in torch,
# We found the custom allreduce is much faster than the built-in allreduce in torch,
...
@@ -88,7 +86,6 @@ def patch_model(
...
@@ -88,7 +86,6 @@ def patch_model(
finally
:
finally
:
if
enable_compile
:
if
enable_compile
:
_to_torch
(
model
,
reverse
=
True
,
batch_size
=
batch_size
)
_to_torch
(
model
,
reverse
=
True
,
batch_size
=
batch_size
)
monkey_patch_vllm_all_gather
(
reverse
=
True
)
tp_group
.
ca_comm
=
backup_ca_comm
tp_group
.
ca_comm
=
backup_ca_comm
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
89cd9235
...
@@ -63,8 +63,8 @@ from sglang.srt.utils import (
...
@@ -63,8 +63,8 @@ from sglang.srt.utils import (
init_custom_process_group
,
init_custom_process_group
,
is_cuda
,
is_cuda
,
is_hip
,
is_hip
,
monkey_patch_p2p_access_check
,
monkey_patch_vllm_gguf_config
,
monkey_patch_vllm_gguf_config
,
monkey_patch_vllm_p2p_access_check
,
set_cpu_offload_max_bytes
,
set_cpu_offload_max_bytes
,
)
)
...
@@ -229,7 +229,8 @@ class ModelRunner:
...
@@ -229,7 +229,8 @@ class ModelRunner:
backend
=
"gloo"
backend
=
"gloo"
if
not
self
.
server_args
.
enable_p2p_check
:
if
not
self
.
server_args
.
enable_p2p_check
:
monkey_patch_vllm_p2p_access_check
(
self
.
gpu_id
)
monkey_patch_p2p_access_check
()
if
self
.
server_args
.
dist_init_addr
:
if
self
.
server_args
.
dist_init_addr
:
dist_init_method
=
f
"tcp://
{
self
.
server_args
.
dist_init_addr
}
"
dist_init_method
=
f
"tcp://
{
self
.
server_args
.
dist_init_addr
}
"
else
:
else
:
...
...
python/sglang/srt/utils.py
View file @
89cd9235
...
@@ -518,66 +518,22 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N
...
@@ -518,66 +518,22 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N
pass
pass
def
monkey_patch_
vllm_
p2p_access_check
(
gpu_id
:
int
):
def
monkey_patch_p2p_access_check
():
"""
"""
Monkey patch the slow p2p access check
in vllm
.
Monkey patch the slow p2p access check.
NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
"""
"""
import
vllm
.distributed.device_communicators.custom_all_reduce_utils
as
tgt
import
sglang.srt
.distributed.device_communicators.custom_all_reduce_utils
as
tgt
setattr
(
tgt
,
"gpu_p2p_access_check"
,
lambda
*
arg
,
**
kwargs
:
True
)
setattr
(
tgt
,
"gpu_p2p_access_check"
,
lambda
*
arg
,
**
kwargs
:
True
)
# Suppress the warnings from this delete function when using sglang.bench_one_batch
# Suppress the warnings from this delete function when using sglang.bench_one_batch
from
vllm.distributed.device_communicators.custom_all_reduce
import
CustomAllreduce
from
sglang.srt.distributed.device_communicators.custom_all_reduce
import
(
CustomAllreduce
,
setattr
(
CustomAllreduce
,
"__del__"
,
lambda
*
args
,
**
kwargs
:
None
)
vllm_all_gather_backup
=
None
def
monkey_patch_vllm_all_gather
(
reverse
:
bool
=
False
):
"""Monkey patch all-gather to remove in-place operations."""
from
torch.distributed
import
_functional_collectives
as
funcol
from
vllm.distributed.parallel_state
import
GroupCoordinator
global
vllm_all_gather_backup
if
vllm_all_gather_backup
is
None
:
vllm_all_gather_backup
=
GroupCoordinator
.
all_gather
def
all_gather
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
world_size
=
self
.
world_size
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
assert
(
-
input_
.
dim
()
<=
dim
<
input_
.
dim
()
),
f
"Invalid dim (
{
dim
}
) for input tensor with shape
{
input_
.
size
()
}
"
if
dim
<
0
:
# Convert negative dim to positive.
dim
+=
input_
.
dim
()
input_size
=
input_
.
size
()
# Allocate output tensor.
output_tensor
=
torch
.
empty
(
(
world_size
,)
+
input_size
,
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
)
output_tensor
=
funcol
.
all_gather_tensor
(
setattr
(
CustomAllreduce
,
"__del__"
,
lambda
*
args
,
**
kwargs
:
None
)
input_
,
gather_dim
=
0
,
group
=
self
.
device_group
).
view
((
world_size
,)
+
input_size
)
# Reshape
output_tensor
=
output_tensor
.
movedim
(
0
,
dim
)
output_tensor
=
output_tensor
.
reshape
(
input_size
[:
dim
]
+
(
world_size
*
input_size
[
dim
],)
+
input_size
[
dim
+
1
:]
)
return
output_tensor
if
reverse
:
setattr
(
GroupCoordinator
,
"all_gather"
,
vllm_all_gather_backup
)
else
:
setattr
(
GroupCoordinator
,
"all_gather"
,
all_gather
)
def
monkey_patch_vllm_gguf_config
():
def
monkey_patch_vllm_gguf_config
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment