Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8e3797be
Unverified
Commit
8e3797be
authored
Jun 05, 2025
by
zyksir
Committed by
GitHub
Jun 04, 2025
Browse files
support 1 shot allreduce in 1-node and 2-node using mscclpp (#6277)
parent
4474eaf5
Changes
20
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2177 additions
and
12 deletions
+2177
-12
benchmark/kernels/all_reduce/benchmark_mscclpp.py
benchmark/kernels/all_reduce/benchmark_mscclpp.py
+224
-0
benchmark/lora/launch_server.py
benchmark/lora/launch_server.py
+7
-0
docs/backend/server_arguments.md
docs/backend/server_arguments.md
+1
-0
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+34
-0
python/sglang/srt/distributed/device_communicators/pymscclpp.py
.../sglang/srt/distributed/device_communicators/pymscclpp.py
+315
-0
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+48
-4
python/sglang/srt/layers/dp_attention.py
python/sglang/srt/layers/dp_attention.py
+6
-5
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+2
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-0
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+31
-1
sgl-kernel/Makefile
sgl-kernel/Makefile
+2
-2
sgl-kernel/build.sh
sgl-kernel/build.sh
+3
-0
sgl-kernel/csrc/allreduce/mscclpp_allreduce.cu
sgl-kernel/csrc/allreduce/mscclpp_allreduce.cu
+140
-0
sgl-kernel/csrc/allreduce/mscclpp_allreduce.cuh
sgl-kernel/csrc/allreduce/mscclpp_allreduce.cuh
+779
-0
sgl-kernel/csrc/allreduce/test_mscclpp_allreduce.cu
sgl-kernel/csrc/allreduce/test_mscclpp_allreduce.cu
+153
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+9
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+12
-0
sgl-kernel/python/sgl_kernel/allreduce.py
sgl-kernel/python/sgl_kernel/allreduce.py
+54
-0
sgl-kernel/tests/test_mscclpp.py
sgl-kernel/tests/test_mscclpp.py
+146
-0
test/srt/test_mscclpp.py
test/srt/test_mscclpp.py
+205
-0
No files found.
benchmark/kernels/all_reduce/benchmark_mscclpp.py
0 → 100644
View file @
8e3797be
"""For Now, MSCCL is only supported on TP16 and TP8 case
export WORLD_SIZE=1
export RANK=0
export MASTER_ADDR=127.0.0.1
export MASTER_PORT=12345
torchrun --nproc_per_node gpu
\
--nnodes $WORLD_SIZE
\
--node_rank $RANK
\
--master_addr $MASTER_ADDR
\
--master_port $MASTER_PORT benchmark/kernels/all_reduce/benchmark_mscclpp.py
"""
import
os
from
contextlib
import
nullcontext
from
typing
import
List
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
from
sglang.srt.distributed
import
init_distributed_environment
from
sglang.srt.distributed.device_communicators.pymscclpp
import
PyMscclppCommunicator
from
sglang.srt.distributed.device_communicators.pynccl
import
PyNcclCommunicator
from
sglang.srt.distributed.parallel_state
import
(
get_tensor_model_parallel_group
,
graph_capture
,
initialize_model_parallel
,
set_mscclpp_all_reduce
,
)
def
torch_allreduce
(
torch_input
:
torch
.
Tensor
,
group
:
ProcessGroup
)
->
torch
.
Tensor
:
dist
.
all_reduce
(
torch_input
,
group
=
group
)
return
torch_input
def
msccl_allreduce
(
msccl_input
:
torch
.
Tensor
,
msccl_comm
:
PyMscclppCommunicator
)
->
torch
.
Tensor
:
return
msccl_comm
.
all_reduce
(
msccl_input
)
def
pynccl_allreduce
(
msccl_input
:
torch
.
Tensor
,
pynccl_comm
:
PyNcclCommunicator
)
->
torch
.
Tensor
:
pynccl_comm
.
all_reduce
(
msccl_input
)
return
msccl_input
def
_bench_graph_time
(
func
,
inp_randn
,
warmup_loop
=
2
,
graph_loop
=
10
,
test_loop
=
10
):
graph_input
=
inp_randn
.
clone
()
with
graph_capture
()
as
graph_capture_context
:
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
,
stream
=
graph_capture_context
.
stream
):
for
_
in
range
(
graph_loop
):
graph_out
=
func
(
graph_input
)
graph
.
replay
()
func_output
=
graph_out
.
clone
()
for
_
in
range
(
warmup_loop
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
latencies
:
List
[
float
]
=
[]
for
_
in
range
(
test_loop
):
torch
.
cuda
.
synchronize
()
dist
.
barrier
()
start_event
.
record
()
graph
.
replay
()
end_event
.
record
()
end_event
.
synchronize
()
latencies
.
append
(
start_event
.
elapsed_time
(
end_event
))
func_cost_us
=
sum
(
latencies
)
/
len
(
latencies
)
/
graph_loop
*
1000
graph
.
reset
()
return
func_output
,
func_cost_us
def
_bench_eager_time
(
func
,
inp_randn
,
warmup_loop
=
2
,
test_loop
=
10
):
eager_input
=
inp_randn
.
clone
()
eager_output
=
func
(
eager_input
)
func_output
=
eager_output
.
clone
()
for
_
in
range
(
warmup_loop
):
func
(
eager_input
)
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
torch
.
cuda
.
synchronize
()
start_event
.
record
()
for
_
in
range
(
test_loop
):
func
(
eager_input
)
end_event
.
record
()
torch
.
cuda
.
synchronize
()
func_cost_us
=
start_event
.
elapsed_time
(
end_event
)
/
test_loop
*
1000
return
func_output
,
func_cost_us
def
get_torch_prof_ctx
(
do_prof
:
bool
):
ctx
=
(
torch
.
profiler
.
profile
(
activities
=
[
torch
.
profiler
.
ProfilerActivity
.
CPU
,
torch
.
profiler
.
ProfilerActivity
.
CUDA
,
],
record_shapes
=
True
,
with_stack
=
True
,
)
if
do_prof
else
nullcontext
()
)
return
ctx
def
human_readable_size
(
size
,
decimal_places
=
1
):
for
unit
in
[
"B"
,
"KiB"
,
"MiB"
,
"GiB"
,
"TiB"
,
"PiB"
]:
if
size
<
1024.0
or
unit
==
"PiB"
:
break
size
/=
1024.0
return
f
"
{
size
:.
{
decimal_places
}
f
}
{
unit
}
"
try
:
from
tabulate
import
tabulate
except
ImportError
:
print
(
"tabulate not installed, skipping table printing"
)
tabulate
=
None
def
print_markdown_table
(
data
):
if
tabulate
is
not
None
:
print
(
tabulate
(
data
,
headers
=
"keys"
,
tablefmt
=
"github"
))
return
headers
=
data
[
0
].
keys
()
header_row
=
"| "
+
" | "
.
join
(
headers
)
+
" |"
separator
=
"| "
+
" | "
.
join
([
"---"
]
*
len
(
headers
))
+
" |"
rows
=
[]
for
item
in
data
:
row
=
"| "
+
" | "
.
join
(
str
(
item
[
key
])
for
key
in
headers
)
+
" |"
rows
.
append
(
row
)
markdown_table
=
"
\n
"
.
join
([
header_row
,
separator
]
+
rows
)
print
(
markdown_table
)
if
__name__
==
"__main__"
:
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
"%(asctime)s - %(levelname)s - %(message)s"
,
datefmt
=
"%Y-%m-%d %H:%M:%S"
,
force
=
True
,
)
if
not
dist
.
is_initialized
():
dist
.
init_process_group
(
backend
=
"nccl"
)
world
,
world_size
=
dist
.
group
.
WORLD
,
dist
.
get_world_size
()
rank
=
dist
.
get_rank
()
torch
.
cuda
.
set_device
(
rank
%
8
)
device
=
torch
.
cuda
.
current_device
()
set_mscclpp_all_reduce
(
True
)
init_distributed_environment
(
world_size
=
world_size
,
rank
=
rank
,
local_rank
=
rank
%
8
,
)
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
group
=
get_tensor_model_parallel_group
().
device_group
cpu_group
=
get_tensor_model_parallel_group
().
cpu_group
pynccl_comm
=
get_tensor_model_parallel_group
().
pynccl_comm
pymscclpp_comm
=
get_tensor_model_parallel_group
().
pymscclpp_comm
dist
.
barrier
()
profile
=
False
dtype
=
torch
.
bfloat16
ctx
=
get_torch_prof_ctx
(
profile
)
result
=
[]
with
ctx
:
for
i
in
range
(
10
,
20
):
sz
=
2
**
i
if
sz
*
dtype
.
itemsize
>
2
**
20
:
break
inp_randn
=
torch
.
randint
(
1
,
16
,
(
sz
,),
dtype
=
dtype
,
device
=
device
)
memory
=
torch
.
empty_like
(
inp_randn
)
memory_out
=
torch
.
empty_like
(
memory
)
torch_eager_output
,
torch_eager_time
=
_bench_eager_time
(
lambda
inp
:
torch_allreduce
(
inp
,
group
),
inp_randn
)
msccl_eager_output
,
msccl_eager_time
=
_bench_eager_time
(
lambda
inp
:
msccl_allreduce
(
inp
,
pymscclpp_comm
),
inp_randn
)
msccl_graph_output
,
msccl_graph_time
=
_bench_graph_time
(
lambda
inp
:
msccl_allreduce
(
inp
,
pymscclpp_comm
),
inp_randn
)
# since pynccl is inplace op, this return result is not correct if graph loop > 1
_
,
pynccl_graph_time
=
_bench_graph_time
(
lambda
inp
:
pynccl_allreduce
(
inp
,
pynccl_comm
),
inp_randn
)
torch
.
testing
.
assert_close
(
torch_eager_output
,
msccl_graph_output
)
torch
.
testing
.
assert_close
(
torch_eager_output
,
msccl_eager_output
)
result
.
append
(
{
"msg_size"
:
human_readable_size
(
inp_randn
.
nbytes
),
"torch eager time"
:
torch_eager_time
,
"msccl eager time"
:
msccl_eager_time
,
"msccl graph time"
:
msccl_graph_time
,
"pynccl graph time"
:
pynccl_graph_time
,
}
)
if
rank
==
0
:
print
(
f
"sz=
{
sz
}
, dtype=
{
dtype
}
: correctness check PASS!"
)
if
rank
==
0
:
print_markdown_table
(
result
)
if
profile
:
prof_dir
=
f
"prof/msccl"
os
.
makedirs
(
prof_dir
,
exist_ok
=
True
)
ctx
.
export_chrome_trace
(
f
"
{
prof_dir
}
/trace_rank
{
dist
.
get_rank
()
}
.json.gz"
)
benchmark/lora/launch_server.py
View file @
8e3797be
...
...
@@ -26,6 +26,8 @@ def launch_server(args):
cmd
+=
f
"--tp-size
{
args
.
tp_size
}
"
if
args
.
disable_custom_all_reduce
:
cmd
+=
"--disable-custom-all-reduce"
if
args
.
enable_mscclpp
:
cmd
+=
"--enable-mscclpp"
print
(
cmd
)
os
.
system
(
cmd
)
...
...
@@ -63,6 +65,11 @@ if __name__ == "__main__":
action
=
"store_true"
,
help
=
"Disable custom all reduce when device does not support p2p communication"
,
)
parser
.
add_argument
(
"--enable-mscclpp"
,
action
=
"store_true"
,
help
=
"Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL."
,
)
args
=
parser
.
parse_args
()
launch_server
(
args
)
docs/backend/server_arguments.md
View file @
8e3797be
...
...
@@ -201,6 +201,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
`disable_cuda_graph_padding`
| Disable CUDA Graph when padding is needed; otherwise, still use CUDA Graph. |
`False`
|
|
`disable_outlines_disk_cache`
| Disable disk cache for outlines grammar backend. |
`False`
|
|
`disable_custom_all_reduce`
| Disable usage of custom all-reduce kernel. |
`False`
|
|
`enable_mscclpp`
| Enable usage of mscclpp kernel for small message all-reduce. |
`False`
|
|
`disable_overlap_schedule`
| Disable the
[
Overhead-Scheduler
](
https://lmsys.org/blog/2024-12-04-sglang-v0-4/#zero-overhead-batch-scheduler
)
. |
`False`
|
|
`enable_nan_detection`
| Enable warning if the logits contain
`NaN`
. |
`False`
|
|
`enable_p2p_check`
| Turns off the default of always allowing P2P checks when accessing GPU. |
`False`
|
...
...
python/sglang/srt/_custom_ops.py
View file @
8e3797be
...
...
@@ -113,3 +113,37 @@ else:
def
get_meta_buffer_ipc_handle
(
inp
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
sgl_kernel
.
allreduce
.
get_meta_buffer_ipc_handle
(
inp
)
def
mscclpp_generate_unique_id
()
->
bytes
:
return
sgl_kernel
.
allreduce
.
mscclpp_generate_unique_id
()
def
mscclpp_init_context
(
unique_id
:
bytes
,
rank
:
int
,
world_size
:
int
,
scratch
:
torch
.
Tensor
,
put_buffer
:
torch
.
Tensor
,
nranks_per_node
:
int
,
rank_to_node
:
List
[
int
],
rank_to_ib
:
List
[
int
],
context_selection
:
int
,
)
->
int
:
return
sgl_kernel
.
allreduce
.
mscclpp_init_context
(
unique_id
,
rank
,
world_size
,
scratch
,
put_buffer
,
nranks_per_node
,
rank_to_node
,
rank_to_ib
,
context_selection
,
)
def
mscclpp_allreduce
(
context
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
nthreads
:
int
,
nblocks
:
int
)
->
None
:
return
sgl_kernel
.
allreduce
.
mscclpp_allreduce
(
context
,
inp
,
out
,
nthreads
,
nblocks
)
python/sglang/srt/distributed/device_communicators/pymscclpp.py
0 → 100644
View file @
8e3797be
import
bisect
import
logging
import
math
import
os
from
contextlib
import
contextmanager
from
enum
import
IntEnum
from
typing
import
Any
,
Callable
,
List
,
Optional
,
TypeVar
,
Union
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
,
ReduceOp
from
sglang.srt
import
_custom_ops
as
ops
from
sglang.srt.utils
import
is_cuda
,
is_hip
logger
=
logging
.
getLogger
(
__name__
)
_is_cuda
=
is_cuda
()
_is_hip
=
is_hip
()
mscclpp_is_available
=
False
if
_is_hip
:
# TODO(zyksir): mscclpp is untested on AMD and therefore disabled.
mscclpp_is_available
=
False
if
_is_cuda
:
try
:
import
sgl_kernel
mscclpp_is_available
=
True
except
:
mscclpp_is_available
=
False
class
MscclContextSelection
(
IntEnum
):
MSCCL1SHOT1NODELL
=
1
MSCCL1SHOT2NODELL
=
2
def
mscclpp_is_weak_contiguous
(
inp
:
torch
.
Tensor
):
return
inp
.
is_contiguous
()
or
(
inp
.
storage
().
nbytes
()
-
inp
.
storage_offset
()
*
inp
.
element_size
()
==
inp
.
numel
()
*
inp
.
element_size
()
)
def
mscclpp_convert_to_bytes
(
size_str
):
"""
Converts a human-readable size string (e.g., "1MB", "2.5kb", "3 GB")
into the equivalent number of bytes using binary units.
Args:
size_str (str): A string representing size with unit (KB, MB, GB).
Returns:
int: Number of bytes.
"""
size_str
=
size_str
.
strip
().
lower
()
if
not
size_str
:
raise
ValueError
(
"Empty input string"
)
# Extract numeric part and unit
for
i
in
range
(
len
(
size_str
)):
if
not
size_str
[
i
].
isdigit
()
and
size_str
[
i
]
!=
"."
:
break
num_str
=
size_str
[:
i
]
unit
=
size_str
[
i
:].
strip
()
try
:
num
=
float
(
num_str
)
except
ValueError
:
raise
ValueError
(
f
"Invalid numeric value in '
{
size_str
}
'"
)
# Conversion factors
if
unit
==
"b"
:
return
int
(
num
)
elif
unit
==
"kb"
:
return
int
(
num
*
1024
)
elif
unit
==
"mb"
:
return
int
(
num
*
1024
*
1024
)
elif
unit
==
"gb"
:
return
int
(
num
*
1024
*
1024
*
1024
)
else
:
raise
ValueError
(
f
"Unsupported unit:
{
unit
}
, support B, KB, MB, GB only"
)
def
mscclpp_bench_time
(
func
,
test_niter
:
int
=
10
,
warmup_niter
:
int
=
2
):
# warmup
for
_
in
range
(
warmup_niter
):
func
()
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
torch
.
cuda
.
synchronize
()
dist
.
barrier
()
start_event
.
record
()
for
_
in
range
(
test_niter
):
func
()
end_event
.
record
()
end_event
.
synchronize
()
func_cost_us
=
start_event
.
elapsed_time
(
end_event
)
/
test_niter
*
1000
return
func_cost_us
class
PyMscclppCommunicator
:
_SUPPORTED_WORLD_SIZES
=
[
8
,
16
]
_MAX_BYTES
=
mscclpp_convert_to_bytes
(
os
.
getenv
(
"SGLANG_MSCCLPP_MAX_BYTES"
,
"1MB"
))
_SUPPORTED_DTYPE
=
[
torch
.
float
,
torch
.
float16
,
torch
.
bfloat16
]
# max_bytes: max supported mscclpp allreduce size
# in A100 mscclpp is faster than nccl only under condition of msg size smaller than1MB
def
__init__
(
self
,
group
:
ProcessGroup
,
device
:
Union
[
int
,
str
,
torch
.
device
],
max_bytes
=
_MAX_BYTES
,
)
->
None
:
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self
.
_IS_CAPTURING
=
False
self
.
disabled
=
True
if
not
mscclpp_is_available
:
# disable because of missing mscclpp library
# e.g. in a non-cuda environment
return
self
.
group
=
group
assert
(
dist
.
get_backend
(
group
)
!=
dist
.
Backend
.
NCCL
),
"CustomAllreduce should be attached to a non-NCCL group."
rank
=
dist
.
get_rank
(
group
=
self
.
group
)
world_size
=
dist
.
get_world_size
(
group
=
self
.
group
)
if
world_size
==
1
:
# No need to initialize mscclpp for single GPU case.
return
if
world_size
not
in
PyMscclppCommunicator
.
_SUPPORTED_WORLD_SIZES
:
logger
.
warning
(
"PyMscclpp is disabled due to an unsupported world"
" size: %d. Supported world sizes: %s. To silence this "
"warning, specify disable_mscclpp=True explicitly."
,
world_size
,
str
(
PyMscclppCommunicator
.
_SUPPORTED_WORLD_SIZES
),
)
return
self
.
ranks
=
torch
.
distributed
.
get_process_group_ranks
(
group
)
self
.
nranks_per_node
=
torch
.
cuda
.
device_count
()
# for now mscclpp with stride in the communicator is not tested
if
not
(
abs
(
self
.
ranks
[
-
1
]
-
self
.
ranks
[
0
])
==
world_size
-
1
):
logger
.
warning
(
"PyMscclpp is disabled due to an unsupported group %s."
"Please ensure all ranks in the group are consecutive."
"To silence this warning, specify disable_mscclpp=True explicitly."
,
str
(
self
.
ranks
),
)
return
if
isinstance
(
device
,
int
):
device
=
torch
.
device
(
f
"cuda:
{
device
}
"
)
elif
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
# now `device` is a `torch.device` object
assert
isinstance
(
device
,
torch
.
device
)
self
.
device
=
device
self
.
max_bytes
=
max_bytes
self
.
rank
=
rank
self
.
world_size
=
world_size
if
dist
.
get_rank
(
group
)
==
0
:
unique_id
=
[
ops
.
mscclpp_generate_unique_id
()]
else
:
unique_id
=
[
None
]
dist
.
broadcast_object_list
(
unique_id
,
src
=
self
.
ranks
[
0
],
group
=
self
.
group
)
self
.
unique_id
=
unique_id
[
0
]
self
.
rank_to_node
,
self
.
rank_to_ib
=
list
(
range
(
world_size
)),
list
(
range
(
world_size
)
)
for
r
in
range
(
world_size
):
self
.
rank_to_node
[
r
]
=
r
//
8
self
.
rank_to_ib
[
r
]
=
self
.
rank
%
8
self
.
_context
=
None
self
.
context_selection
=
None
self
.
msg_size_for_finetune
=
[
2
**
i
for
i
in
range
(
10
,
math
.
floor
(
math
.
log2
(
self
.
max_bytes
))
+
1
)
]
self
.
msg_size2best_config
=
{}
if
world_size
==
8
:
self
.
context_selection
=
MscclContextSelection
.
MSCCL1SHOT1NODELL
elif
world_size
==
16
:
self
.
context_selection
=
MscclContextSelection
.
MSCCL1SHOT2NODELL
if
not
_is_hip
:
self
.
scratch
=
torch
.
empty
(
self
.
max_bytes
*
8
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
,
)
self
.
put_buffer
=
torch
.
empty
(
self
.
max_bytes
*
8
//
self
.
nranks_per_node
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
,
)
self
.
_context
=
ops
.
mscclpp_init_context
(
self
.
unique_id
,
self
.
rank
,
self
.
world_size
,
self
.
scratch
,
self
.
put_buffer
,
self
.
nranks_per_node
,
self
.
rank_to_node
,
self
.
rank_to_ib
,
int
(
self
.
context_selection
),
)
else
:
raise
NotImplementedError
(
"HIP Mscclpp is not supported yet."
)
self
.
msg_size2best_config
=
{}
self
.
pre_tune_config
()
if
dist
.
get_rank
(
group
)
==
0
:
msg_size2best_config
=
[
self
.
msg_size2best_config
]
else
:
msg_size2best_config
=
[
None
]
dist
.
broadcast_object_list
(
msg_size2best_config
,
src
=
self
.
ranks
[
0
],
group
=
self
.
group
)
self
.
msg_size2best_config
=
msg_size2best_config
[
0
]
# PyMscclpp is enabled only in cuda graph
self
.
disabled
=
True
def
pre_tune_config
(
self
,
dtype
=
torch
.
bfloat16
)
->
bool
:
logger
.
debug
(
f
"start to pre-tune configs for rank
{
self
.
rank
}
"
)
nthreads_to_try
=
[
256
,
512
,
1024
]
nblocks_to_try
=
[
21
,
42
,
84
]
inp_randn
=
torch
.
ones
(
self
.
msg_size_for_finetune
[
-
1
]
//
dtype
.
itemsize
,
dtype
=
dtype
,
device
=
"cuda"
)
oup_randn
=
torch
.
empty_like
(
inp_randn
)
for
msg_size
in
self
.
msg_size_for_finetune
:
mock_inp
,
mock_outp
=
(
inp_randn
[:
msg_size
//
dtype
.
itemsize
],
oup_randn
[:
msg_size
//
dtype
.
itemsize
],
)
best_config
,
best_time
=
None
,
None
for
nthreads
in
nthreads_to_try
:
for
nblocks
in
nblocks_to_try
:
cur_cost
=
mscclpp_bench_time
(
lambda
:
ops
.
mscclpp_allreduce
(
self
.
_context
,
mock_inp
,
mock_outp
,
nthreads
,
nblocks
)
)
if
best_time
is
None
or
cur_cost
<
best_time
:
best_config
=
(
nthreads
,
nblocks
)
best_time
=
cur_cost
self
.
msg_size2best_config
[
msg_size
]
=
best_config
if
self
.
rank
==
0
:
logger
.
debug
(
f
"for msg_size
{
msg_size
}
, best_config:
{
best_config
}
, best_time:
{
best_time
}
us"
)
def
should_mscclpp_allreduce
(
self
,
inp
:
torch
.
Tensor
,
op
:
ReduceOp
=
ReduceOp
.
SUM
)
->
bool
:
if
self
.
disabled
or
self
.
_context
is
None
:
return
False
if
inp
.
dtype
not
in
PyMscclppCommunicator
.
_SUPPORTED_DTYPE
:
return
False
if
not
mscclpp_is_weak_contiguous
(
inp
):
return
False
# only support sum op
if
op
!=
ReduceOp
.
SUM
:
return
False
if
inp
.
numel
()
*
inp
.
element_size
()
>
self
.
max_bytes
:
return
False
return
True
def
all_reduce
(
self
,
tensor
:
torch
.
Tensor
,
op
:
ReduceOp
=
ReduceOp
.
SUM
):
if
self
.
_IS_CAPTURING
:
if
torch
.
cuda
.
is_current_stream_capturing
():
self
.
graph_input_set
.
add
((
tensor
.
dtype
,
tensor
.
numel
()))
msg_size
=
tensor
.
numel
()
*
tensor
.
itemsize
index
=
bisect
.
bisect_left
(
self
.
msg_size_for_finetune
,
msg_size
)
msg_size_finetune
=
self
.
msg_size_for_finetune
[
index
]
nthreads
,
nblocks
=
self
.
msg_size2best_config
[
msg_size_finetune
]
result
=
torch
.
empty_like
(
tensor
)
ops
.
mscclpp_allreduce
(
self
.
_context
,
tensor
,
result
,
nthreads
,
nblocks
)
return
result
@
contextmanager
def
change_state
(
self
,
enable
:
Optional
[
bool
]
=
None
,
):
if
enable
is
None
:
# guess a default value when not specified
enable
=
self
.
available
old_disable
=
self
.
disabled
self
.
disabled
=
not
enable
yield
self
.
disabled
=
old_disable
python/sglang/srt/distributed/parallel_state.py
View file @
8e3797be
...
...
@@ -190,6 +190,7 @@ class GroupCoordinator:
cpu_group
:
ProcessGroup
# group for CPU communication
device_group
:
ProcessGroup
# group for device communication
use_pynccl
:
bool
# a hint of whether to use PyNccl
use_pymscclpp
:
bool
# a hint of whether to use PyMsccl
use_custom_allreduce
:
bool
# a hint of whether to use CustomAllreduce
use_message_queue_broadcaster
:
(
bool
# a hint of whether to use message queue broadcaster
...
...
@@ -205,6 +206,7 @@ class GroupCoordinator:
local_rank
:
int
,
torch_distributed_backend
:
Union
[
str
,
Backend
],
use_pynccl
:
bool
,
use_pymscclpp
:
bool
,
use_custom_allreduce
:
bool
,
use_hpu_communicator
:
bool
,
use_xpu_communicator
:
bool
,
...
...
@@ -244,6 +246,7 @@ class GroupCoordinator:
self
.
device
=
torch
.
device
(
"cpu"
)
self
.
use_pynccl
=
use_pynccl
self
.
use_pymscclpp
=
use_pymscclpp
self
.
use_custom_allreduce
=
use_custom_allreduce
self
.
use_hpu_communicator
=
use_hpu_communicator
self
.
use_xpu_communicator
=
use_xpu_communicator
...
...
@@ -265,6 +268,17 @@ class GroupCoordinator:
device
=
self
.
device
,
)
from
sglang.srt.distributed.device_communicators.pymscclpp
import
(
PyMscclppCommunicator
,
)
self
.
pymscclpp_comm
:
Optional
[
PyMscclppCommunicator
]
=
None
if
use_pymscclpp
and
self
.
world_size
>
1
:
self
.
pymscclpp_comm
=
PyMscclppCommunicator
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
)
self
.
ca_comm
:
Optional
[
CustomAllreduce
]
=
None
if
use_custom_allreduce
and
self
.
world_size
>
1
:
# Initialize a custom fast all-reduce implementation.
...
...
@@ -373,11 +387,15 @@ class GroupCoordinator:
# --------------------------------------------
# custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled |
# PyMscclpp | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note that custom allreduce will have a runtime check, if the
# tensor size is too large, it will fallback to the next
# available option.
# Note that the PyMsccl needs to register the tensor in ahead,
# which will introduce large overhead in the eager case,
# therefore it is only supported in the graph case.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using
# CUDA graph, we use either custom all-reduce kernel or
...
...
@@ -392,7 +410,14 @@ class GroupCoordinator:
maybe_pynccl_context
=
pynccl_comm
.
change_state
(
enable
=
True
,
stream
=
torch
.
cuda
.
current_stream
()
)
with
maybe_pynccl_context
:
pymscclpp_comm
=
self
.
pymscclpp_comm
maybe_pymscclpp_context
:
Any
if
not
pymscclpp_comm
:
maybe_pymscclpp_context
=
nullcontext
()
else
:
maybe_pymscclpp_context
=
pymscclpp_comm
.
change_state
(
enable
=
True
)
with
maybe_pynccl_context
,
maybe_pymscclpp_context
:
yield
graph_capture_context
def
all_reduce
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -437,6 +462,10 @@ class GroupCoordinator:
self
.
ca_comm
is
not
None
and
not
self
.
ca_comm
.
disabled
and
self
.
ca_comm
.
should_custom_ar
(
input_
)
)
or
(
self
.
pymscclpp_comm
is
not
None
and
not
self
.
pymscclpp_comm
.
disabled
and
self
.
pymscclpp_comm
.
should_mscclpp_allreduce
(
input_
)
):
return
torch
.
ops
.
sglang
.
outplace_all_reduce
(
input_
,
group_name
=
self
.
unique_name
...
...
@@ -447,9 +476,13 @@ class GroupCoordinator:
def
_all_reduce_out_place
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
ca_comm
=
self
.
ca_comm
assert
ca_comm
is
not
None
assert
not
ca_comm
.
disabled
out
=
ca_comm
.
custom_all_reduce
(
input_
)
pymscclpp_comm
=
self
.
pymscclpp_comm
assert
ca_comm
is
not
None
or
pymscclpp_comm
is
not
None
if
ca_comm
is
not
None
and
not
ca_comm
.
disabled
:
out
=
ca_comm
.
custom_all_reduce
(
input_
)
else
:
assert
not
pymscclpp_comm
.
disabled
out
=
pymscclpp_comm
.
all_reduce
(
input_
)
assert
out
is
not
None
return
out
...
...
@@ -958,6 +991,7 @@ def init_world_group(
local_rank
=
local_rank
,
torch_distributed_backend
=
backend
,
use_pynccl
=
False
,
use_pymscclpp
=
False
,
use_custom_allreduce
=
False
,
use_hpu_communicator
=
False
,
use_xpu_communicator
=
False
,
...
...
@@ -973,14 +1007,18 @@ def init_model_parallel_group(
use_custom_allreduce
:
Optional
[
bool
]
=
None
,
use_message_queue_broadcaster
:
bool
=
False
,
group_name
:
Optional
[
str
]
=
None
,
use_mscclpp_allreduce
:
Optional
[
bool
]
=
None
,
)
->
GroupCoordinator
:
if
use_custom_allreduce
is
None
:
use_custom_allreduce
=
_ENABLE_CUSTOM_ALL_REDUCE
if
use_mscclpp_allreduce
is
None
:
use_mscclpp_allreduce
=
_ENABLE_MSCCLPP_ALL_REDUCE
return
GroupCoordinator
(
group_ranks
=
group_ranks
,
local_rank
=
local_rank
,
torch_distributed_backend
=
backend
,
use_pynccl
=
not
is_npu
(),
use_pymscclpp
=
use_mscclpp_allreduce
,
use_custom_allreduce
=
use_custom_allreduce
,
use_hpu_communicator
=
True
,
use_xpu_communicator
=
True
,
...
...
@@ -1037,6 +1075,7 @@ def graph_capture():
logger
=
logging
.
getLogger
(
__name__
)
_ENABLE_CUSTOM_ALL_REDUCE
=
True
_ENABLE_MSCCLPP_ALL_REDUCE
=
False
def
set_custom_all_reduce
(
enable
:
bool
):
...
...
@@ -1044,6 +1083,11 @@ def set_custom_all_reduce(enable: bool):
_ENABLE_CUSTOM_ALL_REDUCE
=
enable
def
set_mscclpp_all_reduce
(
enable
:
bool
):
global
_ENABLE_MSCCLPP_ALL_REDUCE
_ENABLE_MSCCLPP_ALL_REDUCE
=
enable
def
init_distributed_environment
(
world_size
:
int
=
-
1
,
rank
:
int
=
-
1
,
...
...
python/sglang/srt/layers/dp_attention.py
View file @
8e3797be
...
...
@@ -98,11 +98,12 @@ def initialize_dp_attention(
],
local_rank
,
torch
.
distributed
.
get_backend
(
tp_group
.
device_group
),
SYNC_TOKEN_IDS_ACROSS_TP
,
False
,
False
,
False
,
False
,
use_pynccl
=
SYNC_TOKEN_IDS_ACROSS_TP
,
use_pymscclpp
=
False
,
use_custom_allreduce
=
False
,
use_hpu_communicator
=
False
,
use_xpu_communicator
=
False
,
use_npu_communicator
=
False
,
group_name
=
"attention_tp"
,
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
8e3797be
...
...
@@ -35,6 +35,7 @@ from sglang.srt.distributed import (
init_distributed_environment
,
initialize_model_parallel
,
set_custom_all_reduce
,
set_mscclpp_all_reduce
,
)
from
sglang.srt.distributed.parallel_state
import
monkey_patch_vllm_parallel_state
from
sglang.srt.layers.attention.tbo_backend
import
TboAttnBackend
...
...
@@ -460,6 +461,7 @@ class ModelRunner:
else
:
dist_init_method
=
f
"tcp://127.0.0.1:
{
self
.
dist_port
}
"
set_custom_all_reduce
(
not
self
.
server_args
.
disable_custom_all_reduce
)
set_mscclpp_all_reduce
(
self
.
server_args
.
enable_mscclpp
)
if
not
self
.
is_draft_worker
:
# Only initialize the distributed environment on the target model worker.
...
...
python/sglang/srt/server_args.py
View file @
8e3797be
...
...
@@ -165,6 +165,7 @@ class ServerArgs:
enable_tokenizer_batch_encode
:
bool
=
False
disable_outlines_disk_cache
:
bool
=
False
disable_custom_all_reduce
:
bool
=
False
enable_mscclpp
:
bool
=
False
disable_overlap_schedule
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_dp_attention
:
bool
=
False
...
...
@@ -1168,6 +1169,11 @@ class ServerArgs:
action
=
"store_true"
,
help
=
"Disable the custom all-reduce kernel and fall back to NCCL."
,
)
parser
.
add_argument
(
"--enable-mscclpp"
,
action
=
"store_true"
,
help
=
"Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL."
,
)
parser
.
add_argument
(
"--disable-overlap-schedule"
,
action
=
"store_true"
,
...
...
sgl-kernel/CMakeLists.txt
View file @
8e3797be
...
...
@@ -73,6 +73,14 @@ FetchContent_Declare(
GIT_SHALLOW OFF
)
FetchContent_Populate
(
repo-flash-attention
)
# mscclpp
FetchContent_Declare
(
repo-mscclpp
GIT_REPOSITORY https://github.com/microsoft/mscclpp.git
GIT_TAG 51eca89d20f0cfb3764ccd764338d7b22cd486a6
GIT_SHALLOW OFF
)
FetchContent_Populate
(
repo-mscclpp
)
# ccache option
option
(
ENABLE_CCACHE
"Whether to use ccache"
ON
)
...
...
@@ -99,6 +107,7 @@ include_directories(
${
repo-cutlass_SOURCE_DIR
}
/tools/util/include
${
repo-flashinfer_SOURCE_DIR
}
/include
${
repo-flashinfer_SOURCE_DIR
}
/csrc
${
repo-mscclpp_SOURCE_DIR
}
/include
)
set
(
SGL_KERNEL_CUDA_FLAGS
...
...
@@ -196,6 +205,7 @@ string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE
string
(
REPLACE
"-D__CUDA_NO_HALF2_OPERATORS__"
""
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
"
)
set
(
SOURCES
"csrc/allreduce/mscclpp_allreduce.cu"
"csrc/allreduce/custom_all_reduce.cu"
"csrc/attention/cascade.cu"
"csrc/attention/merge_attn_states.cu"
...
...
@@ -250,7 +260,27 @@ target_include_directories(common_ops PRIVATE
${
repo-cutlass_SOURCE_DIR
}
/examples/common
${
repo-flash-attention_SOURCE_DIR
}
/csrc/flash_attn/src
)
target_link_libraries
(
common_ops PRIVATE
${
TORCH_LIBRARIES
}
c10 cuda cublas cublasLt
)
find_package
(
Python3 COMPONENTS Interpreter REQUIRED
)
execute_process
(
COMMAND
${
Python3_EXECUTABLE
}
-c
"import torch; print(int(torch._C._GLIBCXX_USE_CXX11_ABI))"
OUTPUT_VARIABLE TORCH_CXX11_ABI
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if
(
TORCH_CXX11_ABI STREQUAL
"0"
)
message
(
STATUS
"Using old C++ ABI (-D_GLIBCXX_USE_CXX11_ABI=0)"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-D_GLIBCXX_USE_CXX11_ABI=0"
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-D_GLIBCXX_USE_CXX11_ABI=0"
)
else
()
message
(
STATUS
"Using new C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI=1)"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-D_GLIBCXX_USE_CXX11_ABI=1"
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-D_GLIBCXX_USE_CXX11_ABI=1"
)
endif
()
set
(
MSCCLPP_USE_CUDA ON
)
set
(
MSCCLPP_BYPASS_GPU_CHECK ON
)
set
(
MSCCLPP_BUILD_TESTS OFF
)
add_subdirectory
(
${
repo-mscclpp_SOURCE_DIR
}
)
target_link_libraries
(
common_ops PRIVATE
${
TORCH_LIBRARIES
}
c10 cuda cublas cublasLt mscclpp_static
)
target_compile_definitions
(
common_ops PRIVATE
FLASHATTENTION_DISABLE_BACKWARD
...
...
sgl-kernel/Makefile
View file @
8e3797be
...
...
@@ -19,14 +19,14 @@ submodule: ## Initialize and update git submodules
@
git submodule update
--init
--recursive
ln
:
submodule
##
Create compilation database
@
rm
-rf
build
&&
mkdir
build
&&
cd
build
&&
cmake ..
-DCMAKE_EXPORT_COMPILE_COMMANDS
=
YES
@
rm
-rf
build
&&
mkdir
build
&&
cd
build
&&
cmake ..
-DCMAKE_EXPORT_COMPILE_COMMANDS
=
YES
-DCMAKE_POLICY_VERSION_MINIMUM
=
3.5
install
:
submodule
##
Install package in development mode
@
pip
install
-e
.
--no-build-isolation
build
:
install-deps submodule
##
Build and install wheel package
@
rm
-rf
dist/
*
||
true
&&
export
MAX_JOBS
=
$(nproc)
&&
CMAKE_BUILD_PARALLEL_LEVEL
=
$(nproc)
uv build
--wheel
-Cbuild-dir
=
build
.
--verbose
--color
=
always
--no-build-isolation
&&
pip3
install
dist/
*
whl
--force-reinstall
--no-deps
@
rm
-rf
dist/
*
||
true
&&
export
MAX_JOBS
=
$(nproc)
&&
CMAKE_POLICY_VERSION_MINIMUM
=
3.5
CMAKE_BUILD_PARALLEL_LEVEL
=
$(nproc)
uv build
--wheel
-Cbuild-dir
=
build
.
--verbose
--color
=
always
--no-build-isolation
&&
pip3
install
dist/
*
whl
--force-reinstall
--no-deps
clean
:
##
Remove build artifacts
@
rm
-rf
build dist
*
.egg-info
...
...
sgl-kernel/build.sh
View file @
8e3797be
...
...
@@ -50,6 +50,9 @@ docker run --rm \
which cmake
cmake --version
yum install numactl-devel -y &&
\
yum install libibverbs -y &&
\
ln -sv /usr/lib64/libibverbs.so.1 /usr/lib64/libibverbs.so &&
\
${
PYTHON_ROOT_PATH
}
/bin/
${
TORCH_INSTALL
}
&&
\
${
PYTHON_ROOT_PATH
}
/bin/pip install --no-cache-dir ninja setuptools==75.0.0 wheel==0.41.0 numpy uv scikit-build-core &&
\
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX' &&
\
...
...
sgl-kernel/csrc/allreduce/mscclpp_allreduce.cu
0 → 100644
View file @
8e3797be
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/all.h>
#include <torch/library.h>
#include "mscclpp_allreduce.cuh"
enum
MscclContextSelection
{
MSCCL1NODELL
=
1
,
MSCCL2NODELL
=
2
,
};
class
MscclContext
{
public:
MscclContextSelection
selection_
;
std
::
shared_ptr
<
sglang
::
Msccl1NodeLLcontext
>
msccl_1nodeLL_context
;
std
::
shared_ptr
<
sglang
::
Msccl2NodeLLcontext
>
msccl_2nodeLL_context
;
MscclContext
(
MscclContextSelection
selection
)
:
selection_
(
selection
)
{}
template
<
typename
T
>
void
allreduce
(
cudaStream_t
stream
,
T
*
input
,
T
*
output
,
const
size_t
input_numel
,
int
threads
=
512
,
int
block_limit
=
21
)
{
if
(
selection_
==
MSCCL1NODELL
)
{
msccl_1nodeLL_context
->
allreduce
<
T
>
(
stream
,
input
,
output
,
input_numel
,
threads
,
block_limit
);
}
else
if
(
selection_
==
MSCCL2NODELL
)
{
msccl_2nodeLL_context
->
allreduce
<
T
>
(
stream
,
input
,
output
,
input_numel
,
threads
,
block_limit
);
}
}
};
using
fptr_t
=
int64_t
;
static_assert
(
sizeof
(
void
*
)
==
sizeof
(
fptr_t
));
torch
::
Tensor
_unique_id2tensor
(
const
mscclpp
::
UniqueId
&
unique_id
)
{
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kByte
).
device
(
torch
::
kCPU
);
auto
tensor
=
torch
::
empty
({
static_cast
<
int64_t
>
(
unique_id
.
size
())},
options
);
std
::
memcpy
(
tensor
.
data_ptr
<
uint8_t
>
(),
unique_id
.
data
(),
unique_id
.
size
());
return
tensor
;
}
// Function to convert vector of int32_t back to array of uint8_t
mscclpp
::
UniqueId
_tensor2unique_id
(
const
torch
::
Tensor
&
tensor
)
{
mscclpp
::
UniqueId
unique_id
;
std
::
memcpy
(
unique_id
.
data
(),
tensor
.
data_ptr
<
uint8_t
>
(),
unique_id
.
size
());
return
unique_id
;
}
torch
::
Tensor
mscclpp_generate_unique_id
()
{
mscclpp
::
UniqueId
unique_id
=
mscclpp
::
TcpBootstrap
::
createUniqueId
();
return
_unique_id2tensor
(
unique_id
);
}
fptr_t
mscclpp_init_context
(
const
torch
::
Tensor
&
unique_id
,
const
int64_t
rank
,
const
int64_t
world_size
,
torch
::
Tensor
&
scratch
,
torch
::
Tensor
&
put_buffer
,
const
int64_t
nranks_per_node
,
const
std
::
vector
<
int64_t
>&
rank_to_node
,
const
std
::
vector
<
int64_t
>&
rank_to_ib
,
const
int64_t
context_selection
)
{
MscclContext
*
context_ptr
=
new
MscclContext
(
static_cast
<
MscclContextSelection
>
(
context_selection
));
mscclpp
::
UniqueId
uid
=
_tensor2unique_id
(
unique_id
);
if
(
context_selection
==
MSCCL1NODELL
)
{
void
*
scratch_ptr
=
reinterpret_cast
<
void
*>
(
scratch
.
data_ptr
());
const
size_t
scratch_bytes
=
scratch
.
numel
()
*
scratch
.
element_size
();
context_ptr
->
msccl_1nodeLL_context
=
std
::
make_shared
<
sglang
::
Msccl1NodeLLcontext
>
(
uid
,
rank
,
world_size
,
scratch_ptr
,
scratch_bytes
,
nranks_per_node
,
rank_to_node
,
rank_to_ib
);
}
else
if
(
context_selection
==
MSCCL2NODELL
)
{
void
*
scratch_ptr
=
reinterpret_cast
<
void
*>
(
scratch
.
data_ptr
());
const
size_t
scratch_bytes
=
scratch
.
numel
()
*
scratch
.
element_size
();
void
*
put_buffer_ptr
=
reinterpret_cast
<
void
*>
(
put_buffer
.
data_ptr
());
const
size_t
put_buffer_bytes
=
put_buffer
.
numel
()
*
put_buffer
.
element_size
();
context_ptr
->
msccl_2nodeLL_context
=
std
::
make_shared
<
sglang
::
Msccl2NodeLLcontext
>
(
uid
,
rank
,
world_size
,
scratch_ptr
,
scratch_bytes
,
put_buffer_ptr
,
put_buffer_bytes
,
nranks_per_node
,
rank_to_node
,
rank_to_ib
);
}
else
{
throw
std
::
runtime_error
(
"invalid context selection"
);
}
return
(
fptr_t
)
context_ptr
;
}
bool
_mscclpp_is_weak_contiguous
(
torch
::
Tensor
&
t
)
{
return
t
.
is_contiguous
()
||
(
t
.
storage
().
nbytes
()
-
t
.
storage_offset
()
*
t
.
element_size
()
==
t
.
numel
()
*
t
.
element_size
());
}
void
mscclpp_allreduce
(
fptr_t
_context
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
int64_t
nthreads
,
int64_t
nblocks
)
{
MscclContext
*
context
=
reinterpret_cast
<
MscclContext
*>
(
_context
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
inp
));
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
TORCH_CHECK_EQ
(
inp
.
scalar_type
(),
out
.
scalar_type
());
TORCH_CHECK_EQ
(
inp
.
numel
(),
out
.
numel
());
TORCH_CHECK
(
_mscclpp_is_weak_contiguous
(
out
));
TORCH_CHECK
(
_mscclpp_is_weak_contiguous
(
inp
));
switch
(
out
.
scalar_type
())
{
case
at
::
ScalarType
::
Float
:
{
context
->
allreduce
<
float
>
(
stream
,
reinterpret_cast
<
float
*>
(
inp
.
data_ptr
()),
reinterpret_cast
<
float
*>
(
out
.
data_ptr
()),
inp
.
numel
(),
nthreads
,
nblocks
);
break
;
}
case
at
::
ScalarType
::
Half
:
{
context
->
allreduce
<
half
>
(
stream
,
reinterpret_cast
<
half
*>
(
inp
.
data_ptr
()),
reinterpret_cast
<
half
*>
(
out
.
data_ptr
()),
inp
.
numel
(),
nthreads
,
nblocks
);
break
;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case
at
::
ScalarType
::
BFloat16
:
{
context
->
allreduce
<
__nv_bfloat16
>
(
stream
,
reinterpret_cast
<
__nv_bfloat16
*>
(
inp
.
data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
*>
(
out
.
data_ptr
()),
inp
.
numel
(),
nthreads
,
nblocks
);
break
;
}
#endif
default:
throw
std
::
runtime_error
(
"custom allreduce only supports float32, float16 and bfloat16"
);
}
}
sgl-kernel/csrc/allreduce/mscclpp_allreduce.cuh
0 → 100644
View file @
8e3797be
This diff is collapsed.
Click to expand it.
sgl-kernel/csrc/allreduce/test_mscclpp_allreduce.cu
0 → 100644
View file @
8e3797be
/*
* this file is used to test mscclpp_allreduce.cu using mpirun
* this file is adapted from https://github.com/flashinfer-ai/flashinfer/blob/v0.2.5/src/test_sum_all_reduce.cu
usage:
cd PATH-TO-THIS-FILE
export MPI_HOME=/usr/local/mpi
# export MPI_HOME=/opt/hpcx/ompi/
export MSCCLPP_HOME=/workspace/test/mscclpp
nvcc -O2 -arch=native -std=c++17 test_mscclpp_allreduce.cu \
-o test_mscclpp_allreduce -D_GLIBCXX_USE_CXX11_ABI=0 \
-I${MSCCLPP_HOME}/include -L${MSCCLPP_HOME}/build -lmscclpp \
-lnccl -I${MPI_HOME}/include -L${MPI_HOME}/lib -lmpi
/opt/hpcx/ompi/bin/
mpirun --allow-run-as-root -H 127.0.0.1:8 -np 8 \
--map-by ppr:8:node \
--mca btl_openib_warn_no_device_params_found 0 \
--mca btl_tcp_if_include bond0 \
--allow-run-as-root -np 8 \
-x NCCL_RUNTIME_CONNECT=0 -x NCCL_IB_GID_INDEX=3 -x NCCL_DEBUG=WARN \
-x LD_PRELOAD=${MSCCLPP_HOME}/build/libmscclpp.so ./test_mscclpp_allreduce
*/
#include <mpi.h>
#include <thrust/detail/raw_pointer_cast.h>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#ifndef CHECK_CUDA_SUCCESS
#define CHECK_CUDA_SUCCESS(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while (0)
#endif
#include <cstdint>
#include "mscclpp_allreduce.cuh"
template
<
typename
T
>
bool
isclose
(
T
a
,
T
b
,
float
rtol
=
1e-5
,
float
atol
=
1e-8
)
{
return
fabs
(
a
-
b
)
<=
(
atol
+
rtol
*
fabs
(
b
));
}
int
main
(
int
argc
,
char
*
argv
[])
{
// init mpi
MPI_Init
(
&
argc
,
&
argv
);
printf
(
"MPI Initialized.
\n
"
);
int
nranks
,
rank
;
// get work size and rank id
MPI_Comm_size
(
MPI_COMM_WORLD
,
&
nranks
);
MPI_Comm_rank
(
MPI_COMM_WORLD
,
&
rank
);
cudaSetDevice
(
rank
);
printf
(
"nranks: %d, rank: %d
\n
"
,
nranks
,
rank
);
// init host and device buffers
using
T
=
float
;
using
ReduceT
=
float
;
const
size_t
num_elems
=
2
*
1024
*
1024
;
std
::
vector
<
T
>
host_buf
(
num_elems
);
for
(
uint32_t
i
=
0
;
i
<
num_elems
;
++
i
)
{
host_buf
[
i
]
=
T
(
i
+
rank
);
}
thrust
::
device_vector
<
T
>
device_buf
(
host_buf
);
const
size_t
buf_size_in_bytes
=
num_elems
*
sizeof
(
T
);
std
::
vector
<
T
>
host_result_buf
(
num_elems
);
thrust
::
device_vector
<
T
>
device_result_buf
(
host_result_buf
);
std
::
vector
<
T
>
host_scratch_buf
(
num_elems
*
8
);
for
(
uint32_t
i
=
0
;
i
<
num_elems
;
++
i
)
{
host_scratch_buf
[
i
]
=
1
;
}
thrust
::
device_vector
<
T
>
device_scratch_buf
(
host_scratch_buf
);
std
::
vector
<
T
>
host_put_buf
(
num_elems
);
thrust
::
device_vector
<
T
>
device_put_buf
(
host_put_buf
);
mscclpp
::
UniqueId
unique_id
;
if
(
rank
==
0
)
unique_id
=
mscclpp
::
TcpBootstrap
::
createUniqueId
();
MPI_Bcast
(
&
unique_id
,
sizeof
(
unique_id
),
MPI_BYTE
,
0
,
MPI_COMM_WORLD
);
std
::
vector
<
int64_t
>
rank_to_node
(
nranks
);
std
::
vector
<
int64_t
>
rank_to_ib
(
nranks
);
for
(
int
i
=
0
;
i
<
nranks
;
i
++
)
{
rank_to_node
[
i
]
=
i
/
8
;
rank_to_ib
[
i
]
=
i
%
8
;
}
cudaStream_t
s
;
CHECK_CUDA_SUCCESS
(
cudaStreamCreate
(
&
s
));
CHECK_CUDA_SUCCESS
(
cudaStreamSynchronize
(
s
));
if
(
nranks
==
8
)
{
auto
context
=
std
::
make_shared
<
sglang
::
Msccl1NodeLLcontext
>
(
unique_id
,
rank
,
nranks
,
thrust
::
raw_pointer_cast
(
device_scratch_buf
.
data
()),
buf_size_in_bytes
*
8
,
rank_to_node
,
rank_to_ib
);
printf
(
"rank: %d, Msccl1NodeLLcontext setup.
\n
"
,
rank
);
MPI_Barrier
(
MPI_COMM_WORLD
);
context
->
allreduce
<
T
>
(
s
,
thrust
::
raw_pointer_cast
(
device_buf
.
data
()),
thrust
::
raw_pointer_cast
(
device_result_buf
.
data
()),
device_buf
.
size
());
}
else
if
(
nranks
==
16
)
{
// TODO: this branch is untested since there is something wrong with mpirun in my test machince
auto
context
=
std
::
make_shared
<
sglang
::
Msccl2NodeLLcontext
>
(
unique_id
,
rank
,
nranks
,
thrust
::
raw_pointer_cast
(
device_scratch_buf
.
data
()),
buf_size_in_bytes
*
8
,
thrust
::
raw_pointer_cast
(
device_put_buf
.
data
()),
buf_size_in_bytes
,
rank_to_node
,
rank_to_ib
);
printf
(
"rank: %d, Msccl2NodeLLcontext setup.
\n
"
,
rank
);
MPI_Barrier
(
MPI_COMM_WORLD
);
context
->
allreduce
<
T
>
(
s
,
thrust
::
raw_pointer_cast
(
device_buf
.
data
()),
thrust
::
raw_pointer_cast
(
device_result_buf
.
data
()),
device_buf
.
size
());
}
// check result correctness
thrust
::
host_vector
<
T
>
host_buf_result
=
device_result_buf
;
size_t
num_results_error_atol_1e_3_rtol_1e_3
=
0
;
bool
nan_detected
=
false
;
for
(
uint32_t
i
=
0
;
i
<
num_elems
;
++
i
)
{
T
expected
=
T
(
i
*
nranks
+
(
nranks
-
1
)
*
nranks
/
2
);
if
(
std
::
isnan
(
float
(
host_buf_result
[
i
])))
{
nan_detected
=
true
;
}
if
(
!
isclose
(
float
(
host_buf_result
[
i
]),
float
(
expected
),
1e-3
,
1e-3
))
{
num_results_error_atol_1e_3_rtol_1e_3
++
;
}
}
float
result_accuracy
=
1.
-
float
(
num_results_error_atol_1e_3_rtol_1e_3
)
/
float
(
num_elems
);
printf
(
"rank: %d, nan_detected: %d accuracy: %f
\n
"
,
rank
,
nan_detected
,
result_accuracy
);
CHECK_CUDA_SUCCESS
(
cudaStreamDestroy
(
s
));
MPI_Finalize
();
return
0
;
}
sgl-kernel/csrc/common_extension.cc
View file @
8e3797be
...
...
@@ -38,6 +38,15 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
"int reg_buffer_sz_bytes) -> ()"
);
m
.
impl
(
"all_reduce"
,
torch
::
kCUDA
,
&
all_reduce
);
m
.
def
(
"mscclpp_generate_unique_id"
,
&
mscclpp_generate_unique_id
);
m
.
def
(
"mscclpp_init_context(Tensor unique_id, int rank, int world_size, Tensor scratch, Tensor put_buffer, "
"int nranks_per_node, int[] rank_to_node, int[] rank_to_ib, int context_selection) -> int"
);
m
.
impl
(
"mscclpp_init_context"
,
torch
::
kCUDA
,
&
mscclpp_init_context
);
m
.
def
(
"mscclpp_allreduce(int context, Tensor inp, Tensor! out, int nthreads, int nblocks) -> ()"
);
m
.
impl
(
"mscclpp_allreduce"
,
torch
::
kCUDA
,
&
mscclpp_allreduce
);
/*
* From csrc/attention
*/
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
8e3797be
...
...
@@ -74,6 +74,18 @@ std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta
void
register_buffer
(
fptr_t
_fa
,
const
std
::
vector
<
fptr_t
>&
fake_ipc_ptrs
);
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
torch
::
Tensor
mscclpp_generate_unique_id
();
fptr_t
mscclpp_init_context
(
const
torch
::
Tensor
&
unique_id
,
const
int64_t
rank
,
const
int64_t
world_size
,
torch
::
Tensor
&
scratch
,
torch
::
Tensor
&
put_buffer
,
const
int64_t
nranks_per_node
,
const
std
::
vector
<
int64_t
>&
rank_to_node
,
const
std
::
vector
<
int64_t
>&
rank_to_ib
,
const
int64_t
context_selection
);
void
mscclpp_allreduce
(
fptr_t
_context
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
int64_t
nthreads
,
int64_t
nblocks
);
#endif
/*
...
...
sgl-kernel/python/sgl_kernel/allreduce.py
View file @
8e3797be
...
...
@@ -49,6 +49,27 @@ if torch.version.hip is not None:
def
get_meta_buffer_ipc_handle
(
inp
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
ops
.
sgl_kernel
.
get_meta_buffer_ipc_handle
.
default
(
inp
)
def
mscclpp_generate_unique_id
()
->
bytes
:
raise
NotImplementedError
()
def
mscclpp_init_context
(
unique_id
:
bytes
,
rank
:
int
,
world_size
:
int
,
scratch
:
torch
.
Tensor
,
put_buffer
:
torch
.
Tensor
,
nranks_per_node
:
int
,
rank_to_node
:
List
[
int
],
rank_to_ib
:
List
[
int
],
context_selection
:
int
,
)
->
int
:
raise
NotImplementedError
()
def
mscclpp_allreduce
(
context
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
nthreads
:
int
,
nblocks
:
int
)
->
None
:
raise
NotImplementedError
()
else
:
def
init_custom_ar
(
...
...
@@ -85,3 +106,36 @@ else:
def
meta_size
()
->
int
:
return
torch
.
ops
.
sgl_kernel
.
meta_size
.
default
()
def
mscclpp_generate_unique_id
()
->
torch
.
Tensor
:
return
torch
.
ops
.
sgl_kernel
.
mscclpp_generate_unique_id
.
default
()
def
mscclpp_init_context
(
unique_id
:
torch
.
Tensor
,
rank
:
int
,
world_size
:
int
,
scratch
:
torch
.
Tensor
,
put_buffer
:
torch
.
Tensor
,
nranks_per_node
:
int
,
rank_to_node
:
List
[
int
],
rank_to_ib
:
List
[
int
],
context_selection
:
int
,
)
->
int
:
return
torch
.
ops
.
sgl_kernel
.
mscclpp_init_context
.
default
(
unique_id
,
rank
,
world_size
,
scratch
,
put_buffer
,
nranks_per_node
,
rank_to_node
,
rank_to_ib
,
context_selection
,
)
def
mscclpp_allreduce
(
context
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
nthreads
:
int
,
nblocks
:
int
)
->
None
:
torch
.
ops
.
sgl_kernel
.
mscclpp_allreduce
.
default
(
context
,
inp
,
out
,
nthreads
,
nblocks
)
sgl-kernel/tests/test_mscclpp.py
0 → 100644
View file @
8e3797be
import
multiprocessing
as
mp
import
os
import
socket
import
unittest
from
enum
import
IntEnum
from
typing
import
Any
import
sgl_kernel.allreduce
as
custom_ops
import
torch
import
torch.distributed
as
dist
class
MscclContextSelection
(
IntEnum
):
MSCCL1SHOT1NODELL
=
1
MSCCL1SHOT2NODELL
=
2
def
_run_correctness_worker
(
world_size
,
rank
,
distributed_init_port
,
test_sizes
):
device
=
torch
.
device
(
f
"cuda:
{
rank
%
torch
.
cuda
.
device_count
()
}
"
)
torch
.
cuda
.
set_device
(
device
)
distributed_init_method
=
f
"tcp://localhost:
{
distributed_init_port
}
"
dist
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
distributed_init_method
,
rank
=
rank
,
world_size
=
world_size
,
)
group
=
dist
.
group
.
WORLD
cpu_group
=
torch
.
distributed
.
new_group
(
list
(
range
(
world_size
)),
backend
=
"gloo"
)
if
rank
==
0
:
unique_id
=
[
custom_ops
.
mscclpp_generate_unique_id
()]
else
:
unique_id
=
[
None
]
dist
.
broadcast_object_list
(
unique_id
,
src
=
0
,
device
=
torch
.
device
(
"cpu"
),
group
=
cpu_group
)
unique_id
=
unique_id
[
0
]
rank_to_node
,
rank_to_ib
=
list
(
range
(
world_size
)),
list
(
range
(
world_size
))
for
r
in
range
(
world_size
):
rank_to_node
[
r
]
=
r
//
8
rank_to_ib
[
r
]
=
rank
%
8
MAX_BYTES
=
2
**
20
scratch
=
torch
.
empty
(
MAX_BYTES
*
8
,
dtype
=
torch
.
bfloat16
,
device
=
torch
.
cuda
.
current_device
()
)
put_buffer
=
torch
.
empty
(
MAX_BYTES
,
dtype
=
torch
.
bfloat16
,
device
=
torch
.
cuda
.
current_device
()
)
print
(
f
"[
{
rank
}
] start mscclpp_context init"
)
nranks_per_node
=
torch
.
cuda
.
device_count
()
selection
=
int
(
MscclContextSelection
.
MSCCL1SHOT1NODELL
)
mscclpp_context
=
custom_ops
.
mscclpp_init_context
(
unique_id
,
rank
,
world_size
,
scratch
,
put_buffer
,
nranks_per_node
,
rank_to_node
,
rank_to_ib
,
selection
,
)
try
:
test_loop
=
10
for
sz
in
test_sizes
:
for
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]:
if
sz
*
dtype
.
itemsize
>
MAX_BYTES
:
continue
if
rank
==
0
:
print
(
f
"mscclpp allreduce test sz
{
sz
}
, dtype
{
dtype
}
"
)
for
_
in
range
(
test_loop
):
inp1
=
torch
.
randint
(
1
,
16
,
(
sz
,),
dtype
=
dtype
,
device
=
device
)
inp1_ref
=
inp1
.
clone
()
out1
=
torch
.
empty_like
(
inp1
)
custom_ops
.
mscclpp_allreduce
(
mscclpp_context
,
inp1
,
out1
,
nthreads
=
512
,
nblocks
=
21
)
dist
.
all_reduce
(
inp1_ref
,
group
=
group
)
torch
.
testing
.
assert_close
(
out1
,
inp1_ref
)
finally
:
dist
.
barrier
(
group
=
group
)
dist
.
destroy_process_group
(
group
=
group
)
def
get_open_port
()
->
int
:
try
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
"127.0.0.1"
,
0
))
return
s
.
getsockname
()[
1
]
except
OSError
:
with
socket
.
socket
(
socket
.
AF_INET6
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
"::1"
,
0
))
return
s
.
getsockname
()[
1
]
def
multi_process_parallel
(
world_size
:
int
,
test_target
:
Any
,
target_args
:
tuple
=
()
)
->
None
:
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
procs
=
[]
distributed_init_port
=
get_open_port
()
for
i
in
range
(
world_size
):
proc_args
=
(
world_size
,
i
,
distributed_init_port
)
+
target_args
proc
=
mp
.
Process
(
target
=
test_target
,
args
=
proc_args
,
name
=
f
"Worker-
{
i
}
"
)
proc
.
start
()
procs
.
append
(
proc
)
for
i
in
range
(
world_size
):
procs
[
i
].
join
()
assert
(
procs
[
i
].
exitcode
==
0
),
f
"Process
{
i
}
failed with exit code
{
procs
[
i
].
exitcode
}
"
class
TestMSCCLAllReduce
(
unittest
.
TestCase
):
test_sizes
=
[
512
,
2560
,
4096
,
5120
,
7680
,
32768
,
262144
,
524288
,
]
world_sizes
=
[
8
]
def
test_correctness
(
self
):
for
world_size
in
self
.
world_sizes
:
available_gpus
=
torch
.
cuda
.
device_count
()
if
world_size
>
available_gpus
:
print
(
f
"Skipping world_size=
{
world_size
}
, found
{
available_gpus
}
and now ray is not supported here"
)
continue
print
(
f
"Running test for world_size=
{
world_size
}
"
)
multi_process_parallel
(
world_size
,
_run_correctness_worker
,
target_args
=
(
self
.
test_sizes
,)
)
print
(
f
"custom allreduce tp =
{
world_size
}
: OK"
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/test_mscclpp.py
0 → 100644
View file @
8e3797be
"""For Now, MSCCL is only supported on TP16 and TP8 case
if [[ $RANK -eq 0 ]]; then
ray start --block --head --port=6379 &
python3 test_mscclpp.py;
else
ray start --block --address=${MASTER_ADDR}:6379;
fi
"""
import
itertools
import
os
import
random
import
socket
import
unittest
from
contextlib
import
contextmanager
,
nullcontext
from
typing
import
Any
,
List
,
Optional
,
Union
import
ray
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
,
ReduceOp
from
sglang.srt.distributed
import
init_distributed_environment
from
sglang.srt.distributed.communication_op
import
(
# noqa
tensor_model_parallel_all_reduce
,
)
from
sglang.srt.distributed.device_communicators.custom_all_reduce
import
(
CustomAllreduce
,
)
from
sglang.srt.distributed.device_communicators.pymscclpp
import
PyMscclppCommunicator
from
sglang.srt.distributed.device_communicators.pynccl
import
PyNcclCommunicator
from
sglang.srt.distributed.parallel_state
import
(
get_tensor_model_parallel_group
,
graph_capture
,
initialize_model_parallel
,
set_custom_all_reduce
,
set_mscclpp_all_reduce
,
)
from
sglang.srt.distributed.utils
import
StatelessProcessGroup
from
sglang.test.test_utils
import
CustomTestCase
def
get_open_port
()
->
int
:
# try ipv4
try
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
""
,
0
))
return
s
.
getsockname
()[
1
]
except
OSError
:
# try ipv6
with
socket
.
socket
(
socket
.
AF_INET6
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
""
,
0
))
return
s
.
getsockname
()[
1
]
def
multi_process_parallel
(
world_size
:
int
,
master_addr
:
str
,
cls
:
Any
,
test_target
:
Any
,
)
->
None
:
# Using ray helps debugging the error when it failed
# as compared to multiprocessing.
# NOTE: We need to set working_dir for distributed tests,
# otherwise we may get import errors on ray workers
ray
.
init
(
log_to_driver
=
True
)
distributed_init_port
=
get_open_port
()
refs
=
[]
for
rank
in
range
(
world_size
):
refs
.
append
(
test_target
.
remote
(
cls
,
world_size
,
master_addr
,
rank
,
distributed_init_port
)
)
ray
.
get
(
refs
)
ray
.
shutdown
()
class
TestMSCCLAllReduce
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
random
.
seed
(
42
)
# 1KB to 1MB
cls
.
test_sizes
=
[
512
,
4096
,
32768
,
262144
,
524288
]
cls
.
world_sizes
=
[
8
]
TEST_TP16
=
int
(
os
.
getenv
(
"SGL_MSCCLPP_TEST_TP16"
,
"0"
))
if
TEST_TP16
:
cls
.
world_sizes
=
[
16
]
cls
.
test_loop
=
10
def
test_graph_allreduce
(
self
):
TEST_MASTER_ADDR
=
os
.
getenv
(
"SGL_MSCCLPP_TEST_MASTER_ADDR"
,
"localhost"
)
for
world_size
in
self
.
world_sizes
:
if
world_size
not
in
[
8
,
16
]:
continue
multi_process_parallel
(
world_size
,
TEST_MASTER_ADDR
,
self
,
self
.
graph_allreduce
)
def
test_eager_allreduce
(
self
):
TEST_MASTER_ADDR
=
os
.
getenv
(
"SGL_MSCCLPP_TEST_MASTER_ADDR"
,
"localhost"
)
for
world_size
in
self
.
world_sizes
:
if
world_size
not
in
[
8
,
16
]:
continue
multi_process_parallel
(
world_size
,
TEST_MASTER_ADDR
,
self
,
self
.
eager_allreduce
)
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
def
graph_allreduce
(
self
,
world_size
,
master_addr
,
rank
,
distributed_init_port
):
del
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
device
=
torch
.
device
(
f
"cuda:
{
rank
%
torch
.
cuda
.
device_count
()
}
"
)
torch
.
cuda
.
set_device
(
device
)
distributed_init_method
=
f
"tcp://
{
master_addr
}
:
{
distributed_init_port
}
"
set_mscclpp_all_reduce
(
True
)
set_custom_all_reduce
(
False
)
init_distributed_environment
(
world_size
=
world_size
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
local_rank
=
rank
%
torch
.
cuda
.
device_count
(),
)
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
group
=
get_tensor_model_parallel_group
().
device_group
# A small all_reduce for warmup.
# this is needed because device communicators might be created lazily
# (e.g. NCCL). This will ensure that the communicator is initialized
# before any communication happens, so that this group can be used for
# graph capture immediately.
data
=
torch
.
zeros
(
1
)
data
=
data
.
to
(
device
=
device
)
torch
.
distributed
.
all_reduce
(
data
,
group
=
group
)
torch
.
cuda
.
synchronize
()
del
data
for
sz
in
self
.
test_sizes
:
for
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]:
for
_
in
range
(
self
.
test_loop
):
with
graph_capture
()
as
graph_capture_context
:
# use integers so result matches NCCL exactly
inp1
=
torch
.
randint
(
1
,
16
,
(
sz
,),
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
inp2
=
torch
.
randint
(
1
,
16
,
(
sz
,),
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
torch
.
cuda
.
synchronize
()
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
,
stream
=
graph_capture_context
.
stream
):
out1
=
tensor_model_parallel_all_reduce
(
inp1
)
# the input buffer is immediately modified to test
# synchronization
dist
.
all_reduce
(
inp1
,
group
=
group
)
out2
=
tensor_model_parallel_all_reduce
(
inp2
)
dist
.
all_reduce
(
inp2
,
group
=
group
)
graph
.
replay
()
torch
.
testing
.
assert_close
(
out1
,
inp1
)
torch
.
testing
.
assert_close
(
out2
,
inp2
)
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
def
eager_allreduce
(
self
,
world_size
,
master_addr
,
rank
,
distributed_init_port
):
del
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
device
=
torch
.
device
(
f
"cuda:
{
rank
%
torch
.
cuda
.
device_count
()
}
"
)
torch
.
cuda
.
set_device
(
device
)
distributed_init_method
=
f
"tcp://
{
master_addr
}
:
{
distributed_init_port
}
"
set_mscclpp_all_reduce
(
True
)
set_custom_all_reduce
(
False
)
init_distributed_environment
(
world_size
=
world_size
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
local_rank
=
rank
,
)
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
group
=
get_tensor_model_parallel_group
().
device_group
for
sz
in
self
.
test_sizes
:
for
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]:
for
_
in
range
(
self
.
test_loop
):
inp1
=
torch
.
randint
(
1
,
16
,
(
sz
,),
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
()
)
out1
=
tensor_model_parallel_all_reduce
(
inp1
)
dist
.
all_reduce
(
inp1
,
group
=
group
)
torch
.
testing
.
assert_close
(
out1
,
inp1
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment