Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9cf40772
Unverified
Commit
9cf40772
authored
Mar 02, 2025
by
Hubert Lu
Committed by
GitHub
Mar 02, 2025
Browse files
Enable custom AR for AMD GPUs and maintain it in sgl-kernel (#3406)
parent
d3fe9bae
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1278 additions
and
191 deletions
+1278
-191
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+81
-33
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
...srt/distributed/device_communicators/custom_all_reduce.py
+222
-80
sgl-kernel/setup_rocm.py
sgl-kernel/setup_rocm.py
+1
-0
sgl-kernel/src/sgl-kernel/__init__.py
sgl-kernel/src/sgl-kernel/__init__.py
+127
-63
sgl-kernel/src/sgl-kernel/csrc/custom_all_reduce.hip
sgl-kernel/src/sgl-kernel/csrc/custom_all_reduce.hip
+180
-0
sgl-kernel/src/sgl-kernel/csrc/custom_all_reduce_hip.cuh
sgl-kernel/src/sgl-kernel/csrc/custom_all_reduce_hip.cuh
+554
-0
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
+17
-1
sgl-kernel/src/sgl-kernel/ops/__init__.py
sgl-kernel/src/sgl-kernel/ops/__init__.py
+65
-14
sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc
sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc
+31
-0
No files found.
python/sglang/srt/_custom_ops.py
View file @
9cf40772
...
...
@@ -9,13 +9,13 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import
torch
import
torch.library
from
sglang.srt.utils
import
is_hpu
from
sglang.srt.utils
import
is_hip
,
is_hpu
logger
=
logging
.
getLogger
(
__name__
)
use_vllm_custom_allreduce
=
os
.
environ
.
get
(
"USE_VLLM_CUSTOM_ALLREDUCE"
,
default
=
True
)
if
not
is_hpu
():
if
use_vllm_custom_allreduce
:
# Remove vllm dependency for custom allreduce on ROCm
if
use_vllm_custom_allreduce
and
not
is_hip
():
try
:
import
vllm._C
except
ImportError
as
e
:
...
...
@@ -56,7 +56,7 @@ def hint_on_error(fn):
return
wrapper
if
use_vllm_custom_allreduce
:
if
use_vllm_custom_allreduce
and
not
is_hip
()
:
# custom ar
def
init_custom_ar
(
ipc_tensors
:
List
[
torch
.
Tensor
],
...
...
@@ -95,39 +95,87 @@ if use_vllm_custom_allreduce:
torch
.
ops
.
_C_custom_ar
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
else
:
# custom ar
def
init_custom_ar
(
rank_id
:
int
,
world_size
:
int
,
rank_data_base
:
torch
.
Tensor
,
buffers
:
List
[
int
],
tmp_result_buffers
:
List
[
int
],
barrier_in
:
List
[
int
],
barrier_out
:
List
[
int
],
)
->
int
:
return
sgl_kernel
.
ops
.
init_custom_reduce
(
rank_id
,
world_size
,
rank_data_base
,
buffers
,
tmp_result_buffers
,
barrier_in
,
barrier_out
,
)
if
is_hip
():
def
init_custom_ar
(
meta
:
torch
.
Tensor
,
rank_data
:
torch
.
Tensor
,
handles
:
List
[
str
],
offsets
:
List
[
int
],
rank
:
int
,
full_nvlink
:
bool
,
)
->
int
:
return
sgl_kernel
.
ops
.
init_custom_ar
(
meta
,
rank_data
,
handles
,
offsets
,
rank
,
full_nvlink
)
def
all_reduce
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
sgl_kernel
.
ops
.
custom
_reduce
(
fa
,
inp
,
out
)
def
all_reduce
_reg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
sgl_kernel
.
ops
.
all
_reduce
_reg
(
fa
,
inp
,
out
)
def
dispose
(
fa
:
int
)
->
None
:
sgl_kernel
.
ops
.
custom_dispose
(
fa
)
def
all_reduce_unreg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
reg_buffer
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
sgl_kernel
.
ops
.
all_reduce_unreg
(
fa
,
inp
,
reg_buffer
,
out
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
List
[
int
],
List
[
int
]]
:
return
sgl_kernel
.
ops
.
get_graph_buffer_ipc_meta
(
fa
)
def
dispose
(
fa
:
int
)
->
None
:
sgl_kernel
.
ops
.
dispose
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
List
[
int
]],
offsets
:
List
[
List
[
int
]]
)
->
None
:
sgl_kernel
.
ops
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
def
meta_size
()
->
int
:
return
sgl_kernel
.
ops
.
meta_size
()
def
register_buffer
(
fa
:
int
,
t
:
torch
.
Tensor
,
handles
:
List
[
str
],
offsets
:
List
[
int
]
)
->
None
:
return
sgl_kernel
.
ops
.
register_buffer
(
fa
,
t
,
handles
,
offsets
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
torch
.
Tensor
,
List
[
int
]]:
return
sgl_kernel
.
ops
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
str
],
offsets
:
List
[
List
[
int
]]
)
->
None
:
sgl_kernel
.
ops
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
def
allocate_meta_buffer
(
size
:
int
)
->
torch
.
Tensor
:
return
sgl_kernel
.
ops
.
allocate_meta_buffer
(
size
)
def
get_meta_buffer_ipc_handle
(
inp
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
sgl_kernel
.
ops
.
get_meta_buffer_ipc_handle
(
inp
)
else
:
# custom ar
def
init_custom_ar
(
rank_id
:
int
,
world_size
:
int
,
rank_data_base
:
torch
.
Tensor
,
buffers
:
List
[
int
],
tmp_result_buffers
:
List
[
int
],
barrier_in
:
List
[
int
],
barrier_out
:
List
[
int
],
)
->
int
:
return
sgl_kernel
.
ops
.
init_custom_reduce
(
rank_id
,
world_size
,
rank_data_base
,
buffers
,
tmp_result_buffers
,
barrier_in
,
barrier_out
,
)
def
all_reduce
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
sgl_kernel
.
ops
.
custom_reduce
(
fa
,
inp
,
out
)
def
dispose
(
fa
:
int
)
->
None
:
sgl_kernel
.
ops
.
custom_dispose
(
fa
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
return
sgl_kernel
.
ops
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
List
[
int
]],
offsets
:
List
[
List
[
int
]]
)
->
None
:
sgl_kernel
.
ops
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
...
...
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
View file @
9cf40772
...
...
@@ -18,7 +18,7 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import
gpu_p2p_access_check
,
)
from
sglang.srt.distributed.parallel_state
import
in_the_same_node_as
from
sglang.srt.utils
import
cuda_device_count_stateless
,
is_cuda
from
sglang.srt.utils
import
cuda_device_count_stateless
,
is_cuda
,
is_hip
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -28,14 +28,27 @@ if is_cuda():
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import pynvml with %r"
,
e
)
if
is_hip
():
try
:
from
amdsmi
import
(
AmdSmiException
,
amdsmi_get_gpu_board_info
,
amdsmi_get_processor_handles
,
amdsmi_init
,
amdsmi_shut_down
,
amdsmi_topo_get_link_type
,
)
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import amdsmi with %r"
,
e
)
try
:
if
ops
.
use_vllm_custom_allreduce
:
if
ops
.
use_vllm_custom_allreduce
and
not
is_hip
()
:
ops
.
meta_size
()
else
:
import
sgl_kernel
custom_ar
=
True
except
Exception
:
# For
AMD GPUs and
CPUs
# For CPUs
custom_ar
=
False
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -47,37 +60,62 @@ _R = TypeVar("_R")
def
with_nvml_context
(
fn
:
Callable
[
_P
,
_R
])
->
Callable
[
_P
,
_R
]:
@
wraps
(
fn
)
def
wrapper
(
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
_R
:
pynvml
.
nvmlInit
()
try
:
return
fn
(
*
args
,
**
kwargs
)
finally
:
pynvml
.
nvmlShutdown
()
if
torch
.
version
.
hip
:
try
:
amdsmi_init
()
return
fn
(
*
args
,
**
kwargs
)
finally
:
amdsmi_shut_down
()
else
:
pynvml
.
nvmlInit
()
try
:
return
fn
(
*
args
,
**
kwargs
)
finally
:
pynvml
.
nvmlShutdown
()
return
wrapper
@
with_nvml_context
def
is_full_nvlink
(
physical_device_ids
:
List
[
int
])
->
bool
:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
handles
=
[
pynvml
.
nvmlDeviceGetHandleByIndex
(
i
)
for
i
in
physical_device_ids
]
for
i
,
handle
in
enumerate
(
handles
):
for
j
,
peer_handle
in
enumerate
(
handles
):
if
i
<
j
:
try
:
p2p_status
=
pynvml
.
nvmlDeviceGetP2PStatus
(
handle
,
peer_handle
,
pynvml
.
NVML_P2P_CAPS_INDEX_NVLINK
)
if
p2p_status
!=
pynvml
.
NVML_P2P_STATUS_OK
:
def
is_full_nvlink
(
physical_device_ids
:
List
[
int
],
world_size
:
int
)
->
bool
:
if
is_hip
():
"""
query if the set of gpus are fully connected by xgmi (1 hop)
"""
handles
=
[
amdsmi_get_processor_handles
()[
i
]
for
i
in
physical_device_ids
]
for
i
,
handle
in
enumerate
(
handles
):
for
j
,
peer_handle
in
enumerate
(
handles
):
if
i
<
j
:
try
:
link_type
=
amdsmi_topo_get_link_type
(
handle
,
peer_handle
)
# type is 2 for XGMI
if
link_type
[
"hops"
]
!=
1
or
link_type
[
"type"
]
!=
2
:
return
False
except
AmdSmiException
as
error
:
logger
.
error
(
"AMD 1 hop XGMI detection failed."
,
exc_info
=
error
)
return
False
except
pynvml
.
NVMLError
:
logger
.
exception
(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped."
)
return
False
return
True
return
True
else
:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
handles
=
[
pynvml
.
nvmlDeviceGetHandleByIndex
(
i
)
for
i
in
physical_device_ids
]
for
i
,
handle
in
enumerate
(
handles
):
for
j
,
peer_handle
in
enumerate
(
handles
):
if
i
<
j
:
try
:
p2p_status
=
pynvml
.
nvmlDeviceGetP2PStatus
(
handle
,
peer_handle
,
pynvml
.
NVML_P2P_CAPS_INDEX_NVLINK
)
if
p2p_status
!=
pynvml
.
NVML_P2P_STATUS_OK
:
return
False
except
pynvml
.
NVMLError
:
logger
.
exception
(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped."
)
return
False
return
True
def
_can_p2p
(
rank
:
int
,
world_size
:
int
)
->
bool
:
...
...
@@ -102,15 +140,18 @@ def is_weak_contiguous(inp: torch.Tensor):
class
CustomAllreduce
:
_SUPPORTED_WORLD_SIZES
=
[
2
,
4
,
6
,
8
]
_MAX_CAR_SIZE
=
8192
*
1024
if
is_hip
():
# crossover is at 16MB buffer size for ROCm
_MAX_CAR_SIZE
=
2
*
8192
*
1024
# max_size: max supported allreduce size
def
__init__
(
self
,
group
:
ProcessGroup
,
device
:
Union
[
int
,
str
,
torch
.
device
],
max_size
=
8192
*
1024
,
max_size
=
_MAX_CAR_SIZE
,
)
->
None
:
"""
Args:
...
...
@@ -185,12 +226,9 @@ class CustomAllreduce:
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
if
is_cuda
():
assert
is_cuda
(
)
if
is_cuda
()
or
is_hip
()
:
full_nvlink
=
is_full_nvlink
(
physical_device_ids
,
world_size
)
full_nvlink
=
is_full_nvlink
(
physical_device_ids
)
else
:
full_nvlink
=
False
if
world_size
>
2
and
not
full_nvlink
:
logger
.
warning
(
"Custom allreduce is disabled because it's not supported on"
...
...
@@ -201,7 +239,8 @@ class CustomAllreduce:
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
if
not
_can_p2p
(
rank
,
world_size
):
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
if
not
is_hip
()
and
not
_can_p2p
(
rank
,
world_size
):
logger
.
warning
(
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
...
...
@@ -214,7 +253,7 @@ class CustomAllreduce:
self
.
world_size
=
world_size
self
.
full_nvlink
=
full_nvlink
if
ops
.
use_vllm_custom_allreduce
:
if
ops
.
use_vllm_custom_allreduce
and
not
is_hip
()
:
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
...
...
@@ -237,35 +276,56 @@ class CustomAllreduce:
)
ops
.
register_buffer
(
self
.
_ptr
,
self
.
buffer_ptrs
)
else
:
# From TensorRT-LLM getMaxRequiredWorkspaceSize
self
.
max_required_workspace_size
=
[
16
*
1024
*
1024
,
8
*
1024
*
1024
]
if
is_hip
():
# meta data buffers need to be "uncached" for signal on MI200
self
.
meta
=
ops
.
allocate_meta_buffer
(
ops
.
meta_size
()
+
max_size
)
self
.
buffer
=
torch
.
empty
(
max_size
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
handle
=
ops
.
get_meta_buffer_ipc_handle
(
self
.
meta
)
shard_data
=
(
bytes
(
handle
),
# ipc handle to base ptr
0
,
# offset of base ptr
)
handles
,
offsets
=
self
.
_gather_ipc_meta
(
shard_data
)
self
.
rank_data
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
_ptr
=
ops
.
init_custom_ar
(
self
.
meta
,
self
.
rank_data
,
handles
,
offsets
,
rank
,
self
.
full_nvlink
)
self
.
register_buffer
(
self
.
buffer
)
self
.
MSCCL
=
os
.
getenv
(
"RCCL_MSCCL_ENABLE"
,
"1"
)
==
"1"
else
:
# From TensorRT-LLM getMaxRequiredWorkspaceSize
self
.
max_required_workspace_size
=
[
16
*
1024
*
1024
,
8
*
1024
*
1024
]
# sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
self
.
barrier_max_size
=
8
*
(
36
+
2
)
*
8
# sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
self
.
barrier_max_size
=
8
*
(
36
+
2
)
*
8
self
.
buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
self
.
tmp_result_buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
self
.
rank_data_base
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
barrier_in_ptrs
=
self
.
create_shared_buffer
(
self
.
barrier_max_size
,
group
=
group
)
self
.
barrier_out_ptrs
=
self
.
create_shared_buffer
(
self
.
barrier_max_size
,
group
=
group
)
self
.
buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
self
.
tmp_result_buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
self
.
rank_data_base
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
barrier_in_ptrs
=
self
.
create_shared_buffer
(
self
.
barrier_max_size
,
group
=
group
)
self
.
barrier_out_ptrs
=
self
.
create_shared_buffer
(
self
.
barrier_max_size
,
group
=
group
)
self
.
_ptr
=
ops
.
init_custom_ar
(
rank
,
world_size
,
self
.
rank_data_base
,
self
.
buffer_ptrs
,
self
.
tmp_result_buffer_ptrs
,
self
.
barrier_in_ptrs
,
self
.
barrier_out_ptrs
,
)
self
.
_ptr
=
ops
.
init_custom_ar
(
rank
,
world_size
,
self
.
rank_data_base
,
self
.
buffer_ptrs
,
self
.
tmp_result_buffer_ptrs
,
self
.
barrier_in_ptrs
,
self
.
barrier_out_ptrs
,
)
self
.
disabled
=
False
@
staticmethod
...
...
@@ -316,23 +376,69 @@ class CustomAllreduce:
if
not
self
.
disabled
:
self
.
register_graph_buffers
()
def
register_graph_buffers
(
self
):
handle
,
offset
=
ops
.
get_graph_buffer_ipc_meta
(
self
.
_ptr
)
logger
.
info
(
"Registering %d cuda graph addresses"
,
len
(
offset
))
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data
=
[[
None
,
None
]
for
_
in
range
(
dist
.
get_world_size
(
group
=
self
.
group
))]
all_data
[
self
.
rank
]
=
[
handle
,
offset
]
ranks
=
sorted
(
dist
.
get_process_group_ranks
(
group
=
self
.
group
))
def
_get_ipc_meta
(
self
,
inp
:
torch
.
Tensor
):
# _share_cuda_() doesn't accept meta buffer not allocated from
# PyTorch cache allocator, use direct HIP call to get IPC handle
handle
=
ops
.
get_meta_buffer_ipc_handle
(
inp
)
shard_data
=
(
bytes
(
handle
),
# ipc handle to base ptr
0
,
# offset of base ptr
)
return
self
.
_gather_ipc_meta
(
shard_data
)
def
_gather_ipc_meta
(
self
,
shard_data
):
# Note: don't use `[[None]] * self.world_size` here
# because it will create a list of the same reference
all_data
:
List
[
Optional
[
Any
]]
=
[[
None
]
for
i
in
range
(
self
.
world_size
)]
all_data
[
self
.
rank
][
0
]
=
shard_data
ranks
=
dist
.
get_process_group_ranks
(
group
=
self
.
group
)
ranks
.
sort
()
for
i
,
rank
in
enumerate
(
ranks
):
dist
.
broadcast_object_list
(
all_data
[
i
],
src
=
rank
,
group
=
self
.
group
,
device
=
"cpu"
)
# Unpack list of tuples to tuple of lists.
handles
=
[
d
[
0
]
for
d
in
all_data
]
# type: ignore
offsets
=
[
d
[
1
]
for
d
in
all_data
]
# type: ignore
ops
.
register_graph_buffers
(
self
.
_ptr
,
handles
,
offsets
)
# we cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
handles
=
[]
offsets
=
[]
for
i
in
range
(
len
(
all_data
)):
handles
.
append
(
all_data
[
i
][
0
][
0
])
# type: ignore
offsets
.
append
(
all_data
[
i
][
0
][
1
])
# type: ignore
return
handles
,
offsets
def
register_buffer
(
self
,
inp
:
torch
.
Tensor
):
handles
,
offsets
=
self
.
_get_ipc_meta
(
inp
)
ops
.
register_buffer
(
self
.
_ptr
,
inp
,
handles
,
offsets
)
def
register_graph_buffers
(
self
):
if
is_hip
():
handle
,
offset
=
ops
.
get_graph_buffer_ipc_meta
(
self
.
_ptr
)
handles
,
offsets
=
self
.
_gather_ipc_meta
((
bytes
(
handle
),
offset
))
logger
.
info
(
"Registering %d cuda graph addresses"
,
len
(
offset
))
ops
.
register_graph_buffers
(
self
.
_ptr
,
handles
,
offsets
)
else
:
handle
,
offset
=
ops
.
get_graph_buffer_ipc_meta
(
self
.
_ptr
)
logger
.
info
(
"Registering %d cuda graph addresses"
,
len
(
offset
))
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data
=
[
[
None
,
None
]
for
_
in
range
(
dist
.
get_world_size
(
group
=
self
.
group
))
]
all_data
[
self
.
rank
]
=
[
handle
,
offset
]
ranks
=
sorted
(
dist
.
get_process_group_ranks
(
group
=
self
.
group
))
for
i
,
rank
in
enumerate
(
ranks
):
dist
.
broadcast_object_list
(
all_data
[
i
],
src
=
rank
,
group
=
self
.
group
,
device
=
"cpu"
)
# Unpack list of tuples to tuple of lists.
handles
=
[
d
[
0
]
for
d
in
all_data
]
# type: ignore
offsets
=
[
d
[
1
]
for
d
in
all_data
]
# type: ignore
ops
.
register_graph_buffers
(
self
.
_ptr
,
handles
,
offsets
)
def
should_custom_ar
(
self
,
inp
:
torch
.
Tensor
):
if
self
.
disabled
:
...
...
@@ -345,11 +451,22 @@ class CustomAllreduce:
return
False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
if
ops
.
use_vllm_custom_allreduce
:
if
ops
.
use_vllm_custom_allreduce
and
not
is_hip
()
:
if
self
.
world_size
==
2
or
self
.
full_nvlink
:
return
inp_size
<
self
.
max_size
return
False
if
is_hip
():
if
self
.
full_nvlink
:
if
self
.
world_size
==
8
:
if
self
.
MSCCL
:
return
False
else
:
return
inp_size
<
self
.
max_size
else
:
return
inp_size
<
self
.
max_size
return
False
if
self
.
world_size
==
2
:
return
(
inp_size
<
self
.
max_size
...
...
@@ -364,6 +481,21 @@ class CustomAllreduce:
return
False
# all reduce, assuming inp tensor is IPC registered with register_buffer,
# or, in the context of cuda graphs, register_graph_buffers
def
all_reduce_reg
(
self
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
=
None
):
if
out
is
None
:
out
=
torch
.
empty_like
(
inp
)
ops
.
all_reduce_reg
(
self
.
_ptr
,
inp
,
out
)
return
out
# all reduce, assuming inp tensor is NOT IPC registered
def
all_reduce_unreg
(
self
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
=
None
):
if
out
is
None
:
out
=
torch
.
empty_like
(
inp
)
ops
.
all_reduce_unreg
(
self
.
_ptr
,
inp
,
self
.
buffer
,
out
)
return
out
def
all_reduce
(
self
,
inp
:
torch
.
Tensor
,
...
...
@@ -397,13 +529,23 @@ class CustomAllreduce:
return
None
if
self
.
_IS_CAPTURING
:
if
torch
.
cuda
.
is_current_stream_capturing
():
return
self
.
all_reduce
(
input
,
registered
=
True
)
if
is_hip
():
return
self
.
all_reduce_reg
(
input
)
else
:
return
self
.
all_reduce
(
input
,
registered
=
True
)
else
:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
return
torch
.
empty_like
(
input
)
else
:
return
self
.
all_reduce
(
input
,
registered
=
False
)
if
is_hip
():
# note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance
# gains of using custom kernels
return
self
.
all_reduce_unreg
(
input
)
else
:
return
self
.
all_reduce
(
input
,
registered
=
False
)
def
close
(
self
):
if
not
self
.
disabled
and
self
.
_ptr
:
...
...
@@ -411,7 +553,7 @@ class CustomAllreduce:
if
ops
.
use_vllm_custom_allreduce
:
self
.
free_shared_buffer
(
self
.
meta_ptrs
)
self
.
free_shared_buffer
(
self
.
buffer_ptrs
)
el
se
:
el
if
is_cuda
()
:
self
.
free_shared_buffer
(
self
.
buffer_ptrs
)
self
.
free_shared_buffer
(
self
.
tmp_result_buffer_ptrs
)
self
.
free_shared_buffer
(
self
.
barrier_in_ptrs
)
...
...
sgl-kernel/setup_rocm.py
View file @
9cf40772
...
...
@@ -44,6 +44,7 @@ include_dirs = [
sources
=
[
"src/sgl-kernel/torch_extension_rocm.cc"
,
"src/sgl-kernel/csrc/moe_align_kernel.cu"
,
"src/sgl-kernel/csrc/custom_all_reduce.hip"
,
]
cxx_flags
=
[
"-O3"
]
...
...
sgl-kernel/src/sgl-kernel/__init__.py
View file @
9cf40772
import
ctypes
import
os
import
torch
if
os
.
path
.
exists
(
"/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"
):
ctypes
.
CDLL
(
"/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"
,
mode
=
ctypes
.
RTLD_GLOBAL
,
)
from
.version
import
__version__
from
sgl_kernel.ops
import
(
apply_rope_with_cos_sin_cache_inplace
,
bmm_fp8
,
build_tree_kernel
,
build_tree_kernel_efficient
,
cublas_g
ro
u
pe
d_gemm
,
custom_dispose
,
custom_reduc
e
,
fp8_
blockwise_
scaled_mm
,
f
p8_scaled_m
m
,
fused
_a
d
d_
rmsnorm
,
gelu_and_mul
,
ge
lu_tanh
_a
n
d_
mul
,
gemma_
fused_add_
rmsnorm
,
ge
mma_rmsnorm
,
get_
graph
_buffer_ipc_
meta
,
init_custom_r
educe
,
int8_scaled_mm
,
lightning_attention_decode
,
m
in_p_sampling_from_probs
,
m
oe_align_block_size
,
register_graph_buffers
,
r
msnorm
,
sampling_scaling_penaltie
s
,
sgl_per_token_group_quant_fp8
,
s
ilu_and_mul
,
top_k_renorm_prob
,
top_k_
top_p_sampling_fro
m_prob
s
,
top_
p_renor
m_prob
,
t
ree_speculative_sampling_target_only
,
)
if
torch
.
version
.
hip
is
not
None
:
from
sgl_kernel.ops
import
(
all_reduce_reg
,
all_reduce_unreg
,
allocate_meta_buffer
,
apply_
rope
_with_cos_sin_cache_inplace
,
bmm_fp8
,
dispos
e
,
fp8_scaled_mm
,
f
used_add_rmsnor
m
,
gelu
_a
n
d_
mul
,
gelu_
tanh_
and_mul
,
ge
mma_fused
_a
d
d_
rmsnorm
,
gemma_rmsnorm
,
ge
t_graph_buffer_ipc_meta
,
get_
meta
_buffer_ipc_
handle
,
init_custom_
a
r
,
int8_scaled_mm
,
lightning_attention_decode
,
m
eta_size
,
m
in_p_sampling_from_probs
,
moe_align_block_size
,
r
egister_buffer
,
register_graph_buffer
s
,
rmsnorm
,
s
ampling_scaling_penalties
,
silu_and_mul
,
top_k_
renor
m_prob
,
top_
k_top_p_sampling_fro
m_prob
s
,
t
op_p_renorm_prob
,
)
from
.version
import
__version__
__all__
=
[
"all_reduce_reg"
,
"all_reduce_unreg"
,
"allocate_meta_buffer"
,
"apply_rope_with_cos_sin_cache_inplace"
,
"bmm_fp8"
,
"dispose"
,
"fp8_scaled_mm"
,
"fused_add_rmsnorm"
,
"gelu_and_mul"
,
"gelu_tanh_and_mul"
,
"gemma_fused_add_rmsnorm"
,
"gemma_rmsnorm"
,
"get_graph_buffer_ipc_meta"
,
"get_meta_buffer_ipc_handle"
,
"init_custom_ar"
,
"int8_scaled_mm"
,
"lightning_attention_decode"
,
"meta_size"
,
"min_p_sampling_from_probs"
,
"moe_align_block_size"
,
"register_buffer"
,
"register_graph_buffers"
,
"rmsnorm"
,
"sampling_scaling_penalties"
,
"silu_and_mul"
,
"top_k_renorm_prob"
,
"top_k_top_p_sampling_from_probs"
,
"top_p_renorm_prob"
,
]
else
:
from
sgl_kernel.ops
import
(
apply_rope_with_cos_sin_cache_inplace
,
bmm_fp8
,
build_tree_kernel
,
build_tree_kernel_efficient
,
cublas_grouped_gemm
,
custom_dispose
,
custom_reduce
,
fp8_blockwise_scaled_mm
,
fp8_scaled_mm
,
fused_add_rmsnorm
,
gelu_and_mul
,
gelu_tanh_and_mul
,
gemma_fused_add_rmsnorm
,
gemma_rmsnorm
,
get_graph_buffer_ipc_meta
,
init_custom_reduce
,
int8_scaled_mm
,
lightning_attention_decode
,
min_p_sampling_from_probs
,
moe_align_block_size
,
register_graph_buffers
,
rmsnorm
,
sampling_scaling_penalties
,
sgl_per_token_group_quant_fp8
,
silu_and_mul
,
top_k_renorm_prob
,
top_k_top_p_sampling_from_probs
,
top_p_renorm_prob
,
tree_speculative_sampling_target_only
,
)
__all__
=
[
"apply_rope_with_cos_sin_cache_inplace"
,
"bmm_fp8"
,
"cublas_grouped_gemm"
,
"custom_dispose"
,
"custom_reduce"
,
"fp8_blockwise_scaled_mm"
,
"fp8_scaled_mm"
,
"fused_add_rmsnorm"
,
"gelu_and_mul"
,
"gelu_tanh_and_mul"
,
"gemma_fused_add_rmsnorm"
,
"gemma_rmsnorm"
,
"get_graph_buffer_ipc_meta"
,
"init_custom_reduce"
,
"int8_scaled_mm"
,
"lightning_attention_decode"
,
"min_p_sampling_from_probs"
,
"moe_align_block_size"
,
"register_graph_buffers"
,
"rmsnorm"
,
"sampling_scaling_penalties"
,
"silu_and_mul"
,
"top_k_renorm_prob"
,
"top_k_top_p_sampling_from_probs"
,
"top_p_renorm_prob"
,
"tree_speculative_sampling_target_only"
,
"build_tree_kernel_efficient"
,
"build_tree_kernel"
,
"sgl_per_token_group_quant_fp8"
,
]
__all__
=
[
"apply_rope_with_cos_sin_cache_inplace"
,
"bmm_fp8"
,
"cublas_grouped_gemm"
,
"custom_dispose"
,
"custom_reduce"
,
"fp8_blockwise_scaled_mm"
,
"fp8_scaled_mm"
,
"fused_add_rmsnorm"
,
"gelu_and_mul"
,
"gelu_tanh_and_mul"
,
"gemma_fused_add_rmsnorm"
,
"gemma_rmsnorm"
,
"get_graph_buffer_ipc_meta"
,
"init_custom_reduce"
,
"int8_scaled_mm"
,
"lightning_attention_decode"
,
"min_p_sampling_from_probs"
,
"moe_align_block_size"
,
"register_graph_buffers"
,
"rmsnorm"
,
"sampling_scaling_penalties"
,
"silu_and_mul"
,
"top_k_renorm_prob"
,
"top_k_top_p_sampling_from_probs"
,
"top_p_renorm_prob"
,
"tree_speculative_sampling_target_only"
,
"build_tree_kernel_efficient"
,
"build_tree_kernel"
,
"sgl_per_token_group_quant_fp8"
,
]
sgl-kernel/src/sgl-kernel/csrc/custom_all_reduce.hip
0 → 100644
View file @
9cf40772
// !!! This is a file automatically generated by hipify!!!
#include <ATen/hip/Exceptions.h>
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
#include <torch/all.h>
#include "custom_all_reduce_hip.cuh"
// fake pointer type, must match fptr_t type in ops.h
using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets, int64_t rank,
bool full_nvlink) {
int world_size = offsets.size();
if (world_size > 8)
throw std::invalid_argument("world size > 8 is not supported");
if (world_size % 2 != 0)
throw std::invalid_argument("Odd num gpus is not supported for now");
if (world_size != handles.size())
throw std::invalid_argument(
"handles length should equal to offsets length");
if (rank < 0 || rank >= world_size)
throw std::invalid_argument("invalid rank passed in");
hipIpcMemHandle_t ipc_handles[8];
for (int i = 0; i < world_size; i++) {
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(hipIpcMemHandle_t));
}
return (fptr_t) new vllm::CustomAllreduce(
reinterpret_cast<vllm::Signal*>(meta.data_ptr()), rank_data.data_ptr(),
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
}
/**
* Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
* t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
* because it allows transpose of contiguous slice (i.e. slicing the first
* dimension). Currently, we require this because stride information is not
* passed into the kernels and we treat input tensors as flat.
*
* Examples
* A = torch.zeros(3, 3, 3)
* 1. A: OK
* 2. A[1:]: OK
* 3. A.permute(2, 0, 1): OK
* 4. A[1:].permute(2, 0, 1): OK
* 5. A[None].expand(2, -1, -1, -1): Not OK
* 6. A[:, 1:, 1:]: Not OK
*/
bool _is_weak_contiguous(torch::Tensor& t) {
return t.is_contiguous() ||
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
t.numel() * t.element_size());
}
void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
hipStream_t stream) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
TORCH_CHECK(_is_weak_contiguous(out));
switch (out.scalar_type()) {
case at::ScalarType::Float: {
fa->allreduce<float>(stream, reinterpret_cast<float*>(inp.data_ptr()),
reinterpret_cast<float*>(out.data_ptr()),
out.numel());
break;
}
case at::ScalarType::Half: {
fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half*>(out.data_ptr()), out.numel());
break;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: {
fa->allreduce<nv_bfloat16>(
stream, reinterpret_cast<nv_bfloat16*>(inp.data_ptr()),
reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
break;
}
#endif
default:
throw std::runtime_error(
"custom allreduce only supports float32, float16 and bfloat16");
}
}
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp));
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
_all_reduce(_fa, inp, out, stream);
}
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor& out) {
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp));
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
auto input_size = inp.numel() * inp.element_size();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
"registered buffer is too small to contain the input");
AT_CUDA_CHECK(hipMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
input_size, hipMemcpyDeviceToDevice, stream));
_all_reduce(_fa, reg_buffer, out, stream);
}
void dispose(fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
delete fa;
}
int64_t meta_size() { return sizeof(vllm::Signal); }
void register_buffer(fptr_t _fa, torch::Tensor& t,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
fa->register_buffer(handles, offsets, t.data_ptr());
}
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto handles =
torch::empty({static_cast<int64_t>(handle_bytes.size())}, options);
std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size());
return {handles, std::move(offsets)};
}
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
fa->register_graph_buffers(handles, offsets);
}
void free_meta_buffer(void* buffer) { CUDACHECK(hipFree(buffer)); }
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp) {
auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto data_handle =
torch::empty({static_cast<int64_t>(sizeof(hipIpcMemHandle_t))}, options);
CUDACHECK(hipIpcGetMemHandle((hipIpcMemHandle_t*)data_handle.data_ptr(),
inp.data_ptr()));
return data_handle;
}
torch::Tensor allocate_meta_buffer(int64_t size) {
auto device_index = c10::hip::current_device();
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
void* buffer;
hipStreamCaptureMode mode = hipStreamCaptureModeRelaxed;
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
AT_CUDA_CHECK(hipThreadExchangeStreamCaptureMode(&mode));
AT_CUDA_CHECK(
hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached));
AT_CUDA_CHECK(hipMemsetAsync(buffer, 0, size, stream));
AT_CUDA_CHECK(hipStreamSynchronize(stream));
AT_CUDA_CHECK(hipThreadExchangeStreamCaptureMode(&mode));
auto options = torch::TensorOptions()
.dtype(torch::kI8)
.device(torch::kCUDA, device_index);
return torch::from_blob(buffer, {size}, free_meta_buffer, options);
}
std::vector<uint8_t> get_device_bdf(int dev) {
char busIdStr[] = "0000:00:00.0";
std::vector<uint8_t> bdf(sizeof(busIdStr), 0);
CUDACHECK(hipDeviceGetPCIBusId((char*)bdf.data(), sizeof(busIdStr), dev));
bdf.resize(bdf.size() - 1); // remove trailing NULL
return bdf;
}
sgl-kernel/src/sgl-kernel/csrc/custom_all_reduce_hip.cuh
0 → 100644
View file @
9cf40772
// !!! This is a file automatically generated by hipify!!!
#pragma once
#include <hip/hip_runtime.h>
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
typedef
__hip_bfloat16
nv_bfloat16
;
#else
#include <hip/hip_bf16.h>
#endif
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <iostream>
#include <limits>
#include <map>
#include <unordered_map>
#include <vector>
#define CUDACHECK(cmd) \
do { \
hipError_t e = cmd; \
if (e != hipSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, hipGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while (0)
namespace
vllm
{
constexpr
int
kMaxBlocks
=
64
;
// note: we don't want to use atomics for signals because peer atomics are no
// supported on PCIe links
struct
Signal
{
alignas
(
128
)
uint32_t
start
[
kMaxBlocks
][
8
];
alignas
(
128
)
uint32_t
end
[
kMaxBlocks
][
8
];
alignas
(
128
)
uint32_t
_flag
[
kMaxBlocks
];
// incremental flags for each rank
};
#ifdef USE_ROCM
struct
__align__
(
16
)
RankData
{
const
void
*
ptrs
[
8
];
};
#else
struct
__align__
(
16
)
RankData
{
const
void
*
__restrict__
ptrs
[
8
];
};
#endif
struct
__align__
(
16
)
RankSignals
{
#ifndef USE_ROCM
volatile
#endif
Signal
*
signals
[
8
];
};
// like std::array, but aligned
template
<
typename
T
,
int
sz
>
struct
__align__
(
alignof
(
T
)
*
sz
)
array_t
{
T
data
[
sz
];
using
type
=
T
;
static
constexpr
int
size
=
sz
;
};
// use packed type to maximize memory efficiency
// goal: generate ld.128 and st.128 instructions
template
<
typename
T
>
struct
packed_t
{
// the (P)acked type for load/store
using
P
=
array_t
<
T
,
16
/
sizeof
(
T
)
>
;
// the (A)ccumulator type for reduction
using
A
=
array_t
<
float
,
16
/
sizeof
(
T
)
>
;
};
#define DINLINE __device__ __forceinline__
// scalar cast functions
DINLINE
float
upcast_s
(
half
val
)
{
return
__half2float
(
val
);
}
template
<
typename
T
>
DINLINE
T
downcast_s
(
float
val
);
template
<
>
DINLINE
half
downcast_s
(
float
val
)
{
return
__float2half
(
val
);
}
// scalar add functions
// for some reason when compiling with Pytorch, the + operator for half and
// bfloat is disabled so we call the intrinsics directly
DINLINE
half
&
assign_add
(
half
&
a
,
half
b
)
{
a
=
__hadd
(
a
,
b
);
return
a
;
}
DINLINE
float
&
assign_add
(
float
&
a
,
float
b
)
{
return
a
+=
b
;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
DINLINE
float
upcast_s
(
nv_bfloat16
val
)
{
return
__bfloat162float
(
val
);
}
template
<
>
DINLINE
nv_bfloat16
downcast_s
(
float
val
)
{
return
__float2bfloat16
(
val
);
}
DINLINE
nv_bfloat16
&
assign_add
(
nv_bfloat16
&
a
,
nv_bfloat16
b
)
{
a
=
__hadd
(
a
,
b
);
return
a
;
}
#endif
template
<
typename
T
,
int
N
>
DINLINE
array_t
<
T
,
N
>&
packed_assign_add
(
array_t
<
T
,
N
>&
a
,
array_t
<
T
,
N
>
b
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
assign_add
(
a
.
data
[
i
],
b
.
data
[
i
]);
}
return
a
;
}
template
<
typename
T
,
int
N
>
DINLINE
array_t
<
float
,
N
>
upcast
(
array_t
<
T
,
N
>
val
)
{
if
constexpr
(
std
::
is_same
<
T
,
float
>::
value
)
{
return
val
;
}
else
{
array_t
<
float
,
N
>
out
;
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
out
.
data
[
i
]
=
upcast_s
(
val
.
data
[
i
]);
}
return
out
;
}
}
template
<
typename
O
>
DINLINE
O
downcast
(
array_t
<
float
,
O
::
size
>
val
)
{
if
constexpr
(
std
::
is_same
<
typename
O
::
type
,
float
>::
value
)
{
return
val
;
}
else
{
O
out
;
#pragma unroll
for
(
int
i
=
0
;
i
<
O
::
size
;
i
++
)
{
out
.
data
[
i
]
=
downcast_s
<
typename
O
::
type
>
(
val
.
data
[
i
]);
}
return
out
;
}
}
// This function is meant to be used as the first synchronization in the all
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
template
<
int
ngpus
>
DINLINE
void
start_sync
(
const
RankSignals
&
sg
,
#ifndef USE_ROCM
volatile
#endif
Signal
*
self_sg
,
int
rank
)
{
#ifdef USE_ROCM
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
if
(
threadIdx
.
x
<
ngpus
)
{
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__scoped_atomic_store_n
(
&
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
],
flag
,
__ATOMIC_RELAXED
,
__MEMORY_SCOPE_SYSTEM
);
// wait until we got true from all ranks
while
(
__scoped_atomic_load_n
(
&
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
],
__ATOMIC_RELAXED
,
__MEMORY_SCOPE_DEVICE
)
<
flag
)
;
}
__syncthreads
();
// use one thread to update flag
if
(
threadIdx
.
x
==
0
)
self_sg
->
_flag
[
blockIdx
.
x
]
=
flag
;
#else
if
(
threadIdx
.
x
<
ngpus
)
{
// reset flag for next time
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
]
=
0
;
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
]
=
1
;
// wait until we got true from all ranks
while
(
!
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
])
;
}
__syncthreads
();
#endif
}
// This function is meant to be used as the second or the final synchronization
// barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses.
template
<
int
ngpus
,
bool
final_sync
=
false
>
DINLINE
void
end_sync
(
const
RankSignals
&
sg
,
#ifndef USE_ROCM
volatile
#endif
Signal
*
self_sg
,
int
rank
)
{
#ifdef USE_ROCM
__syncthreads
();
// eliminate the case that prior writes are not visible after signals become
// visible. Note that I did not managed to make this happen through a lot of
// testing. Might be the case that hardware provides stronger guarantee than
// the memory model.
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
if
(
threadIdx
.
x
<
ngpus
)
{
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__scoped_atomic_store_n
(
&
sg
.
signals
[
threadIdx
.
x
]
->
end
[
blockIdx
.
x
][
rank
],
flag
,
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_RELEASE
,
__MEMORY_SCOPE_SYSTEM
);
// wait until we got true from all ranks
while
(
__scoped_atomic_load_n
(
&
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
],
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_ACQUIRE
,
__MEMORY_SCOPE_DEVICE
)
<
flag
)
;
}
__syncthreads
();
// use one thread to update flag
if
(
threadIdx
.
x
==
0
)
self_sg
->
_flag
[
blockIdx
.
x
]
=
flag
;
#else
__syncthreads
();
// eliminate the case that prior writes are not visible after signals become
// visible. Note that I did not managed to make this happen through a lot of
// testing. Might be the case that hardware provides stronger guarantee than
// the memory model.
if
constexpr
(
!
final_sync
)
__threadfence_system
();
if
(
threadIdx
.
x
<
ngpus
)
{
// reset flag for next time
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
]
=
0
;
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
sg
.
signals
[
threadIdx
.
x
]
->
end
[
blockIdx
.
x
][
rank
]
=
1
;
// wait until we got true from all ranks
while
(
!
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
])
;
}
if
constexpr
(
!
final_sync
)
__syncthreads
();
#endif
}
template
<
typename
P
,
int
ngpus
,
typename
A
>
DINLINE
P
packed_reduce
(
const
P
*
ptrs
[],
int
idx
)
{
A
tmp
=
upcast
(
ptrs
[
0
][
idx
]);
#pragma unroll
for
(
int
i
=
1
;
i
<
ngpus
;
i
++
)
{
packed_assign_add
(
tmp
,
upcast
(
ptrs
[
i
][
idx
]));
}
return
downcast
<
P
>
(
tmp
);
}
template
<
typename
T
,
int
ngpus
>
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_1stage
(
RankData
*
_dp
,
RankSignals
sg
,
#ifndef USE_ROCM
volatile
#endif
Signal
*
self_sg
,
T
*
__restrict__
result
,
int
rank
,
int
size
)
{
using
P
=
typename
packed_t
<
T
>::
P
;
using
A
=
typename
packed_t
<
T
>::
A
;
// note: we don't reorder the address so the accumulation order is the same
// for all ranks, ensuring bitwise identical results
auto
dp
=
*
_dp
;
start_sync
<
ngpus
>
(
sg
,
self_sg
,
rank
);
// do the actual reduction
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
((
P
*
)
result
)[
idx
]
=
packed_reduce
<
P
,
ngpus
,
A
>
((
const
P
**
)
&
dp
.
ptrs
[
0
],
idx
);
}
end_sync
<
ngpus
,
true
>
(
sg
,
self_sg
,
rank
);
}
template
<
typename
P
>
#ifdef USE_ROCM
DINLINE
P
*
get_tmp_buf
(
Signal
*
sg
)
{
#else
DINLINE
P
*
get_tmp_buf
(
volatile
Signal
*
sg
)
{
#endif
return
(
P
*
)(((
Signal
*
)
sg
)
+
1
);
}
template
<
typename
T
,
int
ngpus
>
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_2stage
(
RankData
*
_dp
,
RankSignals
sg
,
#ifndef USE_ROCM
volatile
#endif
Signal
*
self_sg
,
T
*
__restrict__
result
,
int
rank
,
int
size
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
gridDim
.
x
*
blockDim
.
x
;
using
P
=
typename
packed_t
<
T
>::
P
;
using
A
=
typename
packed_t
<
T
>::
A
;
int
part
=
size
/
ngpus
;
int
start
=
rank
*
part
;
int
end
=
rank
==
ngpus
-
1
?
size
:
start
+
part
;
int
largest_part
=
part
+
size
%
ngpus
;
const
P
*
ptrs
[
ngpus
];
P
*
tmps
[
ngpus
];
#pragma unroll
for
(
int
i
=
0
;
i
<
ngpus
;
i
++
)
{
int
target
=
(
rank
+
i
)
%
ngpus
;
ptrs
[
i
]
=
(
const
P
*
)
_dp
->
ptrs
[
target
];
tmps
[
i
]
=
get_tmp_buf
<
P
>
(
sg
.
signals
[
target
]);
}
auto
tmp_out
=
tmps
[
0
];
start_sync
<
ngpus
>
(
sg
,
self_sg
,
rank
);
// stage 1: reduce scatter
for
(
int
idx
=
start
+
tid
;
idx
<
end
;
idx
+=
stride
)
{
tmp_out
[
idx
-
start
]
=
packed_reduce
<
P
,
ngpus
,
A
>
(
ptrs
,
idx
);
}
end_sync
<
ngpus
>
(
sg
,
self_sg
,
rank
);
// stage 2: allgather. Note: it's important to match the tid between
// the two stages, because visibility across devices is only guaranteed
// between threads that have the same tid. If thread i computes the sum of
// start + i in the first stage, then thread i also gathers start + i from all
// ranks.
for
(
int
idx
=
tid
;
idx
<
largest_part
;
idx
+=
stride
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
ngpus
;
i
++
)
{
int
gather_from_rank
=
((
rank
+
i
)
%
ngpus
);
if
(
gather_from_rank
==
ngpus
-
1
||
idx
<
part
)
{
int
dst_idx
=
gather_from_rank
*
part
+
idx
;
((
P
*
)
result
)[
dst_idx
]
=
tmps
[
i
][
idx
];
}
}
}
}
using
IPC_KEY
=
std
::
array
<
uint8_t
,
sizeof
(
hipIpcMemHandle_t
)
>
;
static_assert
(
sizeof
(
IPC_KEY
)
==
sizeof
(
hipIpcMemHandle_t
));
static_assert
(
alignof
(
IPC_KEY
)
==
alignof
(
hipIpcMemHandle_t
));
class
CustomAllreduce
{
public:
int
rank_
;
int
world_size_
;
bool
full_nvlink_
;
// below are device pointers
RankSignals
sg_
;
std
::
unordered_map
<
void
*
,
RankData
*>
buffers_
;
Signal
*
self_sg_
;
// stores the registered device pointers from all ranks
RankData
*
d_rank_data_base_
,
*
d_rank_data_end_
;
std
::
vector
<
void
*>
graph_unreg_buffers_
;
// a map from IPC handles to opened IPC pointers
std
::
map
<
IPC_KEY
,
char
*>
ipc_handles_
;
/**
* meta is a pointer to device metadata and temporary buffer for allreduce.
*
* There's a total of sizeof(Signal) of prefix before the actual data,
* so meta + 1 points to actual temporary buffer.
*
* note: this class does not own any device memory. Any required buffers
* are passed in from the constructor
*/
CustomAllreduce
(
Signal
*
meta
,
void
*
rank_data
,
size_t
rank_data_sz
,
const
hipIpcMemHandle_t
*
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
int
rank
,
bool
full_nvlink
=
true
)
:
rank_
(
rank
),
world_size_
(
offsets
.
size
()),
full_nvlink_
(
full_nvlink
),
self_sg_
(
meta
),
d_rank_data_base_
(
reinterpret_cast
<
RankData
*>
(
rank_data
)),
d_rank_data_end_
(
d_rank_data_base_
+
rank_data_sz
/
sizeof
(
RankData
))
{
for
(
int
i
=
0
;
i
<
world_size_
;
i
++
)
{
Signal
*
rank_sg
;
if
(
i
!=
rank_
)
{
char
*
handle
=
open_ipc_handle
(
&
handles
[
i
]);
handle
+=
offsets
[
i
];
rank_sg
=
(
Signal
*
)
handle
;
}
else
{
rank_sg
=
self_sg_
;
}
sg_
.
signals
[
i
]
=
rank_sg
;
}
}
char
*
open_ipc_handle
(
const
void
*
ipc_handle
)
{
auto
[
it
,
new_handle
]
=
ipc_handles_
.
insert
({
*
((
IPC_KEY
*
)
ipc_handle
),
nullptr
});
if
(
new_handle
)
{
char
*
ipc_ptr
;
CUDACHECK
(
hipIpcOpenMemHandle
((
void
**
)
&
ipc_ptr
,
*
((
const
hipIpcMemHandle_t
*
)
ipc_handle
),
hipIpcMemLazyEnablePeerAccess
));
it
->
second
=
ipc_ptr
;
}
return
it
->
second
;
}
std
::
pair
<
std
::
vector
<
uint8_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
()
{
auto
num_buffers
=
graph_unreg_buffers_
.
size
();
auto
handle_sz
=
sizeof
(
hipIpcMemHandle_t
);
std
::
vector
<
uint8_t
>
handles
(
handle_sz
*
num_buffers
,
0
);
std
::
vector
<
int64_t
>
offsets
(
num_buffers
);
for
(
int
i
=
0
;
i
<
num_buffers
;
i
++
)
{
auto
ptr
=
graph_unreg_buffers_
[
i
];
void
*
base_ptr
;
// note: must share the base address of each allocation, or we get wrong
// address
if
(
hipPointerGetAttribute
(
&
base_ptr
,
#ifdef USE_ROCM
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR
,
#else
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR
,
#endif
(
hipDeviceptr_t
)
ptr
)
!=
hipSuccess
)
throw
std
::
runtime_error
(
"failed to get pointer attr"
);
CUDACHECK
(
hipIpcGetMemHandle
((
hipIpcMemHandle_t
*
)
&
handles
[
i
*
handle_sz
],
base_ptr
));
offsets
[
i
]
=
((
char
*
)
ptr
)
-
((
char
*
)
base_ptr
);
}
return
std
::
make_pair
(
handles
,
offsets
);
}
void
check_rank_data_capacity
(
size_t
num
=
1
)
{
if
(
d_rank_data_base_
+
num
>
d_rank_data_end_
)
throw
std
::
runtime_error
(
"Rank data buffer is overflowed by "
+
std
::
to_string
(
d_rank_data_base_
+
num
-
d_rank_data_end_
));
}
void
register_buffer
(
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
void
*
self
)
{
check_rank_data_capacity
();
RankData
data
;
for
(
int
i
=
0
;
i
<
world_size_
;
i
++
)
{
if
(
i
!=
rank_
)
{
char
*
handle
=
open_ipc_handle
(
handles
[
i
].
data
());
handle
+=
offsets
[
i
];
data
.
ptrs
[
i
]
=
handle
;
}
else
{
data
.
ptrs
[
i
]
=
self
;
}
}
auto
d_data
=
d_rank_data_base_
++
;
CUDACHECK
(
hipMemcpy
(
d_data
,
&
data
,
sizeof
(
RankData
),
hipMemcpyHostToDevice
));
buffers_
[
self
]
=
d_data
;
}
// note: when registering graph buffers, we intentionally choose to not
// deduplicate the addresses. That means if the allocator reuses some
// addresses, they will be registered again. This is to account for the remote
// possibility of different allocation patterns between ranks. For example,
// rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
void
register_graph_buffers
(
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
auto
num_buffers
=
graph_unreg_buffers_
.
size
();
check_rank_data_capacity
(
num_buffers
);
std
::
vector
<
RankData
>
rank_data
(
num_buffers
);
for
(
int
i
=
0
;
i
<
num_buffers
;
i
++
)
{
auto
self_ptr
=
graph_unreg_buffers_
[
i
];
auto
&
rd
=
rank_data
[
i
];
for
(
int
j
=
0
;
j
<
world_size_
;
j
++
)
{
if
(
j
!=
rank_
)
{
char
*
handle
=
open_ipc_handle
(
&
handles
[
j
][
i
*
sizeof
(
hipIpcMemHandle_t
)]);
handle
+=
offsets
[
j
][
i
];
rd
.
ptrs
[
j
]
=
handle
;
}
else
{
rd
.
ptrs
[
j
]
=
self_ptr
;
}
}
}
CUDACHECK
(
hipMemcpy
(
d_rank_data_base_
,
rank_data
.
data
(),
sizeof
(
RankData
)
*
num_buffers
,
hipMemcpyHostToDevice
));
d_rank_data_base_
+=
num_buffers
;
graph_unreg_buffers_
.
clear
();
}
/**
* This is the result after careful grid search. Using 36 blocks give the best
* or close to the best runtime on the devices I tried: A100, A10, A30, T4,
* V100. You'll notice that NCCL kernels also only take a small amount of SMs.
* Not quite sure the underlying reason, but my guess is that too many SMs
* will cause contention on NVLink bus.
*/
template
<
typename
T
>
void
allreduce
(
hipStream_t
stream
,
T
*
input
,
T
*
output
,
int
size
,
#ifndef USE_ROCM
int
threads
=
512
,
int
block_limit
=
36
){
#else
int
threads
=
512
,
int
block_limit
=
16
)
{
#endif
auto
d
=
packed_t
<
T
>::
P
::
size
;
if
(
size
%
d
!=
0
)
throw
std
::
runtime_error
(
"custom allreduce currently requires input length to be multiple "
"of "
+
std
::
to_string
(
d
));
if
(
block_limit
>
kMaxBlocks
)
throw
std
::
runtime_error
(
"max supported block limit is "
+
std
::
to_string
(
kMaxBlocks
)
+
". Got "
+
std
::
to_string
(
block_limit
));
RankData
*
ptrs
;
hipStreamCaptureStatus
status
;
CUDACHECK
(
hipStreamIsCapturing
(
stream
,
&
status
));
if
(
status
==
hipStreamCaptureStatusActive
)
{
ptrs
=
d_rank_data_base_
+
graph_unreg_buffers_
.
size
();
graph_unreg_buffers_
.
push_back
(
input
);
}
else
{
auto
it
=
buffers_
.
find
(
input
);
if
(
it
==
buffers_
.
end
())
throw
std
::
runtime_error
(
"buffer address "
+
std
::
to_string
(
reinterpret_cast
<
uint64_t
>
(
input
))
+
" is not registered!"
);
ptrs
=
it
->
second
;
}
size
/=
d
;
auto
bytes
=
size
*
sizeof
(
typename
packed_t
<
T
>::
P
);
int
blocks
=
::
min
(
block_limit
,
(
size
+
threads
-
1
)
/
threads
);
#define KL(ngpus, name) \
hipLaunchKernelGGL((name<T, ngpus>), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_, \
size);
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (world_size_ == 2) { \
KL(ngpus, cross_device_reduce_1stage); \
} else if (full_nvlink_) { \
if ((world_size_ <= 4 && bytes < 512 * 1024) || (world_size_ <= 8 && bytes < 256 * 1024)) { \
KL(ngpus, cross_device_reduce_1stage); \
} else { \
KL(ngpus, cross_device_reduce_2stage); \
} \
} \
break; \
}
switch
(
world_size_
)
{
REDUCE_CASE
(
2
)
REDUCE_CASE
(
4
)
REDUCE_CASE
(
6
)
REDUCE_CASE
(
8
)
default:
throw
std
::
runtime_error
(
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
"gpus = "
+
std
::
to_string
(
world_size_
));
}
#undef REDUCE_CASE
#undef KL
}
~
CustomAllreduce
()
{
for
(
auto
[
_
,
ptr
]
:
ipc_handles_
)
{
CUDACHECK
(
hipIpcCloseMemHandle
(
ptr
));
}
}
};
// namespace vllm
/**
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
a template instantiation:
* template void vllm::CustomAllreduce::allreduce<half>(hipStream_t, half *,
half *, int, int, int);
*/
}
// namespace vllm
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
View file @
9cf40772
...
...
@@ -34,8 +34,23 @@ limitations under the License.
return PyModule_Create(&module); \
}
// trt_reduce
using
fptr_t
=
int64_t
;
#ifdef USE_ROCM
fptr_t
init_custom_ar
(
torch
::
Tensor
&
meta
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
int64_t
rank
,
bool
full_nvlink
);
void
all_reduce_reg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
);
void
all_reduce_unreg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
reg_buffer
,
torch
::
Tensor
&
out
);
void
dispose
(
fptr_t
_fa
);
int64_t
meta_size
();
void
register_buffer
(
fptr_t
_fa
,
torch
::
Tensor
&
t
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
);
std
::
tuple
<
torch
::
Tensor
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
);
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
torch
::
Tensor
allocate_meta_buffer
(
int64_t
size
);
torch
::
Tensor
get_meta_buffer_ipc_handle
(
torch
::
Tensor
&
inp
);
#else
// trt_reduce
fptr_t
init_custom_ar
(
int64_t
rank_id
,
int64_t
world_size
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
fptr_t
>&
buffers
,
const
std
::
vector
<
fptr_t
>&
tmp_result_buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
const
std
::
vector
<
fptr_t
>&
barrier_out
);
...
...
@@ -44,6 +59,7 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
std
::
tuple
<
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
);
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
#endif
// moe_align_block_size
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
int64_t
block_size
,
...
...
sgl-kernel/src/sgl-kernel/ops/__init__.py
View file @
9cf40772
...
...
@@ -64,28 +64,79 @@ def apply_rope_with_cos_sin_cache_inplace(
)
def
init_custom_reduce
(
rank_id
,
num_devices
,
rank_data
,
buffers
,
tmp_buffers
,
barrier_in
,
barrier_out
):
return
torch
.
ops
.
sgl_kernels
.
init_custom_ar
(
rank_id
,
num_devices
,
rank_data
,
buffers
,
tmp_buffers
,
barrier_in
,
barrier_out
)
if
torch
.
version
.
hip
is
not
None
:
def
init_custom_ar
(
meta
:
torch
.
Tensor
,
rank_data
:
torch
.
Tensor
,
handles
:
List
[
str
],
offsets
:
List
[
int
],
rank
:
int
,
full_nvlink
:
bool
,
)
->
int
:
return
torch
.
ops
.
sgl_kernels
.
init_custom_ar
(
meta
,
rank_data
,
handles
,
offsets
,
rank
,
full_nvlink
)
def
all_reduce_reg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
sgl_kernels
.
all_reduce_reg
(
fa
,
inp
,
out
)
def
all_reduce_unreg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
reg_buffer
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
sgl_kernels
.
all_reduce_unreg
(
fa
,
inp
,
reg_buffer
,
out
)
def
dispose
(
fa
:
int
)
->
None
:
torch
.
ops
.
sgl_kernels
.
dispose
(
fa
)
def
custom_dispose
(
fa
)
:
torch
.
ops
.
sgl_kernels
.
dispose
(
fa
)
def
meta_size
()
->
int
:
return
torch
.
ops
.
sgl_kernels
.
meta_size
(
)
def
register_buffer
(
fa
:
int
,
t
:
torch
.
Tensor
,
handles
:
List
[
str
],
offsets
:
List
[
int
]
)
->
None
:
return
torch
.
ops
.
sgl_kernels
.
register_buffer
(
fa
,
t
,
handles
,
offsets
)
def
custom_reduce
(
fa
,
inp
,
out
):
torch
.
ops
.
sgl_kernels
.
all_reduce
(
fa
,
inp
,
out
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
torch
.
Tensor
,
List
[
int
]]:
return
torch
.
ops
.
sgl_kernels
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
str
],
offsets
:
List
[
List
[
int
]]
)
->
None
:
torch
.
ops
.
sgl_kernels
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
def
allocate_meta_buffer
(
size
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
sgl_kernels
.
allocate_meta_buffer
(
size
)
def
get_meta_buffer_ipc_handle
(
inp
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
ops
.
sgl_kernels
.
get_meta_buffer_ipc_handle
(
inp
)
else
:
# trt_reduce
def
init_custom_reduce
(
rank_id
,
num_devices
,
rank_data
,
buffers
,
tmp_buffers
,
barrier_in
,
barrier_out
):
return
torch
.
ops
.
sgl_kernels
.
init_custom_ar
(
rank_id
,
num_devices
,
rank_data
,
buffers
,
tmp_buffers
,
barrier_in
,
barrier_out
,
)
def
custom_dispose
(
fa
):
torch
.
ops
.
sgl_kernels
.
dispose
(
fa
)
def
get_graph_buffer_ipc_meta
(
fa
):
return
torch
.
ops
.
sgl_kernels
.
get_graph_buffer_ipc_meta
(
fa
)
def
custom_reduce
(
fa
,
inp
,
out
):
torch
.
ops
.
sgl_kernels
.
all_reduce
(
fa
,
inp
,
out
)
def
get_graph_buffer_ipc_meta
(
fa
):
return
torch
.
ops
.
sgl_kernels
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
,
handles
,
offsets
):
torch
.
ops
.
sgl_kernels
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
def
register_graph_buffers
(
fa
,
handles
,
offsets
):
torch
.
ops
.
sgl_kernels
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
def
moe_align_block_size
(
...
...
sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc
View file @
9cf40772
...
...
@@ -19,6 +19,37 @@ limitations under the License.
#include "sgl_kernels_ops.h"
TORCH_LIBRARY_EXPAND
(
sgl_kernels
,
m
)
{
// Custom all-reduce kernels
m
.
def
(
"init_custom_ar(Tensor meta, Tensor rank_data, "
"str[] handles, int[] offsets, int rank, "
"bool full_nvlink) -> int"
);
m
.
impl
(
"init_custom_ar"
,
torch
::
kCUDA
,
&
init_custom_ar
);
m
.
def
(
"all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()"
);
m
.
impl
(
"all_reduce_reg"
,
torch
::
kCUDA
,
&
all_reduce_reg
);
m
.
def
(
"all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> "
"()"
);
m
.
impl
(
"all_reduce_unreg"
,
torch
::
kCUDA
,
&
all_reduce_unreg
);
m
.
def
(
"dispose"
,
&
dispose
);
m
.
def
(
"meta_size"
,
&
meta_size
);
m
.
def
(
"register_buffer(int fa, Tensor t, str[] handles, "
"int[] offsets) -> ()"
);
m
.
impl
(
"register_buffer"
,
torch
::
kCUDA
,
&
register_buffer
);
m
.
def
(
"get_graph_buffer_ipc_meta"
,
&
get_graph_buffer_ipc_meta
);
m
.
def
(
"register_graph_buffers"
,
&
register_graph_buffers
);
m
.
def
(
"allocate_meta_buffer"
,
&
allocate_meta_buffer
);
m
.
impl
(
"allocate_meta_buffer"
,
torch
::
kCUDA
,
&
allocate_meta_buffer
);
m
.
def
(
"get_meta_buffer_ipc_handle"
,
&
get_meta_buffer_ipc_handle
);
m
.
impl
(
"get_meta_buffer_ipc_handle"
,
torch
::
kCPU
,
&
get_meta_buffer_ipc_handle
);
// moe_align_block_size
m
.
def
(
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment