Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9cf40772
Unverified
Commit
9cf40772
authored
Mar 02, 2025
by
Hubert Lu
Committed by
GitHub
Mar 02, 2025
Browse files
Enable custom AR for AMD GPUs and maintain it in sgl-kernel (#3406)
parent
d3fe9bae
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1278 additions
and
191 deletions
+1278
-191
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+81
-33
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
...srt/distributed/device_communicators/custom_all_reduce.py
+222
-80
sgl-kernel/setup_rocm.py
sgl-kernel/setup_rocm.py
+1
-0
sgl-kernel/src/sgl-kernel/__init__.py
sgl-kernel/src/sgl-kernel/__init__.py
+127
-63
sgl-kernel/src/sgl-kernel/csrc/custom_all_reduce.hip
sgl-kernel/src/sgl-kernel/csrc/custom_all_reduce.hip
+180
-0
sgl-kernel/src/sgl-kernel/csrc/custom_all_reduce_hip.cuh
sgl-kernel/src/sgl-kernel/csrc/custom_all_reduce_hip.cuh
+554
-0
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
+17
-1
sgl-kernel/src/sgl-kernel/ops/__init__.py
sgl-kernel/src/sgl-kernel/ops/__init__.py
+65
-14
sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc
sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc
+31
-0
No files found.
python/sglang/srt/_custom_ops.py
View file @
9cf40772
...
@@ -9,13 +9,13 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
...
@@ -9,13 +9,13 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import
torch
import
torch
import
torch.library
import
torch.library
from
sglang.srt.utils
import
is_hpu
from
sglang.srt.utils
import
is_hip
,
is_hpu
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
use_vllm_custom_allreduce
=
os
.
environ
.
get
(
"USE_VLLM_CUSTOM_ALLREDUCE"
,
default
=
True
)
use_vllm_custom_allreduce
=
os
.
environ
.
get
(
"USE_VLLM_CUSTOM_ALLREDUCE"
,
default
=
True
)
if
not
is_hpu
():
if
not
is_hpu
():
if
use_vllm_custom_allreduce
:
# Remove vllm dependency for custom allreduce on ROCm
if
use_vllm_custom_allreduce
and
not
is_hip
():
try
:
try
:
import
vllm._C
import
vllm._C
except
ImportError
as
e
:
except
ImportError
as
e
:
...
@@ -56,7 +56,7 @@ def hint_on_error(fn):
...
@@ -56,7 +56,7 @@ def hint_on_error(fn):
return
wrapper
return
wrapper
if
use_vllm_custom_allreduce
:
if
use_vllm_custom_allreduce
and
not
is_hip
()
:
# custom ar
# custom ar
def
init_custom_ar
(
def
init_custom_ar
(
ipc_tensors
:
List
[
torch
.
Tensor
],
ipc_tensors
:
List
[
torch
.
Tensor
],
...
@@ -95,39 +95,87 @@ if use_vllm_custom_allreduce:
...
@@ -95,39 +95,87 @@ if use_vllm_custom_allreduce:
torch
.
ops
.
_C_custom_ar
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
torch
.
ops
.
_C_custom_ar
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
else
:
else
:
# custom ar
if
is_hip
():
def
init_custom_ar
(
rank_id
:
int
,
def
init_custom_ar
(
world_size
:
int
,
meta
:
torch
.
Tensor
,
rank_data_base
:
torch
.
Tensor
,
rank_data
:
torch
.
Tensor
,
buffers
:
List
[
int
],
handles
:
List
[
str
],
tmp_result_buffers
:
List
[
int
],
offsets
:
List
[
int
],
barrier_in
:
List
[
int
],
rank
:
int
,
barrier_out
:
List
[
int
],
full_nvlink
:
bool
,
)
->
int
:
)
->
int
:
return
sgl_kernel
.
ops
.
init_custom_reduce
(
return
sgl_kernel
.
ops
.
init_custom_ar
(
rank_id
,
meta
,
rank_data
,
handles
,
offsets
,
rank
,
full_nvlink
world_size
,
)
rank_data_base
,
buffers
,
tmp_result_buffers
,
barrier_in
,
barrier_out
,
)
def
all_reduce
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
def
all_reduce
_reg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
sgl_kernel
.
ops
.
custom
_reduce
(
fa
,
inp
,
out
)
sgl_kernel
.
ops
.
all
_reduce
_reg
(
fa
,
inp
,
out
)
def
dispose
(
fa
:
int
)
->
None
:
def
all_reduce_unreg
(
sgl_kernel
.
ops
.
custom_dispose
(
fa
)
fa
:
int
,
inp
:
torch
.
Tensor
,
reg_buffer
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
sgl_kernel
.
ops
.
all_reduce_unreg
(
fa
,
inp
,
reg_buffer
,
out
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
List
[
int
],
List
[
int
]]
:
def
dispose
(
fa
:
int
)
->
None
:
return
sgl_kernel
.
ops
.
get_graph_buffer_ipc_meta
(
fa
)
sgl_kernel
.
ops
.
dispose
(
fa
)
def
register_graph_buffers
(
def
meta_size
()
->
int
:
fa
:
int
,
handles
:
List
[
List
[
int
]],
offsets
:
List
[
List
[
int
]]
return
sgl_kernel
.
ops
.
meta_size
()
)
->
None
:
sgl_kernel
.
ops
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
def
register_buffer
(
fa
:
int
,
t
:
torch
.
Tensor
,
handles
:
List
[
str
],
offsets
:
List
[
int
]
)
->
None
:
return
sgl_kernel
.
ops
.
register_buffer
(
fa
,
t
,
handles
,
offsets
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
torch
.
Tensor
,
List
[
int
]]:
return
sgl_kernel
.
ops
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
str
],
offsets
:
List
[
List
[
int
]]
)
->
None
:
sgl_kernel
.
ops
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
def
allocate_meta_buffer
(
size
:
int
)
->
torch
.
Tensor
:
return
sgl_kernel
.
ops
.
allocate_meta_buffer
(
size
)
def
get_meta_buffer_ipc_handle
(
inp
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
sgl_kernel
.
ops
.
get_meta_buffer_ipc_handle
(
inp
)
else
:
# custom ar
def
init_custom_ar
(
rank_id
:
int
,
world_size
:
int
,
rank_data_base
:
torch
.
Tensor
,
buffers
:
List
[
int
],
tmp_result_buffers
:
List
[
int
],
barrier_in
:
List
[
int
],
barrier_out
:
List
[
int
],
)
->
int
:
return
sgl_kernel
.
ops
.
init_custom_reduce
(
rank_id
,
world_size
,
rank_data_base
,
buffers
,
tmp_result_buffers
,
barrier_in
,
barrier_out
,
)
def
all_reduce
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
sgl_kernel
.
ops
.
custom_reduce
(
fa
,
inp
,
out
)
def
dispose
(
fa
:
int
)
->
None
:
sgl_kernel
.
ops
.
custom_dispose
(
fa
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
return
sgl_kernel
.
ops
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
List
[
int
]],
offsets
:
List
[
List
[
int
]]
)
->
None
:
sgl_kernel
.
ops
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
...
...
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
View file @
9cf40772
...
@@ -18,7 +18,7 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import
...
@@ -18,7 +18,7 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import
gpu_p2p_access_check
,
gpu_p2p_access_check
,
)
)
from
sglang.srt.distributed.parallel_state
import
in_the_same_node_as
from
sglang.srt.distributed.parallel_state
import
in_the_same_node_as
from
sglang.srt.utils
import
cuda_device_count_stateless
,
is_cuda
from
sglang.srt.utils
import
cuda_device_count_stateless
,
is_cuda
,
is_hip
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -28,14 +28,27 @@ if is_cuda():
...
@@ -28,14 +28,27 @@ if is_cuda():
except
ImportError
as
e
:
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import pynvml with %r"
,
e
)
logger
.
warning
(
"Failed to import pynvml with %r"
,
e
)
if
is_hip
():
try
:
from
amdsmi
import
(
AmdSmiException
,
amdsmi_get_gpu_board_info
,
amdsmi_get_processor_handles
,
amdsmi_init
,
amdsmi_shut_down
,
amdsmi_topo_get_link_type
,
)
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import amdsmi with %r"
,
e
)
try
:
try
:
if
ops
.
use_vllm_custom_allreduce
:
if
ops
.
use_vllm_custom_allreduce
and
not
is_hip
()
:
ops
.
meta_size
()
ops
.
meta_size
()
else
:
else
:
import
sgl_kernel
import
sgl_kernel
custom_ar
=
True
custom_ar
=
True
except
Exception
:
except
Exception
:
# For
AMD GPUs and
CPUs
# For CPUs
custom_ar
=
False
custom_ar
=
False
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -47,37 +60,62 @@ _R = TypeVar("_R")
...
@@ -47,37 +60,62 @@ _R = TypeVar("_R")
def
with_nvml_context
(
fn
:
Callable
[
_P
,
_R
])
->
Callable
[
_P
,
_R
]:
def
with_nvml_context
(
fn
:
Callable
[
_P
,
_R
])
->
Callable
[
_P
,
_R
]:
@
wraps
(
fn
)
@
wraps
(
fn
)
def
wrapper
(
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
_R
:
def
wrapper
(
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
_R
:
pynvml
.
nvmlInit
()
if
torch
.
version
.
hip
:
try
:
try
:
return
fn
(
*
args
,
**
kwargs
)
amdsmi_init
()
finally
:
return
fn
(
*
args
,
**
kwargs
)
pynvml
.
nvmlShutdown
()
finally
:
amdsmi_shut_down
()
else
:
pynvml
.
nvmlInit
()
try
:
return
fn
(
*
args
,
**
kwargs
)
finally
:
pynvml
.
nvmlShutdown
()
return
wrapper
return
wrapper
@
with_nvml_context
@
with_nvml_context
def
is_full_nvlink
(
physical_device_ids
:
List
[
int
])
->
bool
:
def
is_full_nvlink
(
physical_device_ids
:
List
[
int
],
world_size
:
int
)
->
bool
:
"""
if
is_hip
():
query if the set of gpus are fully connected by nvlink (1 hop)
"""
"""
query if the set of gpus are fully connected by xgmi (1 hop)
handles
=
[
pynvml
.
nvmlDeviceGetHandleByIndex
(
i
)
for
i
in
physical_device_ids
]
"""
for
i
,
handle
in
enumerate
(
handles
):
handles
=
[
amdsmi_get_processor_handles
()[
i
]
for
i
in
physical_device_ids
]
for
j
,
peer_handle
in
enumerate
(
handles
):
for
i
,
handle
in
enumerate
(
handles
):
if
i
<
j
:
for
j
,
peer_handle
in
enumerate
(
handles
):
try
:
if
i
<
j
:
p2p_status
=
pynvml
.
nvmlDeviceGetP2PStatus
(
try
:
handle
,
peer_handle
,
pynvml
.
NVML_P2P_CAPS_INDEX_NVLINK
link_type
=
amdsmi_topo_get_link_type
(
handle
,
peer_handle
)
)
# type is 2 for XGMI
if
p2p_status
!=
pynvml
.
NVML_P2P_STATUS_OK
:
if
link_type
[
"hops"
]
!=
1
or
link_type
[
"type"
]
!=
2
:
return
False
except
AmdSmiException
as
error
:
logger
.
error
(
"AMD 1 hop XGMI detection failed."
,
exc_info
=
error
)
return
False
return
False
except
pynvml
.
NVMLError
:
return
True
logger
.
exception
(
else
:
"NVLink detection failed. This is normal if your"
"""
" machine has no NVLink equipped."
query if the set of gpus are fully connected by nvlink (1 hop)
)
"""
return
False
handles
=
[
pynvml
.
nvmlDeviceGetHandleByIndex
(
i
)
for
i
in
physical_device_ids
]
return
True
for
i
,
handle
in
enumerate
(
handles
):
for
j
,
peer_handle
in
enumerate
(
handles
):
if
i
<
j
:
try
:
p2p_status
=
pynvml
.
nvmlDeviceGetP2PStatus
(
handle
,
peer_handle
,
pynvml
.
NVML_P2P_CAPS_INDEX_NVLINK
)
if
p2p_status
!=
pynvml
.
NVML_P2P_STATUS_OK
:
return
False
except
pynvml
.
NVMLError
:
logger
.
exception
(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped."
)
return
False
return
True
def
_can_p2p
(
rank
:
int
,
world_size
:
int
)
->
bool
:
def
_can_p2p
(
rank
:
int
,
world_size
:
int
)
->
bool
:
...
@@ -102,15 +140,18 @@ def is_weak_contiguous(inp: torch.Tensor):
...
@@ -102,15 +140,18 @@ def is_weak_contiguous(inp: torch.Tensor):
class
CustomAllreduce
:
class
CustomAllreduce
:
_SUPPORTED_WORLD_SIZES
=
[
2
,
4
,
6
,
8
]
_SUPPORTED_WORLD_SIZES
=
[
2
,
4
,
6
,
8
]
_MAX_CAR_SIZE
=
8192
*
1024
if
is_hip
():
# crossover is at 16MB buffer size for ROCm
_MAX_CAR_SIZE
=
2
*
8192
*
1024
# max_size: max supported allreduce size
# max_size: max supported allreduce size
def
__init__
(
def
__init__
(
self
,
self
,
group
:
ProcessGroup
,
group
:
ProcessGroup
,
device
:
Union
[
int
,
str
,
torch
.
device
],
device
:
Union
[
int
,
str
,
torch
.
device
],
max_size
=
8192
*
1024
,
max_size
=
_MAX_CAR_SIZE
,
)
->
None
:
)
->
None
:
"""
"""
Args:
Args:
...
@@ -185,12 +226,9 @@ class CustomAllreduce:
...
@@ -185,12 +226,9 @@ class CustomAllreduce:
# test nvlink first, this will filter out most of the cases
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
# this checks hardware and driver support for NVLink
if
is_cuda
():
if
is_cuda
()
or
is_hip
()
:
assert
is_cuda
(
)
full_nvlink
=
is_full_nvlink
(
physical_device_ids
,
world_size
)
full_nvlink
=
is_full_nvlink
(
physical_device_ids
)
else
:
full_nvlink
=
False
if
world_size
>
2
and
not
full_nvlink
:
if
world_size
>
2
and
not
full_nvlink
:
logger
.
warning
(
logger
.
warning
(
"Custom allreduce is disabled because it's not supported on"
"Custom allreduce is disabled because it's not supported on"
...
@@ -201,7 +239,8 @@ class CustomAllreduce:
...
@@ -201,7 +239,8 @@ class CustomAllreduce:
# test P2P capability, this checks software/cudaruntime support
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# this is expensive to compute at the first time
# then we cache the result
# then we cache the result
if
not
_can_p2p
(
rank
,
world_size
):
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
if
not
is_hip
()
and
not
_can_p2p
(
rank
,
world_size
):
logger
.
warning
(
logger
.
warning
(
"Custom allreduce is disabled because your platform lacks "
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
"GPU P2P capability or P2P test failed. To silence this "
...
@@ -214,7 +253,7 @@ class CustomAllreduce:
...
@@ -214,7 +253,7 @@ class CustomAllreduce:
self
.
world_size
=
world_size
self
.
world_size
=
world_size
self
.
full_nvlink
=
full_nvlink
self
.
full_nvlink
=
full_nvlink
if
ops
.
use_vllm_custom_allreduce
:
if
ops
.
use_vllm_custom_allreduce
and
not
is_hip
()
:
# Buffers memory are owned by this Python class and passed to C++.
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
# temporary buffer for storing intermediate allreduce results.
...
@@ -237,35 +276,56 @@ class CustomAllreduce:
...
@@ -237,35 +276,56 @@ class CustomAllreduce:
)
)
ops
.
register_buffer
(
self
.
_ptr
,
self
.
buffer_ptrs
)
ops
.
register_buffer
(
self
.
_ptr
,
self
.
buffer_ptrs
)
else
:
else
:
# From TensorRT-LLM getMaxRequiredWorkspaceSize
if
is_hip
():
self
.
max_required_workspace_size
=
[
16
*
1024
*
1024
,
8
*
1024
*
1024
]
# meta data buffers need to be "uncached" for signal on MI200
self
.
meta
=
ops
.
allocate_meta_buffer
(
ops
.
meta_size
()
+
max_size
)
self
.
buffer
=
torch
.
empty
(
max_size
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
handle
=
ops
.
get_meta_buffer_ipc_handle
(
self
.
meta
)
shard_data
=
(
bytes
(
handle
),
# ipc handle to base ptr
0
,
# offset of base ptr
)
handles
,
offsets
=
self
.
_gather_ipc_meta
(
shard_data
)
self
.
rank_data
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
_ptr
=
ops
.
init_custom_ar
(
self
.
meta
,
self
.
rank_data
,
handles
,
offsets
,
rank
,
self
.
full_nvlink
)
self
.
register_buffer
(
self
.
buffer
)
self
.
MSCCL
=
os
.
getenv
(
"RCCL_MSCCL_ENABLE"
,
"1"
)
==
"1"
else
:
# From TensorRT-LLM getMaxRequiredWorkspaceSize
self
.
max_required_workspace_size
=
[
16
*
1024
*
1024
,
8
*
1024
*
1024
]
# sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
# sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
self
.
barrier_max_size
=
8
*
(
36
+
2
)
*
8
self
.
barrier_max_size
=
8
*
(
36
+
2
)
*
8
self
.
buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
self
.
buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
self
.
tmp_result_buffer_ptrs
=
self
.
create_shared_buffer
(
self
.
tmp_result_buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
max_size
,
group
=
group
)
)
self
.
rank_data_base
=
torch
.
empty
(
self
.
rank_data_base
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
)
self
.
barrier_in_ptrs
=
self
.
create_shared_buffer
(
self
.
barrier_in_ptrs
=
self
.
create_shared_buffer
(
self
.
barrier_max_size
,
group
=
group
self
.
barrier_max_size
,
group
=
group
)
)
self
.
barrier_out_ptrs
=
self
.
create_shared_buffer
(
self
.
barrier_out_ptrs
=
self
.
create_shared_buffer
(
self
.
barrier_max_size
,
group
=
group
self
.
barrier_max_size
,
group
=
group
)
)
self
.
_ptr
=
ops
.
init_custom_ar
(
self
.
_ptr
=
ops
.
init_custom_ar
(
rank
,
rank
,
world_size
,
world_size
,
self
.
rank_data_base
,
self
.
rank_data_base
,
self
.
buffer_ptrs
,
self
.
buffer_ptrs
,
self
.
tmp_result_buffer_ptrs
,
self
.
tmp_result_buffer_ptrs
,
self
.
barrier_in_ptrs
,
self
.
barrier_in_ptrs
,
self
.
barrier_out_ptrs
,
self
.
barrier_out_ptrs
,
)
)
self
.
disabled
=
False
self
.
disabled
=
False
@
staticmethod
@
staticmethod
...
@@ -316,23 +376,69 @@ class CustomAllreduce:
...
@@ -316,23 +376,69 @@ class CustomAllreduce:
if
not
self
.
disabled
:
if
not
self
.
disabled
:
self
.
register_graph_buffers
()
self
.
register_graph_buffers
()
def
register_graph_buffers
(
self
):
def
_get_ipc_meta
(
self
,
inp
:
torch
.
Tensor
):
handle
,
offset
=
ops
.
get_graph_buffer_ipc_meta
(
self
.
_ptr
)
# _share_cuda_() doesn't accept meta buffer not allocated from
logger
.
info
(
"Registering %d cuda graph addresses"
,
len
(
offset
))
# PyTorch cache allocator, use direct HIP call to get IPC handle
# We cannot directly use `dist.all_gather_object` here
handle
=
ops
.
get_meta_buffer_ipc_handle
(
inp
)
# because it is incompatible with `gloo` backend under inference mode.
shard_data
=
(
# see https://github.com/pytorch/pytorch/issues/126032 for details.
bytes
(
handle
),
# ipc handle to base ptr
all_data
=
[[
None
,
None
]
for
_
in
range
(
dist
.
get_world_size
(
group
=
self
.
group
))]
0
,
# offset of base ptr
all_data
[
self
.
rank
]
=
[
handle
,
offset
]
)
ranks
=
sorted
(
dist
.
get_process_group_ranks
(
group
=
self
.
group
))
return
self
.
_gather_ipc_meta
(
shard_data
)
def
_gather_ipc_meta
(
self
,
shard_data
):
# Note: don't use `[[None]] * self.world_size` here
# because it will create a list of the same reference
all_data
:
List
[
Optional
[
Any
]]
=
[[
None
]
for
i
in
range
(
self
.
world_size
)]
all_data
[
self
.
rank
][
0
]
=
shard_data
ranks
=
dist
.
get_process_group_ranks
(
group
=
self
.
group
)
ranks
.
sort
()
for
i
,
rank
in
enumerate
(
ranks
):
for
i
,
rank
in
enumerate
(
ranks
):
dist
.
broadcast_object_list
(
dist
.
broadcast_object_list
(
all_data
[
i
],
src
=
rank
,
group
=
self
.
group
,
device
=
"cpu"
all_data
[
i
],
src
=
rank
,
group
=
self
.
group
,
device
=
"cpu"
)
)
# Unpack list of tuples to tuple of lists.
handles
=
[
d
[
0
]
for
d
in
all_data
]
# type: ignore
# we cannot directly use `dist.all_gather_object` here
offsets
=
[
d
[
1
]
for
d
in
all_data
]
# type: ignore
# because it is incompatible with `gloo` backend under inference mode.
ops
.
register_graph_buffers
(
self
.
_ptr
,
handles
,
offsets
)
# see https://github.com/pytorch/pytorch/issues/126032 for details.
handles
=
[]
offsets
=
[]
for
i
in
range
(
len
(
all_data
)):
handles
.
append
(
all_data
[
i
][
0
][
0
])
# type: ignore
offsets
.
append
(
all_data
[
i
][
0
][
1
])
# type: ignore
return
handles
,
offsets
def
register_buffer
(
self
,
inp
:
torch
.
Tensor
):
handles
,
offsets
=
self
.
_get_ipc_meta
(
inp
)
ops
.
register_buffer
(
self
.
_ptr
,
inp
,
handles
,
offsets
)
def
register_graph_buffers
(
self
):
if
is_hip
():
handle
,
offset
=
ops
.
get_graph_buffer_ipc_meta
(
self
.
_ptr
)
handles
,
offsets
=
self
.
_gather_ipc_meta
((
bytes
(
handle
),
offset
))
logger
.
info
(
"Registering %d cuda graph addresses"
,
len
(
offset
))
ops
.
register_graph_buffers
(
self
.
_ptr
,
handles
,
offsets
)
else
:
handle
,
offset
=
ops
.
get_graph_buffer_ipc_meta
(
self
.
_ptr
)
logger
.
info
(
"Registering %d cuda graph addresses"
,
len
(
offset
))
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data
=
[
[
None
,
None
]
for
_
in
range
(
dist
.
get_world_size
(
group
=
self
.
group
))
]
all_data
[
self
.
rank
]
=
[
handle
,
offset
]
ranks
=
sorted
(
dist
.
get_process_group_ranks
(
group
=
self
.
group
))
for
i
,
rank
in
enumerate
(
ranks
):
dist
.
broadcast_object_list
(
all_data
[
i
],
src
=
rank
,
group
=
self
.
group
,
device
=
"cpu"
)
# Unpack list of tuples to tuple of lists.
handles
=
[
d
[
0
]
for
d
in
all_data
]
# type: ignore
offsets
=
[
d
[
1
]
for
d
in
all_data
]
# type: ignore
ops
.
register_graph_buffers
(
self
.
_ptr
,
handles
,
offsets
)
def
should_custom_ar
(
self
,
inp
:
torch
.
Tensor
):
def
should_custom_ar
(
self
,
inp
:
torch
.
Tensor
):
if
self
.
disabled
:
if
self
.
disabled
:
...
@@ -345,11 +451,22 @@ class CustomAllreduce:
...
@@ -345,11 +451,22 @@ class CustomAllreduce:
return
False
return
False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
# little performance improvement over NCCL.
if
ops
.
use_vllm_custom_allreduce
:
if
ops
.
use_vllm_custom_allreduce
and
not
is_hip
()
:
if
self
.
world_size
==
2
or
self
.
full_nvlink
:
if
self
.
world_size
==
2
or
self
.
full_nvlink
:
return
inp_size
<
self
.
max_size
return
inp_size
<
self
.
max_size
return
False
return
False
if
is_hip
():
if
self
.
full_nvlink
:
if
self
.
world_size
==
8
:
if
self
.
MSCCL
:
return
False
else
:
return
inp_size
<
self
.
max_size
else
:
return
inp_size
<
self
.
max_size
return
False
if
self
.
world_size
==
2
:
if
self
.
world_size
==
2
:
return
(
return
(
inp_size
<
self
.
max_size
inp_size
<
self
.
max_size
...
@@ -364,6 +481,21 @@ class CustomAllreduce:
...
@@ -364,6 +481,21 @@ class CustomAllreduce:
return
False
return
False
# all reduce, assuming inp tensor is IPC registered with register_buffer,
# or, in the context of cuda graphs, register_graph_buffers
def
all_reduce_reg
(
self
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
=
None
):
if
out
is
None
:
out
=
torch
.
empty_like
(
inp
)
ops
.
all_reduce_reg
(
self
.
_ptr
,
inp
,
out
)
return
out
# all reduce, assuming inp tensor is NOT IPC registered
def
all_reduce_unreg
(
self
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
=
None
):
if
out
is
None
:
out
=
torch
.
empty_like
(
inp
)
ops
.
all_reduce_unreg
(
self
.
_ptr
,
inp
,
self
.
buffer
,
out
)
return
out
def
all_reduce
(
def
all_reduce
(
self
,
self
,
inp
:
torch
.
Tensor
,
inp
:
torch
.
Tensor
,
...
@@ -397,13 +529,23 @@ class CustomAllreduce:
...
@@ -397,13 +529,23 @@ class CustomAllreduce:
return
None
return
None
if
self
.
_IS_CAPTURING
:
if
self
.
_IS_CAPTURING
:
if
torch
.
cuda
.
is_current_stream_capturing
():
if
torch
.
cuda
.
is_current_stream_capturing
():
return
self
.
all_reduce
(
input
,
registered
=
True
)
if
is_hip
():
return
self
.
all_reduce_reg
(
input
)
else
:
return
self
.
all_reduce
(
input
,
registered
=
True
)
else
:
else
:
# If warm up, mimic the allocation pattern since custom
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
# allreduce is out-of-place.
return
torch
.
empty_like
(
input
)
return
torch
.
empty_like
(
input
)
else
:
else
:
return
self
.
all_reduce
(
input
,
registered
=
False
)
if
is_hip
():
# note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance
# gains of using custom kernels
return
self
.
all_reduce_unreg
(
input
)
else
:
return
self
.
all_reduce
(
input
,
registered
=
False
)
def
close
(
self
):
def
close
(
self
):
if
not
self
.
disabled
and
self
.
_ptr
:
if
not
self
.
disabled
and
self
.
_ptr
:
...
@@ -411,7 +553,7 @@ class CustomAllreduce:
...
@@ -411,7 +553,7 @@ class CustomAllreduce:
if
ops
.
use_vllm_custom_allreduce
:
if
ops
.
use_vllm_custom_allreduce
:
self
.
free_shared_buffer
(
self
.
meta_ptrs
)
self
.
free_shared_buffer
(
self
.
meta_ptrs
)
self
.
free_shared_buffer
(
self
.
buffer_ptrs
)
self
.
free_shared_buffer
(
self
.
buffer_ptrs
)
el
se
:
el
if
is_cuda
()
:
self
.
free_shared_buffer
(
self
.
buffer_ptrs
)
self
.
free_shared_buffer
(
self
.
buffer_ptrs
)
self
.
free_shared_buffer
(
self
.
tmp_result_buffer_ptrs
)
self
.
free_shared_buffer
(
self
.
tmp_result_buffer_ptrs
)
self
.
free_shared_buffer
(
self
.
barrier_in_ptrs
)
self
.
free_shared_buffer
(
self
.
barrier_in_ptrs
)
...
...
sgl-kernel/setup_rocm.py
View file @
9cf40772
...
@@ -44,6 +44,7 @@ include_dirs = [
...
@@ -44,6 +44,7 @@ include_dirs = [
sources
=
[
sources
=
[
"src/sgl-kernel/torch_extension_rocm.cc"
,
"src/sgl-kernel/torch_extension_rocm.cc"
,
"src/sgl-kernel/csrc/moe_align_kernel.cu"
,
"src/sgl-kernel/csrc/moe_align_kernel.cu"
,
"src/sgl-kernel/csrc/custom_all_reduce.hip"
,
]
]
cxx_flags
=
[
"-O3"
]
cxx_flags
=
[
"-O3"
]
...
...
sgl-kernel/src/sgl-kernel/__init__.py
View file @
9cf40772
import
ctypes
import
ctypes
import
os
import
os
import
torch
if
os
.
path
.
exists
(
"/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"
):
if
os
.
path
.
exists
(
"/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"
):
ctypes
.
CDLL
(
ctypes
.
CDLL
(
"/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"
,
"/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"
,
mode
=
ctypes
.
RTLD_GLOBAL
,
mode
=
ctypes
.
RTLD_GLOBAL
,
)
)
from
.version
import
__version__
from
sgl_kernel.ops
import
(
if
torch
.
version
.
hip
is
not
None
:
apply_rope_with_cos_sin_cache_inplace
,
from
sgl_kernel.ops
import
(
bmm_fp8
,
all_reduce_reg
,
build_tree_kernel
,
all_reduce_unreg
,
build_tree_kernel_efficient
,
allocate_meta_buffer
,
cublas_g
ro
u
pe
d_gemm
,
apply_
rope
_with_cos_sin_cache_inplace
,
custom_dispose
,
bmm_fp8
,
custom_reduc
e
,
dispos
e
,
fp8_
blockwise_
scaled_mm
,
fp8_scaled_mm
,
f
p8_scaled_m
m
,
f
used_add_rmsnor
m
,
fused
_a
d
d_
rmsnorm
,
gelu
_a
n
d_
mul
,
gelu_and_mul
,
gelu_
tanh_
and_mul
,
ge
lu_tanh
_a
n
d_
mul
,
ge
mma_fused
_a
d
d_
rmsnorm
,
gemma_
fused_add_
rmsnorm
,
gemma_rmsnorm
,
ge
mma_rmsnorm
,
ge
t_graph_buffer_ipc_meta
,
get_
graph
_buffer_ipc_
meta
,
get_
meta
_buffer_ipc_
handle
,
init_custom_r
educe
,
init_custom_
a
r
,
int8_scaled_mm
,
int8_scaled_mm
,
lightning_attention_decode
,
lightning_attention_decode
,
m
in_p_sampling_from_probs
,
m
eta_size
,
m
oe_align_block_size
,
m
in_p_sampling_from_probs
,
register_graph_buffers
,
moe_align_block_size
,
r
msnorm
,
r
egister_buffer
,
sampling_scaling_penaltie
s
,
register_graph_buffer
s
,
sgl_per_token_group_quant_fp8
,
rmsnorm
,
s
ilu_and_mul
,
s
ampling_scaling_penalties
,
top_k_renorm_prob
,
silu_and_mul
,
top_k_
top_p_sampling_fro
m_prob
s
,
top_k_
renor
m_prob
,
top_
p_renor
m_prob
,
top_
k_top_p_sampling_fro
m_prob
s
,
t
ree_speculative_sampling_target_only
,
t
op_p_renorm_prob
,
)
)
from
.version
import
__version__
__all__
=
[
"all_reduce_reg"
,
"all_reduce_unreg"
,
"allocate_meta_buffer"
,
"apply_rope_with_cos_sin_cache_inplace"
,
"bmm_fp8"
,
"dispose"
,
"fp8_scaled_mm"
,
"fused_add_rmsnorm"
,
"gelu_and_mul"
,
"gelu_tanh_and_mul"
,
"gemma_fused_add_rmsnorm"
,
"gemma_rmsnorm"
,
"get_graph_buffer_ipc_meta"
,
"get_meta_buffer_ipc_handle"
,
"init_custom_ar"
,
"int8_scaled_mm"
,
"lightning_attention_decode"
,
"meta_size"
,
"min_p_sampling_from_probs"
,
"moe_align_block_size"
,
"register_buffer"
,
"register_graph_buffers"
,
"rmsnorm"
,
"sampling_scaling_penalties"
,
"silu_and_mul"
,
"top_k_renorm_prob"
,
"top_k_top_p_sampling_from_probs"
,
"top_p_renorm_prob"
,
]
else
:
from
sgl_kernel.ops
import
(
apply_rope_with_cos_sin_cache_inplace
,
bmm_fp8
,
build_tree_kernel
,
build_tree_kernel_efficient
,
cublas_grouped_gemm
,
custom_dispose
,
custom_reduce
,
fp8_blockwise_scaled_mm
,
fp8_scaled_mm
,
fused_add_rmsnorm
,
gelu_and_mul
,
gelu_tanh_and_mul
,
gemma_fused_add_rmsnorm
,
gemma_rmsnorm
,
get_graph_buffer_ipc_meta
,
init_custom_reduce
,
int8_scaled_mm
,
lightning_attention_decode
,
min_p_sampling_from_probs
,
moe_align_block_size
,
register_graph_buffers
,
rmsnorm
,
sampling_scaling_penalties
,
sgl_per_token_group_quant_fp8
,
silu_and_mul
,
top_k_renorm_prob
,
top_k_top_p_sampling_from_probs
,
top_p_renorm_prob
,
tree_speculative_sampling_target_only
,
)
__all__
=
[
__all__
=
[
"apply_rope_with_cos_sin_cache_inplace"
,
"apply_rope_with_cos_sin_cache_inplace"
,
"bmm_fp8"
,
"bmm_fp8"
,
"cublas_grouped_gemm"
,
"cublas_grouped_gemm"
,
"custom_dispose"
,
"custom_dispose"
,
"custom_reduce"
,
"custom_reduce"
,
"fp8_blockwise_scaled_mm"
,
"fp8_blockwise_scaled_mm"
,
"fp8_scaled_mm"
,
"fp8_scaled_mm"
,
"fused_add_rmsnorm"
,
"fused_add_rmsnorm"
,
"gelu_and_mul"
,
"gelu_and_mul"
,
"gelu_tanh_and_mul"
,
"gelu_tanh_and_mul"
,
"gemma_fused_add_rmsnorm"
,
"gemma_fused_add_rmsnorm"
,
"gemma_rmsnorm"
,
"gemma_rmsnorm"
,
"get_graph_buffer_ipc_meta"
,
"get_graph_buffer_ipc_meta"
,
"init_custom_reduce"
,
"init_custom_reduce"
,
"int8_scaled_mm"
,
"int8_scaled_mm"
,
"lightning_attention_decode"
,
"lightning_attention_decode"
,
"min_p_sampling_from_probs"
,
"min_p_sampling_from_probs"
,
"moe_align_block_size"
,
"moe_align_block_size"
,
"register_graph_buffers"
,
"register_graph_buffers"
,
"rmsnorm"
,
"rmsnorm"
,
"sampling_scaling_penalties"
,
"sampling_scaling_penalties"
,
"silu_and_mul"
,
"silu_and_mul"
,
"top_k_renorm_prob"
,
"top_k_renorm_prob"
,
"top_k_top_p_sampling_from_probs"
,
"top_k_top_p_sampling_from_probs"
,
"top_p_renorm_prob"
,
"top_p_renorm_prob"
,
"tree_speculative_sampling_target_only"
,
"tree_speculative_sampling_target_only"
,
"build_tree_kernel_efficient"
,
"build_tree_kernel_efficient"
,
"build_tree_kernel"
,
"build_tree_kernel"
,
"sgl_per_token_group_quant_fp8"
,
"sgl_per_token_group_quant_fp8"
,
]
]
sgl-kernel/src/sgl-kernel/csrc/custom_all_reduce.hip
0 → 100644
View file @
9cf40772
// !!! This is a file automatically generated by hipify!!!
#include <ATen/hip/Exceptions.h>
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
#include <torch/all.h>
#include "custom_all_reduce_hip.cuh"
// fake pointer type, must match fptr_t type in ops.h
using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets, int64_t rank,
bool full_nvlink) {
int world_size = offsets.size();
if (world_size > 8)
throw std::invalid_argument("world size > 8 is not supported");
if (world_size % 2 != 0)
throw std::invalid_argument("Odd num gpus is not supported for now");
if (world_size != handles.size())
throw std::invalid_argument(
"handles length should equal to offsets length");
if (rank < 0 || rank >= world_size)
throw std::invalid_argument("invalid rank passed in");
hipIpcMemHandle_t ipc_handles[8];
for (int i = 0; i < world_size; i++) {
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(hipIpcMemHandle_t));
}
return (fptr_t) new vllm::CustomAllreduce(
reinterpret_cast<vllm::Signal*>(meta.data_ptr()), rank_data.data_ptr(),
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
}
/**
* Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
* t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
* because it allows transpose of contiguous slice (i.e. slicing the first
* dimension). Currently, we require this because stride information is not
* passed into the kernels and we treat input tensors as flat.
*
* Examples
* A = torch.zeros(3, 3, 3)
* 1. A: OK
* 2. A[1:]: OK
* 3. A.permute(2, 0, 1): OK
* 4. A[1:].permute(2, 0, 1): OK
* 5. A[None].expand(2, -1, -1, -1): Not OK
* 6. A[:, 1:, 1:]: Not OK
*/
bool _is_weak_contiguous(torch::Tensor& t) {
return t.is_contiguous() ||
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
t.numel() * t.element_size());
}
void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
hipStream_t stream) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
TORCH_CHECK(_is_weak_contiguous(out));
switch (out.scalar_type()) {
case at::ScalarType::Float: {
fa->allreduce<float>(stream, reinterpret_cast<float*>(inp.data_ptr()),
reinterpret_cast<float*>(out.data_ptr()),
out.numel());
break;
}
case at::ScalarType::Half: {
fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half*>(out.data_ptr()), out.numel());
break;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: {
fa->allreduce<nv_bfloat16>(
stream, reinterpret_cast<nv_bfloat16*>(inp.data_ptr()),
reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
break;
}
#endif
default:
throw std::runtime_error(
"custom allreduce only supports float32, float16 and bfloat16");
}
}
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp));
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
_all_reduce(_fa, inp, out, stream);
}
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor& out) {
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp));
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
auto input_size = inp.numel() * inp.element_size();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
"registered buffer is too small to contain the input");
AT_CUDA_CHECK(hipMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
input_size, hipMemcpyDeviceToDevice, stream));
_all_reduce(_fa, reg_buffer, out, stream);
}
void dispose(fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
delete fa;
}
int64_t meta_size() { return sizeof(vllm::Signal); }
void register_buffer(fptr_t _fa, torch::Tensor& t,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
fa->register_buffer(handles, offsets, t.data_ptr());
}
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto handles =
torch::empty({static_cast<int64_t>(handle_bytes.size())}, options);
std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size());
return {handles, std::move(offsets)};
}
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
fa->register_graph_buffers(handles, offsets);
}
void free_meta_buffer(void* buffer) { CUDACHECK(hipFree(buffer)); }
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp) {
auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto data_handle =
torch::empty({static_cast<int64_t>(sizeof(hipIpcMemHandle_t))}, options);
CUDACHECK(hipIpcGetMemHandle((hipIpcMemHandle_t*)data_handle.data_ptr(),
inp.data_ptr()));
return data_handle;
}
torch::Tensor allocate_meta_buffer(int64_t size) {
auto device_index = c10::hip::current_device();
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
void* buffer;
hipStreamCaptureMode mode = hipStreamCaptureModeRelaxed;
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
AT_CUDA_CHECK(hipThreadExchangeStreamCaptureMode(&mode));
AT_CUDA_CHECK(
hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached));
AT_CUDA_CHECK(hipMemsetAsync(buffer, 0, size, stream));
AT_CUDA_CHECK(hipStreamSynchronize(stream));
AT_CUDA_CHECK(hipThreadExchangeStreamCaptureMode(&mode));
auto options = torch::TensorOptions()
.dtype(torch::kI8)
.device(torch::kCUDA, device_index);
return torch::from_blob(buffer, {size}, free_meta_buffer, options);
}
std::vector<uint8_t> get_device_bdf(int dev) {
char busIdStr[] = "0000:00:00.0";
std::vector<uint8_t> bdf(sizeof(busIdStr), 0);
CUDACHECK(hipDeviceGetPCIBusId((char*)bdf.data(), sizeof(busIdStr), dev));
bdf.resize(bdf.size() - 1); // remove trailing NULL
return bdf;
}
sgl-kernel/src/sgl-kernel/csrc/custom_all_reduce_hip.cuh
0 → 100644
View file @
9cf40772
This diff is collapsed.
Click to expand it.
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
View file @
9cf40772
...
@@ -34,8 +34,23 @@ limitations under the License.
...
@@ -34,8 +34,23 @@ limitations under the License.
return PyModule_Create(&module); \
return PyModule_Create(&module); \
}
}
// trt_reduce
using
fptr_t
=
int64_t
;
using
fptr_t
=
int64_t
;
#ifdef USE_ROCM
fptr_t
init_custom_ar
(
torch
::
Tensor
&
meta
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
int64_t
rank
,
bool
full_nvlink
);
void
all_reduce_reg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
);
void
all_reduce_unreg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
reg_buffer
,
torch
::
Tensor
&
out
);
void
dispose
(
fptr_t
_fa
);
int64_t
meta_size
();
void
register_buffer
(
fptr_t
_fa
,
torch
::
Tensor
&
t
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
);
std
::
tuple
<
torch
::
Tensor
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
);
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
torch
::
Tensor
allocate_meta_buffer
(
int64_t
size
);
torch
::
Tensor
get_meta_buffer_ipc_handle
(
torch
::
Tensor
&
inp
);
#else
// trt_reduce
fptr_t
init_custom_ar
(
int64_t
rank_id
,
int64_t
world_size
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
fptr_t
>&
buffers
,
fptr_t
init_custom_ar
(
int64_t
rank_id
,
int64_t
world_size
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
fptr_t
>&
buffers
,
const
std
::
vector
<
fptr_t
>&
tmp_result_buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
const
std
::
vector
<
fptr_t
>&
tmp_result_buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
const
std
::
vector
<
fptr_t
>&
barrier_out
);
const
std
::
vector
<
fptr_t
>&
barrier_out
);
...
@@ -44,6 +59,7 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
...
@@ -44,6 +59,7 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
std
::
tuple
<
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
);
std
::
tuple
<
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
);
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
#endif
// moe_align_block_size
// moe_align_block_size
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
int64_t
block_size
,
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
int64_t
block_size
,
...
...
sgl-kernel/src/sgl-kernel/ops/__init__.py
View file @
9cf40772
...
@@ -64,28 +64,79 @@ def apply_rope_with_cos_sin_cache_inplace(
...
@@ -64,28 +64,79 @@ def apply_rope_with_cos_sin_cache_inplace(
)
)
def
init_custom_reduce
(
if
torch
.
version
.
hip
is
not
None
:
rank_id
,
num_devices
,
rank_data
,
buffers
,
tmp_buffers
,
barrier_in
,
barrier_out
):
def
init_custom_ar
(
return
torch
.
ops
.
sgl_kernels
.
init_custom_ar
(
meta
:
torch
.
Tensor
,
rank_id
,
num_devices
,
rank_data
,
buffers
,
tmp_buffers
,
barrier_in
,
barrier_out
rank_data
:
torch
.
Tensor
,
)
handles
:
List
[
str
],
offsets
:
List
[
int
],
rank
:
int
,
full_nvlink
:
bool
,
)
->
int
:
return
torch
.
ops
.
sgl_kernels
.
init_custom_ar
(
meta
,
rank_data
,
handles
,
offsets
,
rank
,
full_nvlink
)
def
all_reduce_reg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
sgl_kernels
.
all_reduce_reg
(
fa
,
inp
,
out
)
def
all_reduce_unreg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
reg_buffer
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
sgl_kernels
.
all_reduce_unreg
(
fa
,
inp
,
reg_buffer
,
out
)
def
dispose
(
fa
:
int
)
->
None
:
torch
.
ops
.
sgl_kernels
.
dispose
(
fa
)
def
custom_dispose
(
fa
)
:
def
meta_size
()
->
int
:
torch
.
ops
.
sgl_kernels
.
dispose
(
fa
)
return
torch
.
ops
.
sgl_kernels
.
meta_size
(
)
def
register_buffer
(
fa
:
int
,
t
:
torch
.
Tensor
,
handles
:
List
[
str
],
offsets
:
List
[
int
]
)
->
None
:
return
torch
.
ops
.
sgl_kernels
.
register_buffer
(
fa
,
t
,
handles
,
offsets
)
def
custom_reduce
(
fa
,
inp
,
out
):
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
torch
.
Tensor
,
List
[
int
]]:
torch
.
ops
.
sgl_kernels
.
all_reduce
(
fa
,
inp
,
out
)
return
torch
.
ops
.
sgl_kernels
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
str
],
offsets
:
List
[
List
[
int
]]
)
->
None
:
torch
.
ops
.
sgl_kernels
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
def
allocate_meta_buffer
(
size
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
sgl_kernels
.
allocate_meta_buffer
(
size
)
def
get_meta_buffer_ipc_handle
(
inp
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
ops
.
sgl_kernels
.
get_meta_buffer_ipc_handle
(
inp
)
else
:
# trt_reduce
def
init_custom_reduce
(
rank_id
,
num_devices
,
rank_data
,
buffers
,
tmp_buffers
,
barrier_in
,
barrier_out
):
return
torch
.
ops
.
sgl_kernels
.
init_custom_ar
(
rank_id
,
num_devices
,
rank_data
,
buffers
,
tmp_buffers
,
barrier_in
,
barrier_out
,
)
def
custom_dispose
(
fa
):
torch
.
ops
.
sgl_kernels
.
dispose
(
fa
)
def
get_graph_buffer_ipc_meta
(
fa
):
def
custom_reduce
(
fa
,
inp
,
out
):
return
torch
.
ops
.
sgl_kernels
.
get_graph_buffer_ipc_meta
(
fa
)
torch
.
ops
.
sgl_kernels
.
all_reduce
(
fa
,
inp
,
out
)
def
get_graph_buffer_ipc_meta
(
fa
):
return
torch
.
ops
.
sgl_kernels
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
,
handles
,
offsets
):
def
register_graph_buffers
(
fa
,
handles
,
offsets
):
torch
.
ops
.
sgl_kernels
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
torch
.
ops
.
sgl_kernels
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
def
moe_align_block_size
(
def
moe_align_block_size
(
...
...
sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc
View file @
9cf40772
...
@@ -19,6 +19,37 @@ limitations under the License.
...
@@ -19,6 +19,37 @@ limitations under the License.
#include "sgl_kernels_ops.h"
#include "sgl_kernels_ops.h"
TORCH_LIBRARY_EXPAND
(
sgl_kernels
,
m
)
{
TORCH_LIBRARY_EXPAND
(
sgl_kernels
,
m
)
{
// Custom all-reduce kernels
m
.
def
(
"init_custom_ar(Tensor meta, Tensor rank_data, "
"str[] handles, int[] offsets, int rank, "
"bool full_nvlink) -> int"
);
m
.
impl
(
"init_custom_ar"
,
torch
::
kCUDA
,
&
init_custom_ar
);
m
.
def
(
"all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()"
);
m
.
impl
(
"all_reduce_reg"
,
torch
::
kCUDA
,
&
all_reduce_reg
);
m
.
def
(
"all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> "
"()"
);
m
.
impl
(
"all_reduce_unreg"
,
torch
::
kCUDA
,
&
all_reduce_unreg
);
m
.
def
(
"dispose"
,
&
dispose
);
m
.
def
(
"meta_size"
,
&
meta_size
);
m
.
def
(
"register_buffer(int fa, Tensor t, str[] handles, "
"int[] offsets) -> ()"
);
m
.
impl
(
"register_buffer"
,
torch
::
kCUDA
,
&
register_buffer
);
m
.
def
(
"get_graph_buffer_ipc_meta"
,
&
get_graph_buffer_ipc_meta
);
m
.
def
(
"register_graph_buffers"
,
&
register_graph_buffers
);
m
.
def
(
"allocate_meta_buffer"
,
&
allocate_meta_buffer
);
m
.
impl
(
"allocate_meta_buffer"
,
torch
::
kCUDA
,
&
allocate_meta_buffer
);
m
.
def
(
"get_meta_buffer_ipc_handle"
,
&
get_meta_buffer_ipc_handle
);
m
.
impl
(
"get_meta_buffer_ipc_handle"
,
torch
::
kCPU
,
&
get_meta_buffer_ipc_handle
);
// moe_align_block_size
// moe_align_block_size
m
.
def
(
m
.
def
(
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment