Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8e3797be
Unverified
Commit
8e3797be
authored
Jun 05, 2025
by
zyksir
Committed by
GitHub
Jun 04, 2025
Browse files
support 1 shot allreduce in 1-node and 2-node using mscclpp (#6277)
parent
4474eaf5
Changes
20
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2177 additions
and
12 deletions
+2177
-12
benchmark/kernels/all_reduce/benchmark_mscclpp.py
benchmark/kernels/all_reduce/benchmark_mscclpp.py
+224
-0
benchmark/lora/launch_server.py
benchmark/lora/launch_server.py
+7
-0
docs/backend/server_arguments.md
docs/backend/server_arguments.md
+1
-0
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+34
-0
python/sglang/srt/distributed/device_communicators/pymscclpp.py
.../sglang/srt/distributed/device_communicators/pymscclpp.py
+315
-0
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+48
-4
python/sglang/srt/layers/dp_attention.py
python/sglang/srt/layers/dp_attention.py
+6
-5
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+2
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-0
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+31
-1
sgl-kernel/Makefile
sgl-kernel/Makefile
+2
-2
sgl-kernel/build.sh
sgl-kernel/build.sh
+3
-0
sgl-kernel/csrc/allreduce/mscclpp_allreduce.cu
sgl-kernel/csrc/allreduce/mscclpp_allreduce.cu
+140
-0
sgl-kernel/csrc/allreduce/mscclpp_allreduce.cuh
sgl-kernel/csrc/allreduce/mscclpp_allreduce.cuh
+779
-0
sgl-kernel/csrc/allreduce/test_mscclpp_allreduce.cu
sgl-kernel/csrc/allreduce/test_mscclpp_allreduce.cu
+153
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+9
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+12
-0
sgl-kernel/python/sgl_kernel/allreduce.py
sgl-kernel/python/sgl_kernel/allreduce.py
+54
-0
sgl-kernel/tests/test_mscclpp.py
sgl-kernel/tests/test_mscclpp.py
+146
-0
test/srt/test_mscclpp.py
test/srt/test_mscclpp.py
+205
-0
No files found.
benchmark/kernels/all_reduce/benchmark_mscclpp.py
0 → 100644
View file @
8e3797be
"""For Now, MSCCL is only supported on TP16 and TP8 case
export WORLD_SIZE=1
export RANK=0
export MASTER_ADDR=127.0.0.1
export MASTER_PORT=12345
torchrun --nproc_per_node gpu
\
--nnodes $WORLD_SIZE
\
--node_rank $RANK
\
--master_addr $MASTER_ADDR
\
--master_port $MASTER_PORT benchmark/kernels/all_reduce/benchmark_mscclpp.py
"""
import
os
from
contextlib
import
nullcontext
from
typing
import
List
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
from
sglang.srt.distributed
import
init_distributed_environment
from
sglang.srt.distributed.device_communicators.pymscclpp
import
PyMscclppCommunicator
from
sglang.srt.distributed.device_communicators.pynccl
import
PyNcclCommunicator
from
sglang.srt.distributed.parallel_state
import
(
get_tensor_model_parallel_group
,
graph_capture
,
initialize_model_parallel
,
set_mscclpp_all_reduce
,
)
def
torch_allreduce
(
torch_input
:
torch
.
Tensor
,
group
:
ProcessGroup
)
->
torch
.
Tensor
:
dist
.
all_reduce
(
torch_input
,
group
=
group
)
return
torch_input
def
msccl_allreduce
(
msccl_input
:
torch
.
Tensor
,
msccl_comm
:
PyMscclppCommunicator
)
->
torch
.
Tensor
:
return
msccl_comm
.
all_reduce
(
msccl_input
)
def
pynccl_allreduce
(
msccl_input
:
torch
.
Tensor
,
pynccl_comm
:
PyNcclCommunicator
)
->
torch
.
Tensor
:
pynccl_comm
.
all_reduce
(
msccl_input
)
return
msccl_input
def
_bench_graph_time
(
func
,
inp_randn
,
warmup_loop
=
2
,
graph_loop
=
10
,
test_loop
=
10
):
graph_input
=
inp_randn
.
clone
()
with
graph_capture
()
as
graph_capture_context
:
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
,
stream
=
graph_capture_context
.
stream
):
for
_
in
range
(
graph_loop
):
graph_out
=
func
(
graph_input
)
graph
.
replay
()
func_output
=
graph_out
.
clone
()
for
_
in
range
(
warmup_loop
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
latencies
:
List
[
float
]
=
[]
for
_
in
range
(
test_loop
):
torch
.
cuda
.
synchronize
()
dist
.
barrier
()
start_event
.
record
()
graph
.
replay
()
end_event
.
record
()
end_event
.
synchronize
()
latencies
.
append
(
start_event
.
elapsed_time
(
end_event
))
func_cost_us
=
sum
(
latencies
)
/
len
(
latencies
)
/
graph_loop
*
1000
graph
.
reset
()
return
func_output
,
func_cost_us
def
_bench_eager_time
(
func
,
inp_randn
,
warmup_loop
=
2
,
test_loop
=
10
):
eager_input
=
inp_randn
.
clone
()
eager_output
=
func
(
eager_input
)
func_output
=
eager_output
.
clone
()
for
_
in
range
(
warmup_loop
):
func
(
eager_input
)
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
torch
.
cuda
.
synchronize
()
start_event
.
record
()
for
_
in
range
(
test_loop
):
func
(
eager_input
)
end_event
.
record
()
torch
.
cuda
.
synchronize
()
func_cost_us
=
start_event
.
elapsed_time
(
end_event
)
/
test_loop
*
1000
return
func_output
,
func_cost_us
def
get_torch_prof_ctx
(
do_prof
:
bool
):
ctx
=
(
torch
.
profiler
.
profile
(
activities
=
[
torch
.
profiler
.
ProfilerActivity
.
CPU
,
torch
.
profiler
.
ProfilerActivity
.
CUDA
,
],
record_shapes
=
True
,
with_stack
=
True
,
)
if
do_prof
else
nullcontext
()
)
return
ctx
def
human_readable_size
(
size
,
decimal_places
=
1
):
for
unit
in
[
"B"
,
"KiB"
,
"MiB"
,
"GiB"
,
"TiB"
,
"PiB"
]:
if
size
<
1024.0
or
unit
==
"PiB"
:
break
size
/=
1024.0
return
f
"
{
size
:.
{
decimal_places
}
f
}
{
unit
}
"
try
:
from
tabulate
import
tabulate
except
ImportError
:
print
(
"tabulate not installed, skipping table printing"
)
tabulate
=
None
def
print_markdown_table
(
data
):
if
tabulate
is
not
None
:
print
(
tabulate
(
data
,
headers
=
"keys"
,
tablefmt
=
"github"
))
return
headers
=
data
[
0
].
keys
()
header_row
=
"| "
+
" | "
.
join
(
headers
)
+
" |"
separator
=
"| "
+
" | "
.
join
([
"---"
]
*
len
(
headers
))
+
" |"
rows
=
[]
for
item
in
data
:
row
=
"| "
+
" | "
.
join
(
str
(
item
[
key
])
for
key
in
headers
)
+
" |"
rows
.
append
(
row
)
markdown_table
=
"
\n
"
.
join
([
header_row
,
separator
]
+
rows
)
print
(
markdown_table
)
if
__name__
==
"__main__"
:
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
"%(asctime)s - %(levelname)s - %(message)s"
,
datefmt
=
"%Y-%m-%d %H:%M:%S"
,
force
=
True
,
)
if
not
dist
.
is_initialized
():
dist
.
init_process_group
(
backend
=
"nccl"
)
world
,
world_size
=
dist
.
group
.
WORLD
,
dist
.
get_world_size
()
rank
=
dist
.
get_rank
()
torch
.
cuda
.
set_device
(
rank
%
8
)
device
=
torch
.
cuda
.
current_device
()
set_mscclpp_all_reduce
(
True
)
init_distributed_environment
(
world_size
=
world_size
,
rank
=
rank
,
local_rank
=
rank
%
8
,
)
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
group
=
get_tensor_model_parallel_group
().
device_group
cpu_group
=
get_tensor_model_parallel_group
().
cpu_group
pynccl_comm
=
get_tensor_model_parallel_group
().
pynccl_comm
pymscclpp_comm
=
get_tensor_model_parallel_group
().
pymscclpp_comm
dist
.
barrier
()
profile
=
False
dtype
=
torch
.
bfloat16
ctx
=
get_torch_prof_ctx
(
profile
)
result
=
[]
with
ctx
:
for
i
in
range
(
10
,
20
):
sz
=
2
**
i
if
sz
*
dtype
.
itemsize
>
2
**
20
:
break
inp_randn
=
torch
.
randint
(
1
,
16
,
(
sz
,),
dtype
=
dtype
,
device
=
device
)
memory
=
torch
.
empty_like
(
inp_randn
)
memory_out
=
torch
.
empty_like
(
memory
)
torch_eager_output
,
torch_eager_time
=
_bench_eager_time
(
lambda
inp
:
torch_allreduce
(
inp
,
group
),
inp_randn
)
msccl_eager_output
,
msccl_eager_time
=
_bench_eager_time
(
lambda
inp
:
msccl_allreduce
(
inp
,
pymscclpp_comm
),
inp_randn
)
msccl_graph_output
,
msccl_graph_time
=
_bench_graph_time
(
lambda
inp
:
msccl_allreduce
(
inp
,
pymscclpp_comm
),
inp_randn
)
# since pynccl is inplace op, this return result is not correct if graph loop > 1
_
,
pynccl_graph_time
=
_bench_graph_time
(
lambda
inp
:
pynccl_allreduce
(
inp
,
pynccl_comm
),
inp_randn
)
torch
.
testing
.
assert_close
(
torch_eager_output
,
msccl_graph_output
)
torch
.
testing
.
assert_close
(
torch_eager_output
,
msccl_eager_output
)
result
.
append
(
{
"msg_size"
:
human_readable_size
(
inp_randn
.
nbytes
),
"torch eager time"
:
torch_eager_time
,
"msccl eager time"
:
msccl_eager_time
,
"msccl graph time"
:
msccl_graph_time
,
"pynccl graph time"
:
pynccl_graph_time
,
}
)
if
rank
==
0
:
print
(
f
"sz=
{
sz
}
, dtype=
{
dtype
}
: correctness check PASS!"
)
if
rank
==
0
:
print_markdown_table
(
result
)
if
profile
:
prof_dir
=
f
"prof/msccl"
os
.
makedirs
(
prof_dir
,
exist_ok
=
True
)
ctx
.
export_chrome_trace
(
f
"
{
prof_dir
}
/trace_rank
{
dist
.
get_rank
()
}
.json.gz"
)
benchmark/lora/launch_server.py
View file @
8e3797be
...
...
@@ -26,6 +26,8 @@ def launch_server(args):
cmd
+=
f
"--tp-size
{
args
.
tp_size
}
"
if
args
.
disable_custom_all_reduce
:
cmd
+=
"--disable-custom-all-reduce"
if
args
.
enable_mscclpp
:
cmd
+=
"--enable-mscclpp"
print
(
cmd
)
os
.
system
(
cmd
)
...
...
@@ -63,6 +65,11 @@ if __name__ == "__main__":
action
=
"store_true"
,
help
=
"Disable custom all reduce when device does not support p2p communication"
,
)
parser
.
add_argument
(
"--enable-mscclpp"
,
action
=
"store_true"
,
help
=
"Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL."
,
)
args
=
parser
.
parse_args
()
launch_server
(
args
)
docs/backend/server_arguments.md
View file @
8e3797be
...
...
@@ -201,6 +201,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
`disable_cuda_graph_padding`
| Disable CUDA Graph when padding is needed; otherwise, still use CUDA Graph. |
`False`
|
|
`disable_outlines_disk_cache`
| Disable disk cache for outlines grammar backend. |
`False`
|
|
`disable_custom_all_reduce`
| Disable usage of custom all-reduce kernel. |
`False`
|
|
`enable_mscclpp`
| Enable usage of mscclpp kernel for small message all-reduce. |
`False`
|
|
`disable_overlap_schedule`
| Disable the
[
Overhead-Scheduler
](
https://lmsys.org/blog/2024-12-04-sglang-v0-4/#zero-overhead-batch-scheduler
)
. |
`False`
|
|
`enable_nan_detection`
| Enable warning if the logits contain
`NaN`
. |
`False`
|
|
`enable_p2p_check`
| Turns off the default of always allowing P2P checks when accessing GPU. |
`False`
|
...
...
python/sglang/srt/_custom_ops.py
View file @
8e3797be
...
...
@@ -113,3 +113,37 @@ else:
def
get_meta_buffer_ipc_handle
(
inp
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
sgl_kernel
.
allreduce
.
get_meta_buffer_ipc_handle
(
inp
)
def
mscclpp_generate_unique_id
()
->
bytes
:
return
sgl_kernel
.
allreduce
.
mscclpp_generate_unique_id
()
def
mscclpp_init_context
(
unique_id
:
bytes
,
rank
:
int
,
world_size
:
int
,
scratch
:
torch
.
Tensor
,
put_buffer
:
torch
.
Tensor
,
nranks_per_node
:
int
,
rank_to_node
:
List
[
int
],
rank_to_ib
:
List
[
int
],
context_selection
:
int
,
)
->
int
:
return
sgl_kernel
.
allreduce
.
mscclpp_init_context
(
unique_id
,
rank
,
world_size
,
scratch
,
put_buffer
,
nranks_per_node
,
rank_to_node
,
rank_to_ib
,
context_selection
,
)
def
mscclpp_allreduce
(
context
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
nthreads
:
int
,
nblocks
:
int
)
->
None
:
return
sgl_kernel
.
allreduce
.
mscclpp_allreduce
(
context
,
inp
,
out
,
nthreads
,
nblocks
)
python/sglang/srt/distributed/device_communicators/pymscclpp.py
0 → 100644
View file @
8e3797be
import
bisect
import
logging
import
math
import
os
from
contextlib
import
contextmanager
from
enum
import
IntEnum
from
typing
import
Any
,
Callable
,
List
,
Optional
,
TypeVar
,
Union
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
,
ReduceOp
from
sglang.srt
import
_custom_ops
as
ops
from
sglang.srt.utils
import
is_cuda
,
is_hip
logger
=
logging
.
getLogger
(
__name__
)
_is_cuda
=
is_cuda
()
_is_hip
=
is_hip
()
mscclpp_is_available
=
False
if
_is_hip
:
# TODO(zyksir): mscclpp is untested on AMD and therefore disabled.
mscclpp_is_available
=
False
if
_is_cuda
:
try
:
import
sgl_kernel
mscclpp_is_available
=
True
except
:
mscclpp_is_available
=
False
class
MscclContextSelection
(
IntEnum
):
MSCCL1SHOT1NODELL
=
1
MSCCL1SHOT2NODELL
=
2
def
mscclpp_is_weak_contiguous
(
inp
:
torch
.
Tensor
):
return
inp
.
is_contiguous
()
or
(
inp
.
storage
().
nbytes
()
-
inp
.
storage_offset
()
*
inp
.
element_size
()
==
inp
.
numel
()
*
inp
.
element_size
()
)
def
mscclpp_convert_to_bytes
(
size_str
):
"""
Converts a human-readable size string (e.g., "1MB", "2.5kb", "3 GB")
into the equivalent number of bytes using binary units.
Args:
size_str (str): A string representing size with unit (KB, MB, GB).
Returns:
int: Number of bytes.
"""
size_str
=
size_str
.
strip
().
lower
()
if
not
size_str
:
raise
ValueError
(
"Empty input string"
)
# Extract numeric part and unit
for
i
in
range
(
len
(
size_str
)):
if
not
size_str
[
i
].
isdigit
()
and
size_str
[
i
]
!=
"."
:
break
num_str
=
size_str
[:
i
]
unit
=
size_str
[
i
:].
strip
()
try
:
num
=
float
(
num_str
)
except
ValueError
:
raise
ValueError
(
f
"Invalid numeric value in '
{
size_str
}
'"
)
# Conversion factors
if
unit
==
"b"
:
return
int
(
num
)
elif
unit
==
"kb"
:
return
int
(
num
*
1024
)
elif
unit
==
"mb"
:
return
int
(
num
*
1024
*
1024
)
elif
unit
==
"gb"
:
return
int
(
num
*
1024
*
1024
*
1024
)
else
:
raise
ValueError
(
f
"Unsupported unit:
{
unit
}
, support B, KB, MB, GB only"
)
def
mscclpp_bench_time
(
func
,
test_niter
:
int
=
10
,
warmup_niter
:
int
=
2
):
# warmup
for
_
in
range
(
warmup_niter
):
func
()
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
torch
.
cuda
.
synchronize
()
dist
.
barrier
()
start_event
.
record
()
for
_
in
range
(
test_niter
):
func
()
end_event
.
record
()
end_event
.
synchronize
()
func_cost_us
=
start_event
.
elapsed_time
(
end_event
)
/
test_niter
*
1000
return
func_cost_us
class
PyMscclppCommunicator
:
_SUPPORTED_WORLD_SIZES
=
[
8
,
16
]
_MAX_BYTES
=
mscclpp_convert_to_bytes
(
os
.
getenv
(
"SGLANG_MSCCLPP_MAX_BYTES"
,
"1MB"
))
_SUPPORTED_DTYPE
=
[
torch
.
float
,
torch
.
float16
,
torch
.
bfloat16
]
# max_bytes: max supported mscclpp allreduce size
# in A100 mscclpp is faster than nccl only under condition of msg size smaller than1MB
def
__init__
(
self
,
group
:
ProcessGroup
,
device
:
Union
[
int
,
str
,
torch
.
device
],
max_bytes
=
_MAX_BYTES
,
)
->
None
:
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self
.
_IS_CAPTURING
=
False
self
.
disabled
=
True
if
not
mscclpp_is_available
:
# disable because of missing mscclpp library
# e.g. in a non-cuda environment
return
self
.
group
=
group
assert
(
dist
.
get_backend
(
group
)
!=
dist
.
Backend
.
NCCL
),
"CustomAllreduce should be attached to a non-NCCL group."
rank
=
dist
.
get_rank
(
group
=
self
.
group
)
world_size
=
dist
.
get_world_size
(
group
=
self
.
group
)
if
world_size
==
1
:
# No need to initialize mscclpp for single GPU case.
return
if
world_size
not
in
PyMscclppCommunicator
.
_SUPPORTED_WORLD_SIZES
:
logger
.
warning
(
"PyMscclpp is disabled due to an unsupported world"
" size: %d. Supported world sizes: %s. To silence this "
"warning, specify disable_mscclpp=True explicitly."
,
world_size
,
str
(
PyMscclppCommunicator
.
_SUPPORTED_WORLD_SIZES
),
)
return
self
.
ranks
=
torch
.
distributed
.
get_process_group_ranks
(
group
)
self
.
nranks_per_node
=
torch
.
cuda
.
device_count
()
# for now mscclpp with stride in the communicator is not tested
if
not
(
abs
(
self
.
ranks
[
-
1
]
-
self
.
ranks
[
0
])
==
world_size
-
1
):
logger
.
warning
(
"PyMscclpp is disabled due to an unsupported group %s."
"Please ensure all ranks in the group are consecutive."
"To silence this warning, specify disable_mscclpp=True explicitly."
,
str
(
self
.
ranks
),
)
return
if
isinstance
(
device
,
int
):
device
=
torch
.
device
(
f
"cuda:
{
device
}
"
)
elif
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
# now `device` is a `torch.device` object
assert
isinstance
(
device
,
torch
.
device
)
self
.
device
=
device
self
.
max_bytes
=
max_bytes
self
.
rank
=
rank
self
.
world_size
=
world_size
if
dist
.
get_rank
(
group
)
==
0
:
unique_id
=
[
ops
.
mscclpp_generate_unique_id
()]
else
:
unique_id
=
[
None
]
dist
.
broadcast_object_list
(
unique_id
,
src
=
self
.
ranks
[
0
],
group
=
self
.
group
)
self
.
unique_id
=
unique_id
[
0
]
self
.
rank_to_node
,
self
.
rank_to_ib
=
list
(
range
(
world_size
)),
list
(
range
(
world_size
)
)
for
r
in
range
(
world_size
):
self
.
rank_to_node
[
r
]
=
r
//
8
self
.
rank_to_ib
[
r
]
=
self
.
rank
%
8
self
.
_context
=
None
self
.
context_selection
=
None
self
.
msg_size_for_finetune
=
[
2
**
i
for
i
in
range
(
10
,
math
.
floor
(
math
.
log2
(
self
.
max_bytes
))
+
1
)
]
self
.
msg_size2best_config
=
{}
if
world_size
==
8
:
self
.
context_selection
=
MscclContextSelection
.
MSCCL1SHOT1NODELL
elif
world_size
==
16
:
self
.
context_selection
=
MscclContextSelection
.
MSCCL1SHOT2NODELL
if
not
_is_hip
:
self
.
scratch
=
torch
.
empty
(
self
.
max_bytes
*
8
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
,
)
self
.
put_buffer
=
torch
.
empty
(
self
.
max_bytes
*
8
//
self
.
nranks_per_node
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
,
)
self
.
_context
=
ops
.
mscclpp_init_context
(
self
.
unique_id
,
self
.
rank
,
self
.
world_size
,
self
.
scratch
,
self
.
put_buffer
,
self
.
nranks_per_node
,
self
.
rank_to_node
,
self
.
rank_to_ib
,
int
(
self
.
context_selection
),
)
else
:
raise
NotImplementedError
(
"HIP Mscclpp is not supported yet."
)
self
.
msg_size2best_config
=
{}
self
.
pre_tune_config
()
if
dist
.
get_rank
(
group
)
==
0
:
msg_size2best_config
=
[
self
.
msg_size2best_config
]
else
:
msg_size2best_config
=
[
None
]
dist
.
broadcast_object_list
(
msg_size2best_config
,
src
=
self
.
ranks
[
0
],
group
=
self
.
group
)
self
.
msg_size2best_config
=
msg_size2best_config
[
0
]
# PyMscclpp is enabled only in cuda graph
self
.
disabled
=
True
def
pre_tune_config
(
self
,
dtype
=
torch
.
bfloat16
)
->
bool
:
logger
.
debug
(
f
"start to pre-tune configs for rank
{
self
.
rank
}
"
)
nthreads_to_try
=
[
256
,
512
,
1024
]
nblocks_to_try
=
[
21
,
42
,
84
]
inp_randn
=
torch
.
ones
(
self
.
msg_size_for_finetune
[
-
1
]
//
dtype
.
itemsize
,
dtype
=
dtype
,
device
=
"cuda"
)
oup_randn
=
torch
.
empty_like
(
inp_randn
)
for
msg_size
in
self
.
msg_size_for_finetune
:
mock_inp
,
mock_outp
=
(
inp_randn
[:
msg_size
//
dtype
.
itemsize
],
oup_randn
[:
msg_size
//
dtype
.
itemsize
],
)
best_config
,
best_time
=
None
,
None
for
nthreads
in
nthreads_to_try
:
for
nblocks
in
nblocks_to_try
:
cur_cost
=
mscclpp_bench_time
(
lambda
:
ops
.
mscclpp_allreduce
(
self
.
_context
,
mock_inp
,
mock_outp
,
nthreads
,
nblocks
)
)
if
best_time
is
None
or
cur_cost
<
best_time
:
best_config
=
(
nthreads
,
nblocks
)
best_time
=
cur_cost
self
.
msg_size2best_config
[
msg_size
]
=
best_config
if
self
.
rank
==
0
:
logger
.
debug
(
f
"for msg_size
{
msg_size
}
, best_config:
{
best_config
}
, best_time:
{
best_time
}
us"
)
def
should_mscclpp_allreduce
(
self
,
inp
:
torch
.
Tensor
,
op
:
ReduceOp
=
ReduceOp
.
SUM
)
->
bool
:
if
self
.
disabled
or
self
.
_context
is
None
:
return
False
if
inp
.
dtype
not
in
PyMscclppCommunicator
.
_SUPPORTED_DTYPE
:
return
False
if
not
mscclpp_is_weak_contiguous
(
inp
):
return
False
# only support sum op
if
op
!=
ReduceOp
.
SUM
:
return
False
if
inp
.
numel
()
*
inp
.
element_size
()
>
self
.
max_bytes
:
return
False
return
True
def
all_reduce
(
self
,
tensor
:
torch
.
Tensor
,
op
:
ReduceOp
=
ReduceOp
.
SUM
):
if
self
.
_IS_CAPTURING
:
if
torch
.
cuda
.
is_current_stream_capturing
():
self
.
graph_input_set
.
add
((
tensor
.
dtype
,
tensor
.
numel
()))
msg_size
=
tensor
.
numel
()
*
tensor
.
itemsize
index
=
bisect
.
bisect_left
(
self
.
msg_size_for_finetune
,
msg_size
)
msg_size_finetune
=
self
.
msg_size_for_finetune
[
index
]
nthreads
,
nblocks
=
self
.
msg_size2best_config
[
msg_size_finetune
]
result
=
torch
.
empty_like
(
tensor
)
ops
.
mscclpp_allreduce
(
self
.
_context
,
tensor
,
result
,
nthreads
,
nblocks
)
return
result
@
contextmanager
def
change_state
(
self
,
enable
:
Optional
[
bool
]
=
None
,
):
if
enable
is
None
:
# guess a default value when not specified
enable
=
self
.
available
old_disable
=
self
.
disabled
self
.
disabled
=
not
enable
yield
self
.
disabled
=
old_disable
python/sglang/srt/distributed/parallel_state.py
View file @
8e3797be
...
...
@@ -190,6 +190,7 @@ class GroupCoordinator:
cpu_group
:
ProcessGroup
# group for CPU communication
device_group
:
ProcessGroup
# group for device communication
use_pynccl
:
bool
# a hint of whether to use PyNccl
use_pymscclpp
:
bool
# a hint of whether to use PyMsccl
use_custom_allreduce
:
bool
# a hint of whether to use CustomAllreduce
use_message_queue_broadcaster
:
(
bool
# a hint of whether to use message queue broadcaster
...
...
@@ -205,6 +206,7 @@ class GroupCoordinator:
local_rank
:
int
,
torch_distributed_backend
:
Union
[
str
,
Backend
],
use_pynccl
:
bool
,
use_pymscclpp
:
bool
,
use_custom_allreduce
:
bool
,
use_hpu_communicator
:
bool
,
use_xpu_communicator
:
bool
,
...
...
@@ -244,6 +246,7 @@ class GroupCoordinator:
self
.
device
=
torch
.
device
(
"cpu"
)
self
.
use_pynccl
=
use_pynccl
self
.
use_pymscclpp
=
use_pymscclpp
self
.
use_custom_allreduce
=
use_custom_allreduce
self
.
use_hpu_communicator
=
use_hpu_communicator
self
.
use_xpu_communicator
=
use_xpu_communicator
...
...
@@ -265,6 +268,17 @@ class GroupCoordinator:
device
=
self
.
device
,
)
from
sglang.srt.distributed.device_communicators.pymscclpp
import
(
PyMscclppCommunicator
,
)
self
.
pymscclpp_comm
:
Optional
[
PyMscclppCommunicator
]
=
None
if
use_pymscclpp
and
self
.
world_size
>
1
:
self
.
pymscclpp_comm
=
PyMscclppCommunicator
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
)
self
.
ca_comm
:
Optional
[
CustomAllreduce
]
=
None
if
use_custom_allreduce
and
self
.
world_size
>
1
:
# Initialize a custom fast all-reduce implementation.
...
...
@@ -373,11 +387,15 @@ class GroupCoordinator:
# --------------------------------------------
# custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled |
# PyMscclpp | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note that custom allreduce will have a runtime check, if the
# tensor size is too large, it will fallback to the next
# available option.
# Note that the PyMsccl needs to register the tensor in ahead,
# which will introduce large overhead in the eager case,
# therefore it is only supported in the graph case.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using
# CUDA graph, we use either custom all-reduce kernel or
...
...
@@ -392,7 +410,14 @@ class GroupCoordinator:
maybe_pynccl_context
=
pynccl_comm
.
change_state
(
enable
=
True
,
stream
=
torch
.
cuda
.
current_stream
()
)
with
maybe_pynccl_context
:
pymscclpp_comm
=
self
.
pymscclpp_comm
maybe_pymscclpp_context
:
Any
if
not
pymscclpp_comm
:
maybe_pymscclpp_context
=
nullcontext
()
else
:
maybe_pymscclpp_context
=
pymscclpp_comm
.
change_state
(
enable
=
True
)
with
maybe_pynccl_context
,
maybe_pymscclpp_context
:
yield
graph_capture_context
def
all_reduce
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -437,6 +462,10 @@ class GroupCoordinator:
self
.
ca_comm
is
not
None
and
not
self
.
ca_comm
.
disabled
and
self
.
ca_comm
.
should_custom_ar
(
input_
)
)
or
(
self
.
pymscclpp_comm
is
not
None
and
not
self
.
pymscclpp_comm
.
disabled
and
self
.
pymscclpp_comm
.
should_mscclpp_allreduce
(
input_
)
):
return
torch
.
ops
.
sglang
.
outplace_all_reduce
(
input_
,
group_name
=
self
.
unique_name
...
...
@@ -447,9 +476,13 @@ class GroupCoordinator:
def
_all_reduce_out_place
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
ca_comm
=
self
.
ca_comm
assert
ca_comm
is
not
None
assert
not
ca_comm
.
disabled
out
=
ca_comm
.
custom_all_reduce
(
input_
)
pymscclpp_comm
=
self
.
pymscclpp_comm
assert
ca_comm
is
not
None
or
pymscclpp_comm
is
not
None
if
ca_comm
is
not
None
and
not
ca_comm
.
disabled
:
out
=
ca_comm
.
custom_all_reduce
(
input_
)
else
:
assert
not
pymscclpp_comm
.
disabled
out
=
pymscclpp_comm
.
all_reduce
(
input_
)
assert
out
is
not
None
return
out
...
...
@@ -958,6 +991,7 @@ def init_world_group(
local_rank
=
local_rank
,
torch_distributed_backend
=
backend
,
use_pynccl
=
False
,
use_pymscclpp
=
False
,
use_custom_allreduce
=
False
,
use_hpu_communicator
=
False
,
use_xpu_communicator
=
False
,
...
...
@@ -973,14 +1007,18 @@ def init_model_parallel_group(
use_custom_allreduce
:
Optional
[
bool
]
=
None
,
use_message_queue_broadcaster
:
bool
=
False
,
group_name
:
Optional
[
str
]
=
None
,
use_mscclpp_allreduce
:
Optional
[
bool
]
=
None
,
)
->
GroupCoordinator
:
if
use_custom_allreduce
is
None
:
use_custom_allreduce
=
_ENABLE_CUSTOM_ALL_REDUCE
if
use_mscclpp_allreduce
is
None
:
use_mscclpp_allreduce
=
_ENABLE_MSCCLPP_ALL_REDUCE
return
GroupCoordinator
(
group_ranks
=
group_ranks
,
local_rank
=
local_rank
,
torch_distributed_backend
=
backend
,
use_pynccl
=
not
is_npu
(),
use_pymscclpp
=
use_mscclpp_allreduce
,
use_custom_allreduce
=
use_custom_allreduce
,
use_hpu_communicator
=
True
,
use_xpu_communicator
=
True
,
...
...
@@ -1037,6 +1075,7 @@ def graph_capture():
logger
=
logging
.
getLogger
(
__name__
)
_ENABLE_CUSTOM_ALL_REDUCE
=
True
_ENABLE_MSCCLPP_ALL_REDUCE
=
False
def
set_custom_all_reduce
(
enable
:
bool
):
...
...
@@ -1044,6 +1083,11 @@ def set_custom_all_reduce(enable: bool):
_ENABLE_CUSTOM_ALL_REDUCE
=
enable
def
set_mscclpp_all_reduce
(
enable
:
bool
):
global
_ENABLE_MSCCLPP_ALL_REDUCE
_ENABLE_MSCCLPP_ALL_REDUCE
=
enable
def
init_distributed_environment
(
world_size
:
int
=
-
1
,
rank
:
int
=
-
1
,
...
...
python/sglang/srt/layers/dp_attention.py
View file @
8e3797be
...
...
@@ -98,11 +98,12 @@ def initialize_dp_attention(
],
local_rank
,
torch
.
distributed
.
get_backend
(
tp_group
.
device_group
),
SYNC_TOKEN_IDS_ACROSS_TP
,
False
,
False
,
False
,
False
,
use_pynccl
=
SYNC_TOKEN_IDS_ACROSS_TP
,
use_pymscclpp
=
False
,
use_custom_allreduce
=
False
,
use_hpu_communicator
=
False
,
use_xpu_communicator
=
False
,
use_npu_communicator
=
False
,
group_name
=
"attention_tp"
,
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
8e3797be
...
...
@@ -35,6 +35,7 @@ from sglang.srt.distributed import (
init_distributed_environment
,
initialize_model_parallel
,
set_custom_all_reduce
,
set_mscclpp_all_reduce
,
)
from
sglang.srt.distributed.parallel_state
import
monkey_patch_vllm_parallel_state
from
sglang.srt.layers.attention.tbo_backend
import
TboAttnBackend
...
...
@@ -460,6 +461,7 @@ class ModelRunner:
else
:
dist_init_method
=
f
"tcp://127.0.0.1:
{
self
.
dist_port
}
"
set_custom_all_reduce
(
not
self
.
server_args
.
disable_custom_all_reduce
)
set_mscclpp_all_reduce
(
self
.
server_args
.
enable_mscclpp
)
if
not
self
.
is_draft_worker
:
# Only initialize the distributed environment on the target model worker.
...
...
python/sglang/srt/server_args.py
View file @
8e3797be
...
...
@@ -165,6 +165,7 @@ class ServerArgs:
enable_tokenizer_batch_encode
:
bool
=
False
disable_outlines_disk_cache
:
bool
=
False
disable_custom_all_reduce
:
bool
=
False
enable_mscclpp
:
bool
=
False
disable_overlap_schedule
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_dp_attention
:
bool
=
False
...
...
@@ -1168,6 +1169,11 @@ class ServerArgs:
action
=
"store_true"
,
help
=
"Disable the custom all-reduce kernel and fall back to NCCL."
,
)
parser
.
add_argument
(
"--enable-mscclpp"
,
action
=
"store_true"
,
help
=
"Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL."
,
)
parser
.
add_argument
(
"--disable-overlap-schedule"
,
action
=
"store_true"
,
...
...
sgl-kernel/CMakeLists.txt
View file @
8e3797be
...
...
@@ -73,6 +73,14 @@ FetchContent_Declare(
GIT_SHALLOW OFF
)
FetchContent_Populate
(
repo-flash-attention
)
# mscclpp
FetchContent_Declare
(
repo-mscclpp
GIT_REPOSITORY https://github.com/microsoft/mscclpp.git
GIT_TAG 51eca89d20f0cfb3764ccd764338d7b22cd486a6
GIT_SHALLOW OFF
)
FetchContent_Populate
(
repo-mscclpp
)
# ccache option
option
(
ENABLE_CCACHE
"Whether to use ccache"
ON
)
...
...
@@ -99,6 +107,7 @@ include_directories(
${
repo-cutlass_SOURCE_DIR
}
/tools/util/include
${
repo-flashinfer_SOURCE_DIR
}
/include
${
repo-flashinfer_SOURCE_DIR
}
/csrc
${
repo-mscclpp_SOURCE_DIR
}
/include
)
set
(
SGL_KERNEL_CUDA_FLAGS
...
...
@@ -196,6 +205,7 @@ string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE
string
(
REPLACE
"-D__CUDA_NO_HALF2_OPERATORS__"
""
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
"
)
set
(
SOURCES
"csrc/allreduce/mscclpp_allreduce.cu"
"csrc/allreduce/custom_all_reduce.cu"
"csrc/attention/cascade.cu"
"csrc/attention/merge_attn_states.cu"
...
...
@@ -250,7 +260,27 @@ target_include_directories(common_ops PRIVATE
${
repo-cutlass_SOURCE_DIR
}
/examples/common
${
repo-flash-attention_SOURCE_DIR
}
/csrc/flash_attn/src
)
target_link_libraries
(
common_ops PRIVATE
${
TORCH_LIBRARIES
}
c10 cuda cublas cublasLt
)
find_package
(
Python3 COMPONENTS Interpreter REQUIRED
)
execute_process
(
COMMAND
${
Python3_EXECUTABLE
}
-c
"import torch; print(int(torch._C._GLIBCXX_USE_CXX11_ABI))"
OUTPUT_VARIABLE TORCH_CXX11_ABI
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if
(
TORCH_CXX11_ABI STREQUAL
"0"
)
message
(
STATUS
"Using old C++ ABI (-D_GLIBCXX_USE_CXX11_ABI=0)"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-D_GLIBCXX_USE_CXX11_ABI=0"
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-D_GLIBCXX_USE_CXX11_ABI=0"
)
else
()
message
(
STATUS
"Using new C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI=1)"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-D_GLIBCXX_USE_CXX11_ABI=1"
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-D_GLIBCXX_USE_CXX11_ABI=1"
)
endif
()
set
(
MSCCLPP_USE_CUDA ON
)
set
(
MSCCLPP_BYPASS_GPU_CHECK ON
)
set
(
MSCCLPP_BUILD_TESTS OFF
)
add_subdirectory
(
${
repo-mscclpp_SOURCE_DIR
}
)
target_link_libraries
(
common_ops PRIVATE
${
TORCH_LIBRARIES
}
c10 cuda cublas cublasLt mscclpp_static
)
target_compile_definitions
(
common_ops PRIVATE
FLASHATTENTION_DISABLE_BACKWARD
...
...
sgl-kernel/Makefile
View file @
8e3797be
...
...
@@ -19,14 +19,14 @@ submodule: ## Initialize and update git submodules
@
git submodule update
--init
--recursive
ln
:
submodule
##
Create compilation database
@
rm
-rf
build
&&
mkdir
build
&&
cd
build
&&
cmake ..
-DCMAKE_EXPORT_COMPILE_COMMANDS
=
YES
@
rm
-rf
build
&&
mkdir
build
&&
cd
build
&&
cmake ..
-DCMAKE_EXPORT_COMPILE_COMMANDS
=
YES
-DCMAKE_POLICY_VERSION_MINIMUM
=
3.5
install
:
submodule
##
Install package in development mode
@
pip
install
-e
.
--no-build-isolation
build
:
install-deps submodule
##
Build and install wheel package
@
rm
-rf
dist/
*
||
true
&&
export
MAX_JOBS
=
$(nproc)
&&
CMAKE_BUILD_PARALLEL_LEVEL
=
$(nproc)
uv build
--wheel
-Cbuild-dir
=
build
.
--verbose
--color
=
always
--no-build-isolation
&&
pip3
install
dist/
*
whl
--force-reinstall
--no-deps
@
rm
-rf
dist/
*
||
true
&&
export
MAX_JOBS
=
$(nproc)
&&
CMAKE_POLICY_VERSION_MINIMUM
=
3.5
CMAKE_BUILD_PARALLEL_LEVEL
=
$(nproc)
uv build
--wheel
-Cbuild-dir
=
build
.
--verbose
--color
=
always
--no-build-isolation
&&
pip3
install
dist/
*
whl
--force-reinstall
--no-deps
clean
:
##
Remove build artifacts
@
rm
-rf
build dist
*
.egg-info
...
...
sgl-kernel/build.sh
View file @
8e3797be
...
...
@@ -50,6 +50,9 @@ docker run --rm \
which cmake
cmake --version
yum install numactl-devel -y &&
\
yum install libibverbs -y &&
\
ln -sv /usr/lib64/libibverbs.so.1 /usr/lib64/libibverbs.so &&
\
${
PYTHON_ROOT_PATH
}
/bin/
${
TORCH_INSTALL
}
&&
\
${
PYTHON_ROOT_PATH
}
/bin/pip install --no-cache-dir ninja setuptools==75.0.0 wheel==0.41.0 numpy uv scikit-build-core &&
\
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX' &&
\
...
...
sgl-kernel/csrc/allreduce/mscclpp_allreduce.cu
0 → 100644
View file @
8e3797be
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/all.h>
#include <torch/library.h>
#include "mscclpp_allreduce.cuh"
enum
MscclContextSelection
{
MSCCL1NODELL
=
1
,
MSCCL2NODELL
=
2
,
};
class
MscclContext
{
public:
MscclContextSelection
selection_
;
std
::
shared_ptr
<
sglang
::
Msccl1NodeLLcontext
>
msccl_1nodeLL_context
;
std
::
shared_ptr
<
sglang
::
Msccl2NodeLLcontext
>
msccl_2nodeLL_context
;
MscclContext
(
MscclContextSelection
selection
)
:
selection_
(
selection
)
{}
template
<
typename
T
>
void
allreduce
(
cudaStream_t
stream
,
T
*
input
,
T
*
output
,
const
size_t
input_numel
,
int
threads
=
512
,
int
block_limit
=
21
)
{
if
(
selection_
==
MSCCL1NODELL
)
{
msccl_1nodeLL_context
->
allreduce
<
T
>
(
stream
,
input
,
output
,
input_numel
,
threads
,
block_limit
);
}
else
if
(
selection_
==
MSCCL2NODELL
)
{
msccl_2nodeLL_context
->
allreduce
<
T
>
(
stream
,
input
,
output
,
input_numel
,
threads
,
block_limit
);
}
}
};
using
fptr_t
=
int64_t
;
static_assert
(
sizeof
(
void
*
)
==
sizeof
(
fptr_t
));
torch
::
Tensor
_unique_id2tensor
(
const
mscclpp
::
UniqueId
&
unique_id
)
{
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kByte
).
device
(
torch
::
kCPU
);
auto
tensor
=
torch
::
empty
({
static_cast
<
int64_t
>
(
unique_id
.
size
())},
options
);
std
::
memcpy
(
tensor
.
data_ptr
<
uint8_t
>
(),
unique_id
.
data
(),
unique_id
.
size
());
return
tensor
;
}
// Function to convert vector of int32_t back to array of uint8_t
mscclpp
::
UniqueId
_tensor2unique_id
(
const
torch
::
Tensor
&
tensor
)
{
mscclpp
::
UniqueId
unique_id
;
std
::
memcpy
(
unique_id
.
data
(),
tensor
.
data_ptr
<
uint8_t
>
(),
unique_id
.
size
());
return
unique_id
;
}
torch
::
Tensor
mscclpp_generate_unique_id
()
{
mscclpp
::
UniqueId
unique_id
=
mscclpp
::
TcpBootstrap
::
createUniqueId
();
return
_unique_id2tensor
(
unique_id
);
}
fptr_t
mscclpp_init_context
(
const
torch
::
Tensor
&
unique_id
,
const
int64_t
rank
,
const
int64_t
world_size
,
torch
::
Tensor
&
scratch
,
torch
::
Tensor
&
put_buffer
,
const
int64_t
nranks_per_node
,
const
std
::
vector
<
int64_t
>&
rank_to_node
,
const
std
::
vector
<
int64_t
>&
rank_to_ib
,
const
int64_t
context_selection
)
{
MscclContext
*
context_ptr
=
new
MscclContext
(
static_cast
<
MscclContextSelection
>
(
context_selection
));
mscclpp
::
UniqueId
uid
=
_tensor2unique_id
(
unique_id
);
if
(
context_selection
==
MSCCL1NODELL
)
{
void
*
scratch_ptr
=
reinterpret_cast
<
void
*>
(
scratch
.
data_ptr
());
const
size_t
scratch_bytes
=
scratch
.
numel
()
*
scratch
.
element_size
();
context_ptr
->
msccl_1nodeLL_context
=
std
::
make_shared
<
sglang
::
Msccl1NodeLLcontext
>
(
uid
,
rank
,
world_size
,
scratch_ptr
,
scratch_bytes
,
nranks_per_node
,
rank_to_node
,
rank_to_ib
);
}
else
if
(
context_selection
==
MSCCL2NODELL
)
{
void
*
scratch_ptr
=
reinterpret_cast
<
void
*>
(
scratch
.
data_ptr
());
const
size_t
scratch_bytes
=
scratch
.
numel
()
*
scratch
.
element_size
();
void
*
put_buffer_ptr
=
reinterpret_cast
<
void
*>
(
put_buffer
.
data_ptr
());
const
size_t
put_buffer_bytes
=
put_buffer
.
numel
()
*
put_buffer
.
element_size
();
context_ptr
->
msccl_2nodeLL_context
=
std
::
make_shared
<
sglang
::
Msccl2NodeLLcontext
>
(
uid
,
rank
,
world_size
,
scratch_ptr
,
scratch_bytes
,
put_buffer_ptr
,
put_buffer_bytes
,
nranks_per_node
,
rank_to_node
,
rank_to_ib
);
}
else
{
throw
std
::
runtime_error
(
"invalid context selection"
);
}
return
(
fptr_t
)
context_ptr
;
}
bool
_mscclpp_is_weak_contiguous
(
torch
::
Tensor
&
t
)
{
return
t
.
is_contiguous
()
||
(
t
.
storage
().
nbytes
()
-
t
.
storage_offset
()
*
t
.
element_size
()
==
t
.
numel
()
*
t
.
element_size
());
}
void
mscclpp_allreduce
(
fptr_t
_context
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
int64_t
nthreads
,
int64_t
nblocks
)
{
MscclContext
*
context
=
reinterpret_cast
<
MscclContext
*>
(
_context
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
inp
));
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
TORCH_CHECK_EQ
(
inp
.
scalar_type
(),
out
.
scalar_type
());
TORCH_CHECK_EQ
(
inp
.
numel
(),
out
.
numel
());
TORCH_CHECK
(
_mscclpp_is_weak_contiguous
(
out
));
TORCH_CHECK
(
_mscclpp_is_weak_contiguous
(
inp
));
switch
(
out
.
scalar_type
())
{
case
at
::
ScalarType
::
Float
:
{
context
->
allreduce
<
float
>
(
stream
,
reinterpret_cast
<
float
*>
(
inp
.
data_ptr
()),
reinterpret_cast
<
float
*>
(
out
.
data_ptr
()),
inp
.
numel
(),
nthreads
,
nblocks
);
break
;
}
case
at
::
ScalarType
::
Half
:
{
context
->
allreduce
<
half
>
(
stream
,
reinterpret_cast
<
half
*>
(
inp
.
data_ptr
()),
reinterpret_cast
<
half
*>
(
out
.
data_ptr
()),
inp
.
numel
(),
nthreads
,
nblocks
);
break
;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case
at
::
ScalarType
::
BFloat16
:
{
context
->
allreduce
<
__nv_bfloat16
>
(
stream
,
reinterpret_cast
<
__nv_bfloat16
*>
(
inp
.
data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
*>
(
out
.
data_ptr
()),
inp
.
numel
(),
nthreads
,
nblocks
);
break
;
}
#endif
default:
throw
std
::
runtime_error
(
"custom allreduce only supports float32, float16 and bfloat16"
);
}
}
sgl-kernel/csrc/allreduce/mscclpp_allreduce.cuh
0 → 100644
View file @
8e3797be
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#pragma once
#if defined(__HIP_PLATFORM_AMD__)
#include <hip/hip_fp16.h>
#else
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#endif
#include <mscclpp/concurrency_device.hpp>
#include <mscclpp/core.hpp>
#include <mscclpp/memory_channel.hpp>
#include <mscclpp/memory_channel_device.hpp>
#include <mscclpp/nvls_device.hpp>
#include <mscclpp/port_channel.hpp>
#include <mscclpp/port_channel_device.hpp>
// comment this for test_mscclpp_allreduce.cu
#include "utils.h"
namespace
sglang
{
__device__
mscclpp
::
DeviceSyncer
deviceSyncer
;
__device__
mscclpp
::
DeviceSyncer
allGatherDeviceSyncer
;
__device__
mscclpp
::
DeviceSyncer
reduceScatterDeviceSyncer
;
__device__
mscclpp
::
DeviceSyncer
ibDeviceSyncer
;
template
<
typename
To
,
typename
From
>
__forceinline__
__device__
To
bit_cast
(
const
From
&
src
)
{
static_assert
(
sizeof
(
To
)
==
sizeof
(
From
),
"Size mismatch for bit_cast"
);
union
{
From
f
;
To
t
;
}
u
;
u
.
f
=
src
;
return
u
.
t
;
}
template
<
typename
T
>
__forceinline__
__device__
T
add_elements
(
T
a
,
T
b
)
{
return
a
+
b
;
}
template
<
>
__forceinline__
__device__
__half2
add_elements
(
__half2
a
,
__half2
b
)
{
return
__hadd2
(
a
,
b
);
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
template
<
>
__forceinline__
__device__
__nv_bfloat162
add_elements
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
return
__hadd2
(
a
,
b
);
}
#endif
template
<
typename
T
>
__forceinline__
__device__
int4
add_vectors_helper
(
int4
a
,
int4
b
)
{
int4
ret
;
ret
.
w
=
bit_cast
<
int
,
T
>
(
add_elements
(
bit_cast
<
T
,
int
>
(
a
.
w
),
bit_cast
<
T
,
int
>
(
b
.
w
)));
ret
.
x
=
bit_cast
<
int
,
T
>
(
add_elements
(
bit_cast
<
T
,
int
>
(
a
.
x
),
bit_cast
<
T
,
int
>
(
b
.
x
)));
ret
.
y
=
bit_cast
<
int
,
T
>
(
add_elements
(
bit_cast
<
T
,
int
>
(
a
.
y
),
bit_cast
<
T
,
int
>
(
b
.
y
)));
ret
.
z
=
bit_cast
<
int
,
T
>
(
add_elements
(
bit_cast
<
T
,
int
>
(
a
.
z
),
bit_cast
<
T
,
int
>
(
b
.
z
)));
return
ret
;
}
template
<
typename
T
>
__forceinline__
__device__
int4
add_vectors
(
int4
a
,
int4
b
)
{
return
add_vectors_helper
<
T
>
(
a
,
b
);
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
template
<
>
__forceinline__
__device__
int4
add_vectors
<
__nv_bfloat16
>
(
int4
a
,
int4
b
)
{
return
add_vectors_helper
<
__nv_bfloat162
>
(
a
,
b
);
}
#endif
template
<
>
__forceinline__
__device__
int4
add_vectors
<
__half
>
(
int4
a
,
int4
b
)
{
return
add_vectors_helper
<
__half2
>
(
a
,
b
);
}
template
<
typename
T
>
__forceinline__
__device__
uint2
add_vectors_helper
(
uint2
a
,
uint2
b
)
{
uint2
ret
;
ret
.
x
=
bit_cast
<
int
,
T
>
(
add_elements
(
bit_cast
<
T
,
int
>
(
a
.
x
),
bit_cast
<
T
,
int
>
(
b
.
x
)));
ret
.
y
=
bit_cast
<
int
,
T
>
(
add_elements
(
bit_cast
<
T
,
int
>
(
a
.
y
),
bit_cast
<
T
,
int
>
(
b
.
y
)));
return
ret
;
}
template
<
typename
T
>
__forceinline__
__device__
uint2
add_vectors
(
uint2
a
,
uint2
b
)
{
return
add_vectors_helper
<
T
>
(
a
,
b
);
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
template
<
>
__forceinline__
__device__
uint2
add_vectors
<
__nv_bfloat16
>
(
uint2
a
,
uint2
b
)
{
return
add_vectors_helper
<
__nv_bfloat162
>
(
a
,
b
);
}
#endif
template
<
>
__forceinline__
__device__
uint2
add_vectors
<
__half
>
(
uint2
a
,
uint2
b
)
{
return
add_vectors_helper
<
__half2
>
(
a
,
b
);
}
template
<
typename
T
>
__forceinline__
__device__
int
add_vectors_helper
(
int
a
,
int
b
)
{
return
bit_cast
<
int
,
T
>
(
add_elements
(
bit_cast
<
T
,
int
>
(
a
),
bit_cast
<
T
,
int
>
(
b
)));
}
template
<
typename
T
>
__forceinline__
__device__
int
add_vectors
(
int
a
,
int
b
)
{
return
add_vectors_helper
<
T
>
(
a
,
b
);
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
template
<
>
__forceinline__
__device__
int
add_vectors
<
__nv_bfloat16
>
(
int
a
,
int
b
)
{
return
add_vectors_helper
<
__nv_bfloat162
>
(
a
,
b
);
}
#endif
template
<
>
__forceinline__
__device__
int
add_vectors
<
__half
>
(
int
a
,
int
b
)
{
return
add_vectors_helper
<
__half2
>
(
a
,
b
);
}
// -------------------------------------------------------
// allreduce_LL_1node using LLPacket, origin allreduce2
// -------------------------------------------------------
__device__
uint64_t
globalFlag
=
1
;
template
<
typename
TYPE
>
__global__
void
__launch_bounds__
(
1024
,
1
)
allreduce_LL_1node
(
mscclpp
::
MemoryChannelDeviceHandle
*
memChans
,
TYPE
*
buff
,
TYPE
*
scratch
,
void
*
resultBuff
,
int
rank
,
int
worldSize
,
size_t
nelems
)
{
nelems
=
nelems
/
(
sizeof
(
int
)
/
sizeof
(
TYPE
));
// This version of allreduce only works for single nodes
const
int
nPeers
=
worldSize
-
1
;
const
size_t
nPkts
=
nelems
/
2
;
const
int
nelemsPerRank
=
nelems
/
worldSize
;
const
int
nPktsPerRank
=
nelemsPerRank
/
2
;
// flag for packets. Initially 1
const
uint32_t
flag
=
(
uint32_t
)
globalFlag
;
// thread block & channel info
const
int
nBlocksPerPeer
=
gridDim
.
x
/
nPeers
;
const
int
localBlockIdx
=
blockIdx
.
x
%
nBlocksPerPeer
;
const
int
peerIdx
=
blockIdx
.
x
/
nBlocksPerPeer
;
const
int
remoteRank
=
peerIdx
<
rank
?
peerIdx
:
peerIdx
+
1
;
mscclpp
::
MemoryChannelDeviceHandle
memChan
=
memChans
[
peerIdx
];
const
int
tid
=
threadIdx
.
x
+
localBlockIdx
*
blockDim
.
x
;
// double buffering
size_t
scratchBaseOffset
=
(
flag
&
1
)
?
0
:
nPkts
*
sizeof
(
mscclpp
::
LLPacket
);
void
*
scratchBuff
=
(
void
*
)((
char
*
)
scratch
+
scratchBaseOffset
);
size_t
scratchOffset
=
scratchBaseOffset
+
rank
*
nPktsPerRank
*
sizeof
(
mscclpp
::
LLPacket
);
size_t
scratchResultOffset
=
(
flag
&
1
)
?
2
*
nPkts
*
sizeof
(
mscclpp
::
LLPacket
)
:
3
*
nPkts
*
sizeof
(
mscclpp
::
LLPacket
);
size_t
srcOffset
=
remoteRank
*
nelemsPerRank
*
sizeof
(
int
);
uint2
*
src
=
(
uint2
*
)((
char
*
)
buff
+
rank
*
nelemsPerRank
*
sizeof
(
int
));
uint2
*
dst
=
(
uint2
*
)((
char
*
)
resultBuff
+
rank
*
nelemsPerRank
*
sizeof
(
int
));
// step 1: write to scratch buffer
memChan
.
putPackets
(
scratchOffset
,
srcOffset
,
nelemsPerRank
*
sizeof
(
int
),
tid
,
blockDim
.
x
*
nBlocksPerPeer
,
flag
);
// step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer
for
(
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
idx
<
nPktsPerRank
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
uint2
data
=
make_uint2
(
0
,
0
);
for
(
int
index
=
0
;
index
<
nPeers
;
index
++
)
{
const
int
remoteRank
=
index
<
rank
?
index
:
index
+
1
;
mscclpp
::
LLPacket
*
dstPkt
=
(
mscclpp
::
LLPacket
*
)
scratchBuff
+
remoteRank
*
nPktsPerRank
;
uint2
val
=
dstPkt
[
idx
].
read
(
flag
);
data
=
add_vectors
<
TYPE
>
(
val
,
data
);
}
data
=
add_vectors
<
TYPE
>
(
data
,
src
[
idx
]);
dst
[
idx
]
=
data
;
mscclpp
::
LLPacket
packet
;
packet
.
data1
=
data
.
x
;
packet
.
flag1
=
flag
;
packet
.
data2
=
data
.
y
;
packet
.
flag2
=
flag
;
size_t
offset
=
scratchResultOffset
/
sizeof
(
mscclpp
::
LLPacket
)
+
(
idx
+
rank
*
nPktsPerRank
);
for
(
int
index
=
0
;
index
<
nPeers
;
index
++
)
{
memChans
[
index
].
write
(
offset
,
packet
);
}
}
// step 3: get data result from scratch buffer
mscclpp
::
LLPacket
*
dstPkt
=
(
mscclpp
::
LLPacket
*
)((
char
*
)
scratch
+
scratchResultOffset
);
const
int
dstOffset
=
remoteRank
*
nPktsPerRank
;
uint2
*
result
=
(
uint2
*
)((
char
*
)
resultBuff
+
remoteRank
*
nelemsPerRank
*
sizeof
(
int
));
for
(
int
idx
=
threadIdx
.
x
+
localBlockIdx
*
blockDim
.
x
;
idx
<
nPktsPerRank
;
idx
+=
blockDim
.
x
*
nBlocksPerPeer
)
{
uint2
data
=
dstPkt
[
idx
+
dstOffset
].
read
(
flag
);
result
[
idx
].
x
=
data
.
x
;
result
[
idx
].
y
=
data
.
y
;
}
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
)
{
globalFlag
+=
1
;
}
}
// -------------------------------------------------------
// allreduce_LL_2node using LLPacket, origin allreduce5
// -------------------------------------------------------
template
<
typename
TYPE
>
__global__
void
__launch_bounds__
(
1024
,
1
)
allreduce_LL_2node
(
mscclpp
::
MemoryChannelDeviceHandle
*
memChans
,
mscclpp
::
PortChannelDeviceHandle
*
portChans
,
TYPE
*
buff
,
TYPE
*
scratch
,
TYPE
*
putBuff
,
TYPE
*
resultBuff
,
int
rank
,
int
nRanksPerNode
,
int
worldSize
,
size_t
nelems
)
{
nelems
=
nelems
/
(
sizeof
(
int
)
/
sizeof
(
TYPE
));
// This version of allreduce only works for single nodes
const
int
nPeersInNode
=
nRanksPerNode
-
1
;
const
int
nPkts
=
nelems
/
2
;
const
int
nelemsPerLocalRank
=
nelems
/
nRanksPerNode
;
const
int
nPktsPerLocalRank
=
nelemsPerLocalRank
/
2
;
const
int
localRankId
=
rank
%
nRanksPerNode
;
// flag for packets. Initially 1
const
uint32_t
flag
=
(
uint32_t
)
globalFlag
;
// thread block & channel info
const
int
nBlocksPerPeer
=
gridDim
.
x
/
nPeersInNode
;
const
int
localBlockIdx
=
blockIdx
.
x
%
nBlocksPerPeer
;
const
int
peerIdx
=
blockIdx
.
x
/
nBlocksPerPeer
;
const
int
remoteRankIdx
=
peerIdx
<
localRankId
?
peerIdx
:
peerIdx
+
1
;
mscclpp
::
MemoryChannelDeviceHandle
memChan
=
memChans
[
peerIdx
];
mscclpp
::
PortChannelDeviceHandle
portChan
=
portChans
[
localRankId
];
const
int
tid
=
threadIdx
.
x
+
localBlockIdx
*
blockDim
.
x
;
// double buffering
size_t
scratchBaseOffset
=
(
flag
&
1
)
?
0
:
nPkts
*
sizeof
(
mscclpp
::
LLPacket
);
size_t
putBaseOffset
=
(
flag
&
1
)
?
0
:
nPktsPerLocalRank
*
sizeof
(
mscclpp
::
LLPacket
);
void
*
scratchBuff
=
(
void
*
)((
char
*
)
scratch
+
scratchBaseOffset
);
size_t
scratchOffset
=
scratchBaseOffset
+
localRankId
*
nPktsPerLocalRank
*
sizeof
(
mscclpp
::
LLPacket
);
size_t
scratchResultOffset
=
(
flag
&
1
)
?
2
*
nPkts
*
sizeof
(
mscclpp
::
LLPacket
)
:
3
*
nPkts
*
sizeof
(
mscclpp
::
LLPacket
);
size_t
srcOffset
=
remoteRankIdx
*
nelemsPerLocalRank
*
sizeof
(
int
);
uint2
*
src
=
(
uint2
*
)((
char
*
)
buff
+
localRankId
*
nelemsPerLocalRank
*
sizeof
(
int
));
uint2
*
dst
=
(
uint2
*
)((
char
*
)
resultBuff
+
localRankId
*
nelemsPerLocalRank
*
sizeof
(
int
));
// step 1: write to scratch buffer
if
(
nRanksPerNode
>
1
)
{
memChan
.
putPackets
(
scratchOffset
,
srcOffset
,
nelemsPerLocalRank
*
sizeof
(
int
),
tid
,
blockDim
.
x
*
nBlocksPerPeer
,
flag
);
}
// step 2: get data from scratch buffer, do local reduce-scatter in each node.
mscclpp
::
LLPacket
*
putPkt
=
(
mscclpp
::
LLPacket
*
)((
char
*
)
putBuff
+
putBaseOffset
);
for
(
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
idx
<
nPktsPerLocalRank
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
uint2
data
=
make_uint2
(
0
,
0
);
for
(
int
index
=
0
;
index
<
nPeersInNode
;
index
++
)
{
const
int
remoteRank
=
index
<
localRankId
?
index
:
index
+
1
;
mscclpp
::
LLPacket
*
dstPkt
=
(
mscclpp
::
LLPacket
*
)
scratchBuff
+
remoteRank
*
nPktsPerLocalRank
;
uint2
val
=
dstPkt
[
idx
].
read
(
flag
);
data
=
add_vectors
<
TYPE
>
(
val
,
data
);
}
data
=
add_vectors
<
TYPE
>
(
data
,
src
[
idx
]);
putPkt
[
idx
].
write
(
data
.
x
,
data
.
y
,
flag
);
dst
[
idx
]
=
data
;
}
deviceSyncer
.
sync
(
gridDim
.
x
);
// step 3. send local reduced data to remote node.
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
)
{
portChan
.
put
(
scratchOffset
,
putBaseOffset
,
nPktsPerLocalRank
*
sizeof
(
mscclpp
::
LLPacket
));
if
((
flag
&
63
)
==
0
)
{
portChan
.
flush
();
}
}
// step 4. try to read the data from scratch buffer and write to local peers
mscclpp
::
LLPacket
*
dstPkt
=
(
mscclpp
::
LLPacket
*
)
scratchBuff
+
localRankId
*
nPktsPerLocalRank
;
for
(
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
idx
<
nPktsPerLocalRank
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
uint2
res
=
dst
[
idx
];
uint2
val
=
dstPkt
[
idx
].
read
(
flag
);
res
=
add_vectors
<
TYPE
>
(
res
,
val
);
mscclpp
::
LLPacket
packet
;
packet
.
data1
=
res
.
x
;
packet
.
flag1
=
flag
;
packet
.
data2
=
res
.
y
;
packet
.
flag2
=
flag
;
size_t
offset
=
scratchResultOffset
/
sizeof
(
mscclpp
::
LLPacket
)
+
(
idx
+
localRankId
*
nPktsPerLocalRank
);
for
(
int
index
=
0
;
index
<
nPeersInNode
;
index
++
)
{
memChans
[
index
].
write
(
offset
,
packet
);
}
dst
[
idx
]
=
res
;
}
// step 5: get data result from scratch buffer
dstPkt
=
(
mscclpp
::
LLPacket
*
)((
char
*
)
scratch
+
scratchResultOffset
);
const
int
dstOffset
=
remoteRankIdx
*
nPktsPerLocalRank
;
uint2
*
result
=
(
uint2
*
)((
char
*
)
resultBuff
+
remoteRankIdx
*
nelemsPerLocalRank
*
sizeof
(
int
));
if
(
nRanksPerNode
>
1
)
{
for
(
int
idx
=
threadIdx
.
x
+
localBlockIdx
*
blockDim
.
x
;
idx
<
nPktsPerLocalRank
;
idx
+=
blockDim
.
x
*
nBlocksPerPeer
)
{
uint2
data
=
dstPkt
[
idx
+
dstOffset
].
read
(
flag
);
result
[
idx
]
=
data
;
}
}
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
)
{
globalFlag
+=
1
;
}
}
static
const
mscclpp
::
Transport
IBs
[]
=
{
mscclpp
::
Transport
::
IB0
,
mscclpp
::
Transport
::
IB1
,
mscclpp
::
Transport
::
IB2
,
mscclpp
::
Transport
::
IB3
,
mscclpp
::
Transport
::
IB4
,
mscclpp
::
Transport
::
IB5
,
mscclpp
::
Transport
::
IB6
,
mscclpp
::
Transport
::
IB7
};
class
MscclCommGroup
{
public:
std
::
shared_ptr
<
mscclpp
::
Communicator
>
comm_
;
const
size_t
rank_
;
const
size_t
world_size_
;
const
std
::
vector
<
int64_t
>
rank_to_node_
;
const
std
::
vector
<
int64_t
>
rank_to_ib_
;
MscclCommGroup
(
mscclpp
::
UniqueId
unique_id
,
const
size_t
rank
,
const
size_t
world_size
,
const
std
::
vector
<
int64_t
>&
rank_to_node
,
const
std
::
vector
<
int64_t
>&
rank_to_ib
)
:
rank_
(
rank
),
world_size_
(
world_size
),
rank_to_node_
(
rank_to_node
),
rank_to_ib_
(
rank_to_ib
)
{
auto
bootstrap
=
std
::
make_shared
<
mscclpp
::
TcpBootstrap
>
(
rank
,
world_size
);
bootstrap
->
initialize
(
unique_id
);
comm_
=
std
::
make_shared
<
mscclpp
::
Communicator
>
(
bootstrap
);
}
template
<
typename
T
>
void
allreduce
(
cudaStream_t
stream
,
T
*
output
,
size_t
input_numel
,
int
threads
=
512
,
int
block_limit
=
21
)
{
throw
std
::
runtime_error
(
"you should not call allreduce of a base context"
);
}
bool
is_same_node
(
int
r1
,
int
r2
)
{
return
rank_to_node_
[
r1
]
==
rank_to_node_
[
r2
];
}
void
make_connection
(
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
mscclpp
::
Connection
>>&
same_node_connections
,
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
mscclpp
::
Connection
>>&
cross_node_connections
)
{
same_node_connections
.
clear
();
cross_node_connections
.
clear
();
std
::
unordered_map
<
int
,
mscclpp
::
NonblockingFuture
<
std
::
shared_ptr
<
mscclpp
::
Connection
>>>
conn_futures
;
for
(
int
r
=
0
;
r
<
world_size_
;
++
r
)
{
if
(
r
==
rank_
)
continue
;
mscclpp
::
Transport
transport
=
is_same_node
(
r
,
rank_
)
?
mscclpp
::
Transport
::
CudaIpc
:
IBs
[
rank_to_ib_
[
r
]];
conn_futures
.
emplace
(
r
,
comm_
->
connectOnSetup
(
r
,
0
,
transport
));
}
comm_
->
setup
();
for
(
int
r
=
0
;
r
<
world_size_
;
++
r
)
{
if
(
r
==
rank_
)
continue
;
if
(
is_same_node
(
r
,
rank_
))
{
same_node_connections
.
emplace
(
r
,
conn_futures
[
r
].
get
());
}
else
{
cross_node_connections
.
emplace
(
r
,
conn_futures
[
r
].
get
());
}
}
}
void
make_memory_channels_with_scratch
(
void
*
tensor_ptr
,
const
size_t
tensor_bytes
,
void
*
scratch_ptr
,
const
size_t
scratch_bytes
,
const
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
mscclpp
::
Connection
>>&
connections
,
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
mscclpp
::
MemoryDevice2DeviceSemaphore
>>&
semaphores
,
std
::
unordered_map
<
int
,
mscclpp
::
RegisteredMemory
>&
registered_memories
,
std
::
unordered_map
<
int
,
mscclpp
::
MemoryChannel
>&
channels
)
{
channels
.
clear
();
make_semaphores
<
mscclpp
::
MemoryDevice2DeviceSemaphore
>
(
connections
,
semaphores
);
register_tensor_with_connections
(
scratch_ptr
,
scratch_bytes
,
connections
,
registered_memories
);
for
(
const
auto
&
[
peer
,
_
]
:
connections
)
{
channels
.
emplace
(
peer
,
mscclpp
::
MemoryChannel
(
semaphores
[
peer
],
registered_memories
[
peer
],
tensor_ptr
,
scratch_ptr
));
}
}
void
make_port_channels_with_scratch
(
std
::
shared_ptr
<
mscclpp
::
ProxyService
>
proxyService
,
void
*
tensor_ptr
,
const
size_t
tensor_bytes
,
void
*
scratch_ptr
,
const
size_t
scratch_bytes
,
const
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
mscclpp
::
Connection
>>&
connections
,
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
mscclpp
::
Host2DeviceSemaphore
>>&
semaphores
,
std
::
unordered_map
<
int
,
mscclpp
::
RegisteredMemory
>&
registered_memories
,
std
::
unordered_map
<
int
,
mscclpp
::
PortChannel
>&
channels
)
{
channels
.
clear
();
make_semaphores
<
mscclpp
::
Host2DeviceSemaphore
>
(
connections
,
semaphores
);
mscclpp
::
TransportFlags
flags
;
for
(
const
auto
&
[
_
,
conn
]
:
connections
)
{
flags
|=
conn
->
transport
();
}
auto
local_reg_memory
=
comm_
->
registerMemory
(
tensor_ptr
,
tensor_bytes
,
flags
);
register_tensor_with_connections
(
scratch_ptr
,
scratch_bytes
,
connections
,
registered_memories
);
std
::
unordered_map
<
int
,
mscclpp
::
SemaphoreId
>
semaphore_ids
;
std
::
unordered_map
<
int
,
size_t
>
memory_ids
;
memory_ids
[
rank_
]
=
proxyService
->
addMemory
(
local_reg_memory
);
for
(
const
auto
&
[
peer
,
memory
]
:
registered_memories
)
{
if
(
peer
==
rank_
)
continue
;
memory_ids
[
peer
]
=
proxyService
->
addMemory
(
memory
);
}
for
(
const
auto
&
[
peer
,
semaphore
]
:
semaphores
)
{
semaphore_ids
[
peer
]
=
proxyService
->
addSemaphore
(
semaphore
);
}
for
(
const
auto
&
[
peer
,
_
]
:
connections
)
{
channels
.
emplace
(
peer
,
proxyService
->
portChannel
(
semaphore_ids
[
peer
],
memory_ids
[
peer
],
memory_ids
[
rank_
]));
}
}
template
<
typename
SemaphoreType
>
void
make_semaphores
(
const
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
mscclpp
::
Connection
>>&
connections
,
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
SemaphoreType
>>&
semaphores
)
{
semaphores
.
clear
();
for
(
const
auto
&
[
peer
,
conn
]
:
connections
)
{
semaphores
[
peer
]
=
std
::
make_shared
<
SemaphoreType
>
(
*
comm_
,
conn
);
}
comm_
->
setup
();
}
void
register_tensor_with_connections
(
void
*
tensor_ptr
,
size_t
tensor_bytes
,
const
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
mscclpp
::
Connection
>>&
connections
,
std
::
unordered_map
<
int
,
mscclpp
::
RegisteredMemory
>&
registered_memories
)
{
registered_memories
.
clear
();
mscclpp
::
TransportFlags
all_transports
;
for
(
const
auto
&
[
_
,
connection
]
:
connections
)
{
all_transports
|=
connection
->
transport
();
}
mscclpp
::
RegisteredMemory
buf_reg_mem
=
comm_
->
registerMemory
(
tensor_ptr
,
tensor_bytes
,
all_transports
);
registered_memories
[
rank_
]
=
buf_reg_mem
;
std
::
unordered_map
<
int
,
mscclpp
::
NonblockingFuture
<
mscclpp
::
RegisteredMemory
>>
remote_mem_futures
;
for
(
const
auto
&
[
r
,
connection
]
:
connections
)
{
comm_
->
sendMemoryOnSetup
(
buf_reg_mem
,
r
,
0
);
auto
remoteMemory
=
comm_
->
recvMemoryOnSetup
(
r
,
0
);
remote_mem_futures
.
emplace
(
r
,
remoteMemory
);
}
comm_
->
setup
();
for
(
auto
&
[
r
,
mem_feature
]
:
remote_mem_futures
)
{
registered_memories
.
emplace
(
r
,
mem_feature
.
get
());
}
}
void
make_device_memory_handle_base_on_new_ptr
(
const
std
::
unordered_map
<
int
,
mscclpp
::
MemoryChannel
>&
old_memory_channels
,
std
::
unordered_map
<
int
,
mscclpp
::
RegisteredMemory
>&
registered_sm_memories
,
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
mscclpp
::
MemoryDevice2DeviceSemaphore
>>&
memory_semaphores
,
std
::
unordered_map
<
int
,
mscclpp
::
MemoryChannel
>&
memory_channels
,
mscclpp
::
GpuBuffer
<
mscclpp
::
MemoryChannelDeviceHandle
>&
device_memory_handle
,
void
*
input
,
void
*
scratch
,
const
cudaStream_t
stream
)
{
memory_channels
.
clear
();
for
(
const
auto
&
[
peer
,
channel
]
:
old_memory_channels
)
{
memory_channels
.
emplace
(
peer
,
mscclpp
::
MemoryChannel
(
memory_semaphores
[
peer
],
registered_sm_memories
[
peer
],
input
,
scratch
));
}
std
::
vector
<
mscclpp
::
MemoryChannel
>
memory_channels_list
;
for
(
int
r
=
0
;
r
<
world_size_
;
r
++
)
{
if
(
r
==
rank_
)
continue
;
if
(
is_same_node
(
r
,
rank_
))
{
memory_channels_list
.
push_back
(
memory_channels
[
r
]);
}
}
std
::
vector
<
mscclpp
::
MemoryChannelDeviceHandle
>
memory_channel_handlers
(
memory_channels_list
.
size
());
std
::
transform
(
memory_channels_list
.
begin
(),
memory_channels_list
.
end
(),
memory_channel_handlers
.
begin
(),
[](
const
mscclpp
::
MemoryChannel
&
channel
)
{
return
channel
.
deviceHandle
();
});
mscclpp
::
gpuMemcpyAsync
<
mscclpp
::
MemoryChannelDeviceHandle
>
(
device_memory_handle
.
data
(),
memory_channel_handlers
.
data
(),
memory_channel_handlers
.
size
(),
stream
,
cudaMemcpyHostToDevice
);
}
};
class
Msccl1NodeLLcontext
{
private:
std
::
shared_ptr
<
MscclCommGroup
>
comm_group_
=
nullptr
;
void
*
scratch_
;
const
size_t
scratch_bytes_
;
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
mscclpp
::
Connection
>>
same_node_connections_
;
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
mscclpp
::
Connection
>>
cross_node_connections_
;
std
::
unordered_map
<
int
,
mscclpp
::
RegisteredMemory
>
registered_sm_memories_
;
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
mscclpp
::
MemoryDevice2DeviceSemaphore
>>
memory_semaphores_
;
std
::
unordered_map
<
int
,
mscclpp
::
MemoryChannel
>
memory_channels_
;
mscclpp
::
GpuBuffer
<
mscclpp
::
MemoryChannelDeviceHandle
>
d_memHandles_
;
std
::
unordered_map
<
void
*
,
std
::
unordered_map
<
int
,
mscclpp
::
MemoryChannel
>>
input_ptr2memory_channels_
;
std
::
unordered_map
<
void
*
,
mscclpp
::
GpuBuffer
<
mscclpp
::
MemoryChannelDeviceHandle
>>
input_ptr2d_memHandles_
;
cudaStream_t
h2d_stream
;
const
size_t
nranks_per_node_
;
public:
Msccl1NodeLLcontext
(
mscclpp
::
UniqueId
unique_id
,
const
size_t
rank
,
const
size_t
world_size
,
void
*
scratch
,
const
size_t
scratch_bytes
,
const
size_t
nranks_per_node
,
const
std
::
vector
<
int64_t
>&
rank_to_node
,
const
std
::
vector
<
int64_t
>&
rank_to_ib
)
:
scratch_
(
scratch
),
scratch_bytes_
(
scratch_bytes
),
nranks_per_node_
(
nranks_per_node
),
d_memHandles_
(
nranks_per_node
-
1
)
{
CHECK_CUDA_SUCCESS
(
cudaStreamCreateWithFlags
(
&
h2d_stream
,
cudaStreamNonBlocking
));
comm_group_
=
std
::
make_shared
<
MscclCommGroup
>
(
unique_id
,
rank
,
world_size
,
rank_to_node
,
rank_to_ib
);
comm_group_
->
make_connection
(
same_node_connections_
,
cross_node_connections_
);
comm_group_
->
make_memory_channels_with_scratch
(
scratch_
,
scratch_bytes_
,
scratch_
,
scratch_bytes_
,
same_node_connections_
,
memory_semaphores_
,
registered_sm_memories_
,
memory_channels_
);
std
::
vector
<
mscclpp
::
MemoryChannel
>
memory_channels_list
;
for
(
int
r
=
0
;
r
<
comm_group_
->
world_size_
;
r
++
)
{
if
(
r
==
comm_group_
->
rank_
)
continue
;
memory_channels_list
.
push_back
(
memory_channels_
[
r
]);
}
std
::
vector
<
mscclpp
::
MemoryChannelDeviceHandle
>
memory_channel_handlers
(
memory_channels_list
.
size
());
std
::
transform
(
memory_channels_list
.
begin
(),
memory_channels_list
.
end
(),
memory_channel_handlers
.
begin
(),
[](
const
mscclpp
::
MemoryChannel
&
channel
)
{
return
channel
.
deviceHandle
();
});
mscclpp
::
gpuMemcpy
<
mscclpp
::
MemoryChannelDeviceHandle
>
(
d_memHandles_
.
data
(),
memory_channel_handlers
.
data
(),
memory_channel_handlers
.
size
(),
cudaMemcpyHostToDevice
);
}
~
Msccl1NodeLLcontext
()
{
CHECK_CUDA_SUCCESS
(
cudaStreamDestroy
(
h2d_stream
));
}
template
<
typename
T
>
void
allreduce
(
cudaStream_t
stream
,
T
*
input
,
T
*
output
,
size_t
input_numel
,
int
nthreads
=
512
,
int
nblocks
=
21
)
{
dim3
nthrs
(
nthreads
);
dim3
nblks
(
nblocks
);
cudaStreamCaptureStatus
capturing_status
;
CHECK_CUDA_SUCCESS
(
cudaStreamIsCapturing
(
stream
,
&
capturing_status
));
mscclpp
::
MemoryChannelDeviceHandle
*
memChans
;
if
(
capturing_status
!=
cudaStreamCaptureStatusActive
)
{
std
::
unordered_map
<
int
,
mscclpp
::
MemoryChannel
>
memory_channels
;
comm_group_
->
make_device_memory_handle_base_on_new_ptr
(
memory_channels_
,
registered_sm_memories_
,
memory_semaphores_
,
memory_channels
,
d_memHandles_
,
input
,
scratch_
,
h2d_stream
);
CHECK_CUDA_SUCCESS
(
cudaStreamSynchronize
(
h2d_stream
));
memChans
=
d_memHandles_
.
data
();
}
else
{
void
*
input_void_ptr
=
reinterpret_cast
<
void
*>
(
input
);
if
(
input_ptr2d_memHandles_
.
find
(
input_void_ptr
)
==
input_ptr2d_memHandles_
.
end
())
{
std
::
unordered_map
<
int
,
mscclpp
::
MemoryChannel
>
memory_channels
;
mscclpp
::
GpuBuffer
<
mscclpp
::
MemoryChannelDeviceHandle
>
device_memory_handle
(
comm_group_
->
world_size_
-
1
);
comm_group_
->
make_device_memory_handle_base_on_new_ptr
(
memory_channels_
,
registered_sm_memories_
,
memory_semaphores_
,
memory_channels
,
device_memory_handle
,
input
,
scratch_
,
h2d_stream
);
input_ptr2memory_channels_
.
emplace
(
input_void_ptr
,
memory_channels
);
input_ptr2d_memHandles_
.
emplace
(
input_void_ptr
,
device_memory_handle
);
}
auto
it
=
input_ptr2d_memHandles_
.
find
(
input_void_ptr
);
memChans
=
it
->
second
.
data
();
}
allreduce_LL_1node
<
T
><<<
nblks
,
nthrs
,
0
,
stream
>>>
(
memChans
,
(
T
*
)
input
,
(
T
*
)
scratch_
,
output
,
comm_group_
->
rank_
,
comm_group_
->
world_size_
,
input_numel
);
cudaError_t
status
=
cudaGetLastError
();
if
(
status
!=
cudaSuccess
)
{
printf
(
"rank: %lu failed to launch allreduce_LL_1node: %s
\n
"
,
comm_group_
->
rank_
,
cudaGetErrorString
(
status
));
}
}
};
class
Msccl2NodeLLcontext
{
private:
std
::
shared_ptr
<
MscclCommGroup
>
comm_group_
=
nullptr
;
void
*
scratch_
;
const
size_t
scratch_bytes_
;
void
*
put_buffer_
;
const
size_t
put_buffer_bytes_
;
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
mscclpp
::
Connection
>>
same_node_connections_
;
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
mscclpp
::
Connection
>>
cross_node_connections_
;
std
::
unordered_map
<
int
,
mscclpp
::
RegisteredMemory
>
registered_sm_memories_
;
std
::
unordered_map
<
int
,
mscclpp
::
RegisteredMemory
>
registered_port_memories_
;
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
mscclpp
::
MemoryDevice2DeviceSemaphore
>>
memory_semaphores_
;
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
mscclpp
::
Host2DeviceSemaphore
>>
port_semaphores_
;
std
::
unordered_map
<
int
,
mscclpp
::
MemoryChannel
>
memory_channels_
;
std
::
unordered_map
<
int
,
mscclpp
::
PortChannel
>
port_channels_
;
mscclpp
::
GpuBuffer
<
mscclpp
::
MemoryChannelDeviceHandle
>
d_memHandles_
;
mscclpp
::
GpuBuffer
<
mscclpp
::
PortChannelDeviceHandle
>
d_portHandles_
;
std
::
shared_ptr
<
mscclpp
::
ProxyService
>
proxyService
;
cudaStream_t
h2d_stream
;
const
size_t
nranks_per_node_
;
std
::
unordered_map
<
void
*
,
std
::
unordered_map
<
int
,
mscclpp
::
MemoryChannel
>>
input_ptr2memory_channels_
;
std
::
unordered_map
<
void
*
,
mscclpp
::
GpuBuffer
<
mscclpp
::
MemoryChannelDeviceHandle
>>
input_ptr2d_memHandles_
;
public:
Msccl2NodeLLcontext
(
mscclpp
::
UniqueId
unique_id
,
const
size_t
rank
,
const
size_t
world_size
,
void
*
scratch
,
const
size_t
scratch_bytes
,
void
*
put_buffer
,
const
size_t
put_buffer_bytes
,
const
size_t
nranks_per_node
,
const
std
::
vector
<
int64_t
>&
rank_to_node
,
const
std
::
vector
<
int64_t
>&
rank_to_ib
)
:
scratch_
(
scratch
),
scratch_bytes_
(
scratch_bytes
),
put_buffer_
(
put_buffer
),
put_buffer_bytes_
(
put_buffer_bytes
),
nranks_per_node_
(
nranks_per_node
),
d_memHandles_
(
nranks_per_node
-
1
),
d_portHandles_
(
world_size
-
nranks_per_node
)
{
CHECK_CUDA_SUCCESS
(
cudaStreamCreateWithFlags
(
&
h2d_stream
,
cudaStreamNonBlocking
));
comm_group_
=
std
::
make_shared
<
MscclCommGroup
>
(
unique_id
,
rank
,
world_size
,
rank_to_node
,
rank_to_ib
);
proxyService
=
std
::
make_shared
<
mscclpp
::
ProxyService
>
();
proxyService
->
startProxy
();
comm_group_
->
make_connection
(
same_node_connections_
,
cross_node_connections_
);
comm_group_
->
make_memory_channels_with_scratch
(
scratch_
,
scratch_bytes_
,
scratch_
,
scratch_bytes_
,
same_node_connections_
,
memory_semaphores_
,
registered_sm_memories_
,
memory_channels_
);
comm_group_
->
make_port_channels_with_scratch
(
proxyService
,
put_buffer_
,
put_buffer_bytes_
,
scratch_
,
scratch_bytes_
,
cross_node_connections_
,
port_semaphores_
,
registered_port_memories_
,
port_channels_
);
std
::
vector
<
mscclpp
::
MemoryChannel
>
memory_channels_list
;
std
::
vector
<
mscclpp
::
PortChannel
>
port_channels_list
;
for
(
int
r
=
0
;
r
<
comm_group_
->
world_size_
;
r
++
)
{
if
(
r
==
comm_group_
->
rank_
)
continue
;
if
(
comm_group_
->
is_same_node
(
r
,
comm_group_
->
rank_
))
{
memory_channels_list
.
push_back
(
memory_channels_
[
r
]);
}
else
{
port_channels_list
.
push_back
(
port_channels_
[
r
]);
}
}
std
::
vector
<
mscclpp
::
MemoryChannelDeviceHandle
>
memory_channel_handlers
(
memory_channels_list
.
size
());
std
::
transform
(
memory_channels_list
.
begin
(),
memory_channels_list
.
end
(),
memory_channel_handlers
.
begin
(),
[](
const
mscclpp
::
MemoryChannel
&
channel
)
{
return
channel
.
deviceHandle
();
});
mscclpp
::
gpuMemcpy
<
mscclpp
::
MemoryChannelDeviceHandle
>
(
d_memHandles_
.
data
(),
memory_channel_handlers
.
data
(),
memory_channel_handlers
.
size
(),
cudaMemcpyHostToDevice
);
std
::
vector
<
mscclpp
::
PortChannelDeviceHandle
>
port_channel_handlers
(
port_channels_list
.
size
());
std
::
transform
(
port_channels_list
.
begin
(),
port_channels_list
.
end
(),
port_channel_handlers
.
begin
(),
[](
const
mscclpp
::
PortChannel
&
channel
)
{
return
channel
.
deviceHandle
();
});
mscclpp
::
gpuMemcpy
<
mscclpp
::
PortChannelDeviceHandle
>
(
d_portHandles_
.
data
(),
port_channel_handlers
.
data
(),
port_channel_handlers
.
size
(),
cudaMemcpyHostToDevice
);
}
~
Msccl2NodeLLcontext
()
{
CHECK_CUDA_SUCCESS
(
cudaStreamDestroy
(
h2d_stream
));
if
(
proxyService
)
{
proxyService
->
stopProxy
();
}
}
template
<
typename
T
>
void
allreduce
(
cudaStream_t
stream
,
T
*
input
,
T
*
output
,
const
size_t
input_numel
,
int
nthreads
=
512
,
int
nblocks
=
21
)
{
dim3
nthrs
(
nthreads
);
dim3
nblks
(
nblocks
);
cudaStreamCaptureStatus
capturing_status
;
CHECK_CUDA_SUCCESS
(
cudaStreamIsCapturing
(
stream
,
&
capturing_status
));
mscclpp
::
MemoryChannelDeviceHandle
*
memChans
;
if
(
capturing_status
!=
cudaStreamCaptureStatusActive
)
{
std
::
unordered_map
<
int
,
mscclpp
::
MemoryChannel
>
memory_channels
;
comm_group_
->
make_device_memory_handle_base_on_new_ptr
(
memory_channels_
,
registered_sm_memories_
,
memory_semaphores_
,
memory_channels
,
d_memHandles_
,
input
,
scratch_
,
h2d_stream
);
CHECK_CUDA_SUCCESS
(
cudaStreamSynchronize
(
h2d_stream
));
memChans
=
d_memHandles_
.
data
();
}
else
{
void
*
input_void_ptr
=
reinterpret_cast
<
void
*>
(
input
);
if
(
input_ptr2d_memHandles_
.
find
(
input_void_ptr
)
==
input_ptr2d_memHandles_
.
end
())
{
std
::
unordered_map
<
int
,
mscclpp
::
MemoryChannel
>
memory_channels
;
mscclpp
::
GpuBuffer
<
mscclpp
::
MemoryChannelDeviceHandle
>
device_memory_handle
(
7
);
comm_group_
->
make_device_memory_handle_base_on_new_ptr
(
memory_channels_
,
registered_sm_memories_
,
memory_semaphores_
,
memory_channels
,
device_memory_handle
,
input
,
scratch_
,
h2d_stream
);
input_ptr2memory_channels_
.
emplace
(
input_void_ptr
,
memory_channels
);
input_ptr2d_memHandles_
.
emplace
(
input_void_ptr
,
device_memory_handle
);
}
auto
it
=
input_ptr2d_memHandles_
.
find
(
input_void_ptr
);
memChans
=
it
->
second
.
data
();
}
allreduce_LL_2node
<
T
><<<
nblks
,
nthrs
,
0
,
stream
>>>
(
memChans
,
d_portHandles_
.
data
(),
(
T
*
)
input
,
(
T
*
)
scratch_
,
(
T
*
)
put_buffer_
,
output
,
comm_group_
->
rank_
,
nranks_per_node_
,
comm_group_
->
world_size_
,
input_numel
);
cudaError_t
status
=
cudaGetLastError
();
if
(
status
!=
cudaSuccess
)
{
printf
(
"rank: %lu failed to launch allreduce_LL_2node: %s
\n
"
,
comm_group_
->
rank_
,
cudaGetErrorString
(
status
));
}
}
};
}
// namespace sglang
sgl-kernel/csrc/allreduce/test_mscclpp_allreduce.cu
0 → 100644
View file @
8e3797be
/*
* this file is used to test mscclpp_allreduce.cu using mpirun
* this file is adapted from https://github.com/flashinfer-ai/flashinfer/blob/v0.2.5/src/test_sum_all_reduce.cu
usage:
cd PATH-TO-THIS-FILE
export MPI_HOME=/usr/local/mpi
# export MPI_HOME=/opt/hpcx/ompi/
export MSCCLPP_HOME=/workspace/test/mscclpp
nvcc -O2 -arch=native -std=c++17 test_mscclpp_allreduce.cu \
-o test_mscclpp_allreduce -D_GLIBCXX_USE_CXX11_ABI=0 \
-I${MSCCLPP_HOME}/include -L${MSCCLPP_HOME}/build -lmscclpp \
-lnccl -I${MPI_HOME}/include -L${MPI_HOME}/lib -lmpi
/opt/hpcx/ompi/bin/
mpirun --allow-run-as-root -H 127.0.0.1:8 -np 8 \
--map-by ppr:8:node \
--mca btl_openib_warn_no_device_params_found 0 \
--mca btl_tcp_if_include bond0 \
--allow-run-as-root -np 8 \
-x NCCL_RUNTIME_CONNECT=0 -x NCCL_IB_GID_INDEX=3 -x NCCL_DEBUG=WARN \
-x LD_PRELOAD=${MSCCLPP_HOME}/build/libmscclpp.so ./test_mscclpp_allreduce
*/
#include <mpi.h>
#include <thrust/detail/raw_pointer_cast.h>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#ifndef CHECK_CUDA_SUCCESS
#define CHECK_CUDA_SUCCESS(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while (0)
#endif
#include <cstdint>
#include "mscclpp_allreduce.cuh"
template
<
typename
T
>
bool
isclose
(
T
a
,
T
b
,
float
rtol
=
1e-5
,
float
atol
=
1e-8
)
{
return
fabs
(
a
-
b
)
<=
(
atol
+
rtol
*
fabs
(
b
));
}
int
main
(
int
argc
,
char
*
argv
[])
{
// init mpi
MPI_Init
(
&
argc
,
&
argv
);
printf
(
"MPI Initialized.
\n
"
);
int
nranks
,
rank
;
// get work size and rank id
MPI_Comm_size
(
MPI_COMM_WORLD
,
&
nranks
);
MPI_Comm_rank
(
MPI_COMM_WORLD
,
&
rank
);
cudaSetDevice
(
rank
);
printf
(
"nranks: %d, rank: %d
\n
"
,
nranks
,
rank
);
// init host and device buffers
using
T
=
float
;
using
ReduceT
=
float
;
const
size_t
num_elems
=
2
*
1024
*
1024
;
std
::
vector
<
T
>
host_buf
(
num_elems
);
for
(
uint32_t
i
=
0
;
i
<
num_elems
;
++
i
)
{
host_buf
[
i
]
=
T
(
i
+
rank
);
}
thrust
::
device_vector
<
T
>
device_buf
(
host_buf
);
const
size_t
buf_size_in_bytes
=
num_elems
*
sizeof
(
T
);
std
::
vector
<
T
>
host_result_buf
(
num_elems
);
thrust
::
device_vector
<
T
>
device_result_buf
(
host_result_buf
);
std
::
vector
<
T
>
host_scratch_buf
(
num_elems
*
8
);
for
(
uint32_t
i
=
0
;
i
<
num_elems
;
++
i
)
{
host_scratch_buf
[
i
]
=
1
;
}
thrust
::
device_vector
<
T
>
device_scratch_buf
(
host_scratch_buf
);
std
::
vector
<
T
>
host_put_buf
(
num_elems
);
thrust
::
device_vector
<
T
>
device_put_buf
(
host_put_buf
);
mscclpp
::
UniqueId
unique_id
;
if
(
rank
==
0
)
unique_id
=
mscclpp
::
TcpBootstrap
::
createUniqueId
();
MPI_Bcast
(
&
unique_id
,
sizeof
(
unique_id
),
MPI_BYTE
,
0
,
MPI_COMM_WORLD
);
std
::
vector
<
int64_t
>
rank_to_node
(
nranks
);
std
::
vector
<
int64_t
>
rank_to_ib
(
nranks
);
for
(
int
i
=
0
;
i
<
nranks
;
i
++
)
{
rank_to_node
[
i
]
=
i
/
8
;
rank_to_ib
[
i
]
=
i
%
8
;
}
cudaStream_t
s
;
CHECK_CUDA_SUCCESS
(
cudaStreamCreate
(
&
s
));
CHECK_CUDA_SUCCESS
(
cudaStreamSynchronize
(
s
));
if
(
nranks
==
8
)
{
auto
context
=
std
::
make_shared
<
sglang
::
Msccl1NodeLLcontext
>
(
unique_id
,
rank
,
nranks
,
thrust
::
raw_pointer_cast
(
device_scratch_buf
.
data
()),
buf_size_in_bytes
*
8
,
rank_to_node
,
rank_to_ib
);
printf
(
"rank: %d, Msccl1NodeLLcontext setup.
\n
"
,
rank
);
MPI_Barrier
(
MPI_COMM_WORLD
);
context
->
allreduce
<
T
>
(
s
,
thrust
::
raw_pointer_cast
(
device_buf
.
data
()),
thrust
::
raw_pointer_cast
(
device_result_buf
.
data
()),
device_buf
.
size
());
}
else
if
(
nranks
==
16
)
{
// TODO: this branch is untested since there is something wrong with mpirun in my test machince
auto
context
=
std
::
make_shared
<
sglang
::
Msccl2NodeLLcontext
>
(
unique_id
,
rank
,
nranks
,
thrust
::
raw_pointer_cast
(
device_scratch_buf
.
data
()),
buf_size_in_bytes
*
8
,
thrust
::
raw_pointer_cast
(
device_put_buf
.
data
()),
buf_size_in_bytes
,
rank_to_node
,
rank_to_ib
);
printf
(
"rank: %d, Msccl2NodeLLcontext setup.
\n
"
,
rank
);
MPI_Barrier
(
MPI_COMM_WORLD
);
context
->
allreduce
<
T
>
(
s
,
thrust
::
raw_pointer_cast
(
device_buf
.
data
()),
thrust
::
raw_pointer_cast
(
device_result_buf
.
data
()),
device_buf
.
size
());
}
// check result correctness
thrust
::
host_vector
<
T
>
host_buf_result
=
device_result_buf
;
size_t
num_results_error_atol_1e_3_rtol_1e_3
=
0
;
bool
nan_detected
=
false
;
for
(
uint32_t
i
=
0
;
i
<
num_elems
;
++
i
)
{
T
expected
=
T
(
i
*
nranks
+
(
nranks
-
1
)
*
nranks
/
2
);
if
(
std
::
isnan
(
float
(
host_buf_result
[
i
])))
{
nan_detected
=
true
;
}
if
(
!
isclose
(
float
(
host_buf_result
[
i
]),
float
(
expected
),
1e-3
,
1e-3
))
{
num_results_error_atol_1e_3_rtol_1e_3
++
;
}
}
float
result_accuracy
=
1.
-
float
(
num_results_error_atol_1e_3_rtol_1e_3
)
/
float
(
num_elems
);
printf
(
"rank: %d, nan_detected: %d accuracy: %f
\n
"
,
rank
,
nan_detected
,
result_accuracy
);
CHECK_CUDA_SUCCESS
(
cudaStreamDestroy
(
s
));
MPI_Finalize
();
return
0
;
}
sgl-kernel/csrc/common_extension.cc
View file @
8e3797be
...
...
@@ -38,6 +38,15 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
"int reg_buffer_sz_bytes) -> ()"
);
m
.
impl
(
"all_reduce"
,
torch
::
kCUDA
,
&
all_reduce
);
m
.
def
(
"mscclpp_generate_unique_id"
,
&
mscclpp_generate_unique_id
);
m
.
def
(
"mscclpp_init_context(Tensor unique_id, int rank, int world_size, Tensor scratch, Tensor put_buffer, "
"int nranks_per_node, int[] rank_to_node, int[] rank_to_ib, int context_selection) -> int"
);
m
.
impl
(
"mscclpp_init_context"
,
torch
::
kCUDA
,
&
mscclpp_init_context
);
m
.
def
(
"mscclpp_allreduce(int context, Tensor inp, Tensor! out, int nthreads, int nblocks) -> ()"
);
m
.
impl
(
"mscclpp_allreduce"
,
torch
::
kCUDA
,
&
mscclpp_allreduce
);
/*
* From csrc/attention
*/
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
8e3797be
...
...
@@ -74,6 +74,18 @@ std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta
void
register_buffer
(
fptr_t
_fa
,
const
std
::
vector
<
fptr_t
>&
fake_ipc_ptrs
);
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
torch
::
Tensor
mscclpp_generate_unique_id
();
fptr_t
mscclpp_init_context
(
const
torch
::
Tensor
&
unique_id
,
const
int64_t
rank
,
const
int64_t
world_size
,
torch
::
Tensor
&
scratch
,
torch
::
Tensor
&
put_buffer
,
const
int64_t
nranks_per_node
,
const
std
::
vector
<
int64_t
>&
rank_to_node
,
const
std
::
vector
<
int64_t
>&
rank_to_ib
,
const
int64_t
context_selection
);
void
mscclpp_allreduce
(
fptr_t
_context
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
int64_t
nthreads
,
int64_t
nblocks
);
#endif
/*
...
...
sgl-kernel/python/sgl_kernel/allreduce.py
View file @
8e3797be
...
...
@@ -49,6 +49,27 @@ if torch.version.hip is not None:
def
get_meta_buffer_ipc_handle
(
inp
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
ops
.
sgl_kernel
.
get_meta_buffer_ipc_handle
.
default
(
inp
)
def
mscclpp_generate_unique_id
()
->
bytes
:
raise
NotImplementedError
()
def
mscclpp_init_context
(
unique_id
:
bytes
,
rank
:
int
,
world_size
:
int
,
scratch
:
torch
.
Tensor
,
put_buffer
:
torch
.
Tensor
,
nranks_per_node
:
int
,
rank_to_node
:
List
[
int
],
rank_to_ib
:
List
[
int
],
context_selection
:
int
,
)
->
int
:
raise
NotImplementedError
()
def
mscclpp_allreduce
(
context
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
nthreads
:
int
,
nblocks
:
int
)
->
None
:
raise
NotImplementedError
()
else
:
def
init_custom_ar
(
...
...
@@ -85,3 +106,36 @@ else:
def
meta_size
()
->
int
:
return
torch
.
ops
.
sgl_kernel
.
meta_size
.
default
()
def
mscclpp_generate_unique_id
()
->
torch
.
Tensor
:
return
torch
.
ops
.
sgl_kernel
.
mscclpp_generate_unique_id
.
default
()
def
mscclpp_init_context
(
unique_id
:
torch
.
Tensor
,
rank
:
int
,
world_size
:
int
,
scratch
:
torch
.
Tensor
,
put_buffer
:
torch
.
Tensor
,
nranks_per_node
:
int
,
rank_to_node
:
List
[
int
],
rank_to_ib
:
List
[
int
],
context_selection
:
int
,
)
->
int
:
return
torch
.
ops
.
sgl_kernel
.
mscclpp_init_context
.
default
(
unique_id
,
rank
,
world_size
,
scratch
,
put_buffer
,
nranks_per_node
,
rank_to_node
,
rank_to_ib
,
context_selection
,
)
def
mscclpp_allreduce
(
context
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
nthreads
:
int
,
nblocks
:
int
)
->
None
:
torch
.
ops
.
sgl_kernel
.
mscclpp_allreduce
.
default
(
context
,
inp
,
out
,
nthreads
,
nblocks
)
sgl-kernel/tests/test_mscclpp.py
0 → 100644
View file @
8e3797be
import
multiprocessing
as
mp
import
os
import
socket
import
unittest
from
enum
import
IntEnum
from
typing
import
Any
import
sgl_kernel.allreduce
as
custom_ops
import
torch
import
torch.distributed
as
dist
class
MscclContextSelection
(
IntEnum
):
MSCCL1SHOT1NODELL
=
1
MSCCL1SHOT2NODELL
=
2
def
_run_correctness_worker
(
world_size
,
rank
,
distributed_init_port
,
test_sizes
):
device
=
torch
.
device
(
f
"cuda:
{
rank
%
torch
.
cuda
.
device_count
()
}
"
)
torch
.
cuda
.
set_device
(
device
)
distributed_init_method
=
f
"tcp://localhost:
{
distributed_init_port
}
"
dist
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
distributed_init_method
,
rank
=
rank
,
world_size
=
world_size
,
)
group
=
dist
.
group
.
WORLD
cpu_group
=
torch
.
distributed
.
new_group
(
list
(
range
(
world_size
)),
backend
=
"gloo"
)
if
rank
==
0
:
unique_id
=
[
custom_ops
.
mscclpp_generate_unique_id
()]
else
:
unique_id
=
[
None
]
dist
.
broadcast_object_list
(
unique_id
,
src
=
0
,
device
=
torch
.
device
(
"cpu"
),
group
=
cpu_group
)
unique_id
=
unique_id
[
0
]
rank_to_node
,
rank_to_ib
=
list
(
range
(
world_size
)),
list
(
range
(
world_size
))
for
r
in
range
(
world_size
):
rank_to_node
[
r
]
=
r
//
8
rank_to_ib
[
r
]
=
rank
%
8
MAX_BYTES
=
2
**
20
scratch
=
torch
.
empty
(
MAX_BYTES
*
8
,
dtype
=
torch
.
bfloat16
,
device
=
torch
.
cuda
.
current_device
()
)
put_buffer
=
torch
.
empty
(
MAX_BYTES
,
dtype
=
torch
.
bfloat16
,
device
=
torch
.
cuda
.
current_device
()
)
print
(
f
"[
{
rank
}
] start mscclpp_context init"
)
nranks_per_node
=
torch
.
cuda
.
device_count
()
selection
=
int
(
MscclContextSelection
.
MSCCL1SHOT1NODELL
)
mscclpp_context
=
custom_ops
.
mscclpp_init_context
(
unique_id
,
rank
,
world_size
,
scratch
,
put_buffer
,
nranks_per_node
,
rank_to_node
,
rank_to_ib
,
selection
,
)
try
:
test_loop
=
10
for
sz
in
test_sizes
:
for
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]:
if
sz
*
dtype
.
itemsize
>
MAX_BYTES
:
continue
if
rank
==
0
:
print
(
f
"mscclpp allreduce test sz
{
sz
}
, dtype
{
dtype
}
"
)
for
_
in
range
(
test_loop
):
inp1
=
torch
.
randint
(
1
,
16
,
(
sz
,),
dtype
=
dtype
,
device
=
device
)
inp1_ref
=
inp1
.
clone
()
out1
=
torch
.
empty_like
(
inp1
)
custom_ops
.
mscclpp_allreduce
(
mscclpp_context
,
inp1
,
out1
,
nthreads
=
512
,
nblocks
=
21
)
dist
.
all_reduce
(
inp1_ref
,
group
=
group
)
torch
.
testing
.
assert_close
(
out1
,
inp1_ref
)
finally
:
dist
.
barrier
(
group
=
group
)
dist
.
destroy_process_group
(
group
=
group
)
def
get_open_port
()
->
int
:
try
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
"127.0.0.1"
,
0
))
return
s
.
getsockname
()[
1
]
except
OSError
:
with
socket
.
socket
(
socket
.
AF_INET6
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
"::1"
,
0
))
return
s
.
getsockname
()[
1
]
def
multi_process_parallel
(
world_size
:
int
,
test_target
:
Any
,
target_args
:
tuple
=
()
)
->
None
:
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
procs
=
[]
distributed_init_port
=
get_open_port
()
for
i
in
range
(
world_size
):
proc_args
=
(
world_size
,
i
,
distributed_init_port
)
+
target_args
proc
=
mp
.
Process
(
target
=
test_target
,
args
=
proc_args
,
name
=
f
"Worker-
{
i
}
"
)
proc
.
start
()
procs
.
append
(
proc
)
for
i
in
range
(
world_size
):
procs
[
i
].
join
()
assert
(
procs
[
i
].
exitcode
==
0
),
f
"Process
{
i
}
failed with exit code
{
procs
[
i
].
exitcode
}
"
class
TestMSCCLAllReduce
(
unittest
.
TestCase
):
test_sizes
=
[
512
,
2560
,
4096
,
5120
,
7680
,
32768
,
262144
,
524288
,
]
world_sizes
=
[
8
]
def
test_correctness
(
self
):
for
world_size
in
self
.
world_sizes
:
available_gpus
=
torch
.
cuda
.
device_count
()
if
world_size
>
available_gpus
:
print
(
f
"Skipping world_size=
{
world_size
}
, found
{
available_gpus
}
and now ray is not supported here"
)
continue
print
(
f
"Running test for world_size=
{
world_size
}
"
)
multi_process_parallel
(
world_size
,
_run_correctness_worker
,
target_args
=
(
self
.
test_sizes
,)
)
print
(
f
"custom allreduce tp =
{
world_size
}
: OK"
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/test_mscclpp.py
0 → 100644
View file @
8e3797be
"""For Now, MSCCL is only supported on TP16 and TP8 case
if [[ $RANK -eq 0 ]]; then
ray start --block --head --port=6379 &
python3 test_mscclpp.py;
else
ray start --block --address=${MASTER_ADDR}:6379;
fi
"""
import
itertools
import
os
import
random
import
socket
import
unittest
from
contextlib
import
contextmanager
,
nullcontext
from
typing
import
Any
,
List
,
Optional
,
Union
import
ray
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
,
ReduceOp
from
sglang.srt.distributed
import
init_distributed_environment
from
sglang.srt.distributed.communication_op
import
(
# noqa
tensor_model_parallel_all_reduce
,
)
from
sglang.srt.distributed.device_communicators.custom_all_reduce
import
(
CustomAllreduce
,
)
from
sglang.srt.distributed.device_communicators.pymscclpp
import
PyMscclppCommunicator
from
sglang.srt.distributed.device_communicators.pynccl
import
PyNcclCommunicator
from
sglang.srt.distributed.parallel_state
import
(
get_tensor_model_parallel_group
,
graph_capture
,
initialize_model_parallel
,
set_custom_all_reduce
,
set_mscclpp_all_reduce
,
)
from
sglang.srt.distributed.utils
import
StatelessProcessGroup
from
sglang.test.test_utils
import
CustomTestCase
def
get_open_port
()
->
int
:
# try ipv4
try
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
""
,
0
))
return
s
.
getsockname
()[
1
]
except
OSError
:
# try ipv6
with
socket
.
socket
(
socket
.
AF_INET6
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
""
,
0
))
return
s
.
getsockname
()[
1
]
def
multi_process_parallel
(
world_size
:
int
,
master_addr
:
str
,
cls
:
Any
,
test_target
:
Any
,
)
->
None
:
# Using ray helps debugging the error when it failed
# as compared to multiprocessing.
# NOTE: We need to set working_dir for distributed tests,
# otherwise we may get import errors on ray workers
ray
.
init
(
log_to_driver
=
True
)
distributed_init_port
=
get_open_port
()
refs
=
[]
for
rank
in
range
(
world_size
):
refs
.
append
(
test_target
.
remote
(
cls
,
world_size
,
master_addr
,
rank
,
distributed_init_port
)
)
ray
.
get
(
refs
)
ray
.
shutdown
()
class
TestMSCCLAllReduce
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
random
.
seed
(
42
)
# 1KB to 1MB
cls
.
test_sizes
=
[
512
,
4096
,
32768
,
262144
,
524288
]
cls
.
world_sizes
=
[
8
]
TEST_TP16
=
int
(
os
.
getenv
(
"SGL_MSCCLPP_TEST_TP16"
,
"0"
))
if
TEST_TP16
:
cls
.
world_sizes
=
[
16
]
cls
.
test_loop
=
10
def
test_graph_allreduce
(
self
):
TEST_MASTER_ADDR
=
os
.
getenv
(
"SGL_MSCCLPP_TEST_MASTER_ADDR"
,
"localhost"
)
for
world_size
in
self
.
world_sizes
:
if
world_size
not
in
[
8
,
16
]:
continue
multi_process_parallel
(
world_size
,
TEST_MASTER_ADDR
,
self
,
self
.
graph_allreduce
)
def
test_eager_allreduce
(
self
):
TEST_MASTER_ADDR
=
os
.
getenv
(
"SGL_MSCCLPP_TEST_MASTER_ADDR"
,
"localhost"
)
for
world_size
in
self
.
world_sizes
:
if
world_size
not
in
[
8
,
16
]:
continue
multi_process_parallel
(
world_size
,
TEST_MASTER_ADDR
,
self
,
self
.
eager_allreduce
)
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
def
graph_allreduce
(
self
,
world_size
,
master_addr
,
rank
,
distributed_init_port
):
del
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
device
=
torch
.
device
(
f
"cuda:
{
rank
%
torch
.
cuda
.
device_count
()
}
"
)
torch
.
cuda
.
set_device
(
device
)
distributed_init_method
=
f
"tcp://
{
master_addr
}
:
{
distributed_init_port
}
"
set_mscclpp_all_reduce
(
True
)
set_custom_all_reduce
(
False
)
init_distributed_environment
(
world_size
=
world_size
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
local_rank
=
rank
%
torch
.
cuda
.
device_count
(),
)
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
group
=
get_tensor_model_parallel_group
().
device_group
# A small all_reduce for warmup.
# this is needed because device communicators might be created lazily
# (e.g. NCCL). This will ensure that the communicator is initialized
# before any communication happens, so that this group can be used for
# graph capture immediately.
data
=
torch
.
zeros
(
1
)
data
=
data
.
to
(
device
=
device
)
torch
.
distributed
.
all_reduce
(
data
,
group
=
group
)
torch
.
cuda
.
synchronize
()
del
data
for
sz
in
self
.
test_sizes
:
for
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]:
for
_
in
range
(
self
.
test_loop
):
with
graph_capture
()
as
graph_capture_context
:
# use integers so result matches NCCL exactly
inp1
=
torch
.
randint
(
1
,
16
,
(
sz
,),
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
inp2
=
torch
.
randint
(
1
,
16
,
(
sz
,),
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
torch
.
cuda
.
synchronize
()
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
,
stream
=
graph_capture_context
.
stream
):
out1
=
tensor_model_parallel_all_reduce
(
inp1
)
# the input buffer is immediately modified to test
# synchronization
dist
.
all_reduce
(
inp1
,
group
=
group
)
out2
=
tensor_model_parallel_all_reduce
(
inp2
)
dist
.
all_reduce
(
inp2
,
group
=
group
)
graph
.
replay
()
torch
.
testing
.
assert_close
(
out1
,
inp1
)
torch
.
testing
.
assert_close
(
out2
,
inp2
)
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
def
eager_allreduce
(
self
,
world_size
,
master_addr
,
rank
,
distributed_init_port
):
del
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
device
=
torch
.
device
(
f
"cuda:
{
rank
%
torch
.
cuda
.
device_count
()
}
"
)
torch
.
cuda
.
set_device
(
device
)
distributed_init_method
=
f
"tcp://
{
master_addr
}
:
{
distributed_init_port
}
"
set_mscclpp_all_reduce
(
True
)
set_custom_all_reduce
(
False
)
init_distributed_environment
(
world_size
=
world_size
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
local_rank
=
rank
,
)
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
group
=
get_tensor_model_parallel_group
().
device_group
for
sz
in
self
.
test_sizes
:
for
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]:
for
_
in
range
(
self
.
test_loop
):
inp1
=
torch
.
randint
(
1
,
16
,
(
sz
,),
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
()
)
out1
=
tensor_model_parallel_all_reduce
(
inp1
)
dist
.
all_reduce
(
inp1
,
group
=
group
)
torch
.
testing
.
assert_close
(
out1
,
inp1
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment