Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
6cb3974e
"torchvision/transforms/v2/functional/_meta.py" did not exist on "cba1c011a87dd14af10f97bcb113fa09a8e2b396"
Unverified
Commit
6cb3974e
authored
Jan 16, 2025
by
yizhang2077
Committed by
GitHub
Jan 16, 2025
Browse files
optimize custom allreduce kernel (#2904)
parent
f65c13b5
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
244 additions
and
80 deletions
+244
-80
sgl-kernel/pyproject.toml
sgl-kernel/pyproject.toml
+1
-1
sgl-kernel/setup.py
sgl-kernel/setup.py
+1
-1
sgl-kernel/src/sgl-kernel/__init__.py
sgl-kernel/src/sgl-kernel/__init__.py
+4
-0
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
+8
-2
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu
+83
-56
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh
+7
-1
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu
+109
-8
sgl-kernel/src/sgl-kernel/ops/__init__.py
sgl-kernel/src/sgl-kernel/ops/__init__.py
+18
-2
sgl-kernel/tests/test_trt_reduce.py
sgl-kernel/tests/test_trt_reduce.py
+13
-9
No files found.
sgl-kernel/pyproject.toml
View file @
6cb3974e
...
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
...
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
[project]
name
=
"sgl-kernel"
name
=
"sgl-kernel"
version
=
"0.0.2.post1
2
"
version
=
"0.0.2.post1
3
"
description
=
"Kernel Library for SGLang"
description
=
"Kernel Library for SGLang"
readme
=
"README.md"
readme
=
"README.md"
requires-python
=
">=3.8"
requires-python
=
">=3.8"
...
...
sgl-kernel/setup.py
View file @
6cb3974e
...
@@ -40,7 +40,7 @@ nvcc_flags = [
...
@@ -40,7 +40,7 @@ nvcc_flags = [
"-U__CUDA_NO_HALF2_OPERATORS__"
,
"-U__CUDA_NO_HALF2_OPERATORS__"
,
]
]
cxx_flags
=
[
"-O3"
]
cxx_flags
=
[
"-O3"
]
libraries
=
[
"c10"
,
"torch"
,
"torch_python"
]
libraries
=
[
"c10"
,
"torch"
,
"torch_python"
,
"cuda"
]
extra_link_args
=
[
"-Wl,-rpath,$ORIGIN/../../torch/lib"
]
extra_link_args
=
[
"-Wl,-rpath,$ORIGIN/../../torch/lib"
]
ext_modules
=
[
ext_modules
=
[
CUDAExtension
(
CUDAExtension
(
...
...
sgl-kernel/src/sgl-kernel/__init__.py
View file @
6cb3974e
from
sgl_kernel.ops
import
(
from
sgl_kernel.ops
import
(
custom_dispose
,
custom_dispose
,
custom_reduce
,
custom_reduce
,
get_graph_buffer_ipc_meta
,
init_custom_reduce
,
init_custom_reduce
,
int8_scaled_mm
,
int8_scaled_mm
,
moe_align_block_size
,
moe_align_block_size
,
register_graph_buffers
,
sampling_scaling_penalties
,
sampling_scaling_penalties
,
)
)
...
@@ -14,4 +16,6 @@ __all__ = [
...
@@ -14,4 +16,6 @@ __all__ = [
"custom_reduce"
,
"custom_reduce"
,
"int8_scaled_mm"
,
"int8_scaled_mm"
,
"sampling_scaling_penalties"
,
"sampling_scaling_penalties"
,
"get_graph_buffer_ipc_meta"
,
"register_graph_buffers"
,
]
]
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
View file @
6cb3974e
...
@@ -2,10 +2,14 @@
...
@@ -2,10 +2,14 @@
// trt_reduce
// trt_reduce
using
fptr_t
=
int64_t
;
using
fptr_t
=
int64_t
;
fptr_t
init_custom_ar
(
int64_t
rank_id
,
int64_t
world_size
,
const
std
::
vector
<
fptr_t
>&
buffers
,
fptr_t
init_custom_ar
(
int64_t
rank_id
,
int64_t
world_size
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
fptr_t
>&
buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
const
std
::
vector
<
fptr_t
>&
barrier_out
);
const
std
::
vector
<
fptr_t
>&
tmp_result_buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
const
std
::
vector
<
fptr_t
>&
barrier_out
);
void
dispose
(
fptr_t
_fa
);
void
dispose
(
fptr_t
_fa
);
void
all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
);
void
all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
);
std
::
tuple
<
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
);
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
// moe_align_block_size
// moe_align_block_size
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
int64_t
block_size
,
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
int64_t
block_size
,
...
@@ -25,6 +29,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -25,6 +29,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"init_custom_ar"
,
&
init_custom_ar
,
"init custom allreduce meta (CUDA)"
);
m
.
def
(
"init_custom_ar"
,
&
init_custom_ar
,
"init custom allreduce meta (CUDA)"
);
m
.
def
(
"dispose"
,
&
dispose
,
"dispose custom allreduce meta"
);
m
.
def
(
"dispose"
,
&
dispose
,
"dispose custom allreduce meta"
);
m
.
def
(
"all_reduce"
,
&
all_reduce
,
"custom all reduce (CUDA)"
);
m
.
def
(
"all_reduce"
,
&
all_reduce
,
"custom all reduce (CUDA)"
);
m
.
def
(
"get_graph_buffer_ipc_meta"
,
&
get_graph_buffer_ipc_meta
,
"custom all reduce get graph ipc meta"
);
m
.
def
(
"register_graph_buffers"
,
&
register_graph_buffers
,
"custom all reduce register graph buffers"
);
// moe_align_block_size
// moe_align_block_size
m
.
def
(
"moe_align_block_size"
,
&
moe_align_block_size
,
"MOE Align Block Size (CUDA)"
);
m
.
def
(
"moe_align_block_size"
,
&
moe_align_block_size
,
"MOE Align Block Size (CUDA)"
);
// sampling_scaling_penalties
// sampling_scaling_penalties
...
...
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu
View file @
6cb3974e
...
@@ -126,10 +126,10 @@ __inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const
...
@@ -126,10 +126,10 @@ __inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const
__syncthreads
();
__syncthreads
();
}
}
template
<
bool
start
,
bool
need_fence
=
false
>
__inline__
__device__
void
block_barrier
(
uint32_t
**
signals
,
uint32_t
const
flag
,
size_t
const
local_rank
,
__inline__
__device__
void
block_barrier
(
uint32_t
**
signals
,
uint32_t
const
flag
,
size_t
const
local_rank
,
size_t
const
world_size
,
int
const
tidx
,
int
const
bidx
,
int
const
grid_size
,
size_t
const
world_size
,
int
const
tidx
,
int
const
bidx
,
int
const
grid_size
)
{
bool
start
=
true
,
bool
need_fence
=
false
)
{
if
constexpr
(
!
start
)
{
if
(
!
start
)
{
__syncthreads
();
__syncthreads
();
}
}
// After this function, the block of id == bidx of each GPU has reached the barrier
// After this function, the block of id == bidx of each GPU has reached the barrier
...
@@ -141,22 +141,16 @@ __inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag
...
@@ -141,22 +141,16 @@ __inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag
// Block broadcast its flag (local_rank on emitting dimension) to all receivers
// Block broadcast its flag (local_rank on emitting dimension) to all receivers
uint32_t
flag_block_offset
=
world_size
+
bidx
*
world_size
;
uint32_t
flag_block_offset
=
world_size
+
bidx
*
world_size
;
if
(
flag
%
2
==
1
)
{
flag_block_offset
+=
(
grid_size
+
1
)
*
world_size
*
(
flag
%
2
);
flag_block_offset
+=
(
grid_size
+
1
)
*
world_size
;
}
if
(
need_fence
)
{
st_flag_release
(
flag
,
signals
[
tidx
]
+
flag_block_offset
+
local_rank
);
}
else
{
st_flag_volatile
(
flag
,
signals
[
tidx
]
+
flag_block_offset
+
local_rank
);
}
// Blocks check that corresponding blocks on other GPUs have also set the flag
uint32_t
*
peer_barrier_d
=
signals
[
local_rank
]
+
flag_block_offset
+
tidx
;
uint32_t
*
peer_barrier_d
=
signals
[
local_rank
]
+
flag_block_offset
+
tidx
;
// Blocks check that corresponding blocks on other GPUs have also set the flag
if
(
need_fence
)
{
if
constexpr
(
need_fence
)
{
st_flag_release
(
flag
,
signals
[
tidx
]
+
flag_block_offset
+
local_rank
);
while
(
ld_flag_acquire
(
peer_barrier_d
)
!=
flag
)
{
while
(
ld_flag_acquire
(
peer_barrier_d
)
!=
flag
)
{
}
}
}
else
{
}
else
{
st_flag_volatile
(
flag
,
signals
[
tidx
]
+
flag_block_offset
+
local_rank
);
while
(
ld_flag_volatile
(
peer_barrier_d
)
!=
flag
)
{
while
(
ld_flag_volatile
(
peer_barrier_d
)
!=
flag
)
{
}
}
}
}
...
@@ -165,7 +159,7 @@ __inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag
...
@@ -165,7 +159,7 @@ __inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag
__syncthreads
();
__syncthreads
();
}
}
template
<
typename
T
,
int
RANKS_PER_NODE
>
/*
COPY_INPUT =
false, PUSH_MODE = false */
template
<
typename
T
,
int
RANKS_PER_NODE
,
bool
COPY_INPUT
=
true
>
static
__global__
void
oneShotAllReduceKernel
(
AllReduceParams
params
)
{
static
__global__
void
oneShotAllReduceKernel
(
AllReduceParams
params
)
{
// Suppose that two GPUs participate in the AR exchange, and we start four blocks.
// Suppose that two GPUs participate in the AR exchange, and we start four blocks.
// The message is partitioned into chunks as detailed below:
// The message is partitioned into chunks as detailed below:
...
@@ -193,6 +187,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
...
@@ -193,6 +187,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
int
const
bidx
=
blockIdx
.
x
;
int
const
bidx
=
blockIdx
.
x
;
int
const
tidx
=
threadIdx
.
x
;
int
const
tidx
=
threadIdx
.
x
;
int
const
grid_size
=
gridDim
.
x
;
// The number of elements packed into one for comms
// The number of elements packed into one for comms
static
constexpr
int
NUM_ELTS
=
16
/
sizeof
(
T
);
static
constexpr
int
NUM_ELTS
=
16
/
sizeof
(
T
);
...
@@ -201,18 +196,23 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
...
@@ -201,18 +196,23 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
using
PackedStruct
=
typename
PackedOn16Bytes
<
T
>::
Type
;
using
PackedStruct
=
typename
PackedOn16Bytes
<
T
>::
Type
;
// The source pointers. Distributed round-robin for the different warps.
// The source pointers. Distributed round-robin for the different warps.
T
const
*
buffers
[
RANKS_PER_NODE
]
;
auto
peer_comm_buffer_ptrs
=
params
.
peer_comm_buffer_ptrs
->
ptrs
;
T
*
local_shared_buffer
=
reinterpret_cast
<
T
*>
(
peer_comm_buffer_ptrs
[
params
.
local_rank
]);
// Start and end offsets of the thread
// Start and end offsets of the thread
size_t
chunk_start
=
bidx
*
params
.
elts_per_block
+
tidx
*
NUM_ELTS
;
size_t
chunk_start
=
bidx
*
params
.
elts_per_block
+
tidx
*
NUM_ELTS
;
size_t
chunk_end
=
std
::
min
((
bidx
+
1
)
*
params
.
elts_per_block
,
params
.
elts_per_rank
);
size_t
chunk_end
=
std
::
min
((
bidx
+
1
)
*
params
.
elts_per_block
,
params
.
elts_per_rank
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
RANKS_PER_NODE
;
++
ii
)
{
int
rank
=
(
params
.
local_rank
+
ii
)
%
RANKS_PER_NODE
;
buffers
[
ii
]
=
reinterpret_cast
<
T
*>
(
params
.
peer_comm_buffer_ptrs
[
rank
]);
}
multi_gpu_barrier
(
params
.
peer_barrier_ptrs_in
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
);
if
constexpr
(
COPY_INPUT
)
{
T
const
*
local_input_buffer
=
reinterpret_cast
<
T
const
*>
(
params
.
local_input_buffer_ptr
);
// Copy from local buffer to shareable buffer
for
(
size_t
iter_offset
=
chunk_start
;
iter_offset
<
chunk_end
;
iter_offset
+=
blockDim
.
x
*
NUM_ELTS
)
{
*
reinterpret_cast
<
int4
*>
(
&
local_shared_buffer
[
iter_offset
])
=
*
reinterpret_cast
<
int4
const
*>
(
&
local_input_buffer
[
iter_offset
]);
}
}
// wait for equivalent blocks of other GPUs to have copied data to their shareable buffer
block_barrier
<
true
>
(
params
.
peer_barrier_ptrs_in
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
,
grid_size
);
// Each block accumulates the values from the different GPUs on the same node.
// Each block accumulates the values from the different GPUs on the same node.
for
(
size_t
iter_offset
=
chunk_start
;
iter_offset
<
chunk_end
;
iter_offset
+=
blockDim
.
x
*
NUM_ELTS
)
{
for
(
size_t
iter_offset
=
chunk_start
;
iter_offset
<
chunk_end
;
iter_offset
+=
blockDim
.
x
*
NUM_ELTS
)
{
...
@@ -220,7 +220,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
...
@@ -220,7 +220,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
PackedStruct
vals
[
RANKS_PER_NODE
];
PackedStruct
vals
[
RANKS_PER_NODE
];
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
RANKS_PER_NODE
;
++
ii
)
{
for
(
int
ii
=
0
;
ii
<
RANKS_PER_NODE
;
++
ii
)
{
vals
[
ii
].
packed
=
*
reinterpret_cast
<
int4
const
*>
(
&
buffers
[
ii
][
iter_offset
]);
vals
[
ii
].
packed
=
*
reinterpret_cast
<
int4
const
*>
(
&
((
T
*
)
peer_comm_
buffer
_ptr
s
[
ii
]
)
[
iter_offset
]);
}
}
// Sum the values from the different ranks.
// Sum the values from the different ranks.
...
@@ -229,8 +229,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
...
@@ -229,8 +229,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
#pragma unroll
#pragma unroll
for
(
int
rank
=
0
;
rank
<
RANKS_PER_NODE
;
++
rank
)
{
for
(
int
rank
=
0
;
rank
<
RANKS_PER_NODE
;
++
rank
)
{
// Always reduce from rank 0 to ensure stable reduce order.
// Always reduce from rank 0 to ensure stable reduce order.
int
ii
=
(
rank
+
RANKS_PER_NODE
-
params
.
local_rank
)
%
RANKS_PER_NODE
;
sums
.
packed
=
add128b
(
sums
,
vals
[
rank
]);
sums
.
packed
=
add128b
(
sums
,
vals
[
ii
]);
}
}
// Store to the destination buffer.
// Store to the destination buffer.
...
@@ -238,7 +237,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
...
@@ -238,7 +237,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
}
}
}
}
template
<
typename
T
,
int
RANKS_PER_NODE
>
template
<
typename
T
,
int
RANKS_PER_NODE
,
bool
COPY_INPUT
=
true
>
static
__global__
void
__launch_bounds__
(
512
,
1
)
twoShotAllReduceKernel
(
AllReduceParams
params
)
{
static
__global__
void
__launch_bounds__
(
512
,
1
)
twoShotAllReduceKernel
(
AllReduceParams
params
)
{
// Suppose that two GPUs participate in the AR exchange, and we start two blocks.
// Suppose that two GPUs participate in the AR exchange, and we start two blocks.
// The message is partitioned into chunks as detailed below:
// The message is partitioned into chunks as detailed below:
...
@@ -286,20 +285,24 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
...
@@ -286,20 +285,24 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
static
constexpr
int
PACKED_ELTS
=
16
/
sizeof
(
T
);
static
constexpr
int
PACKED_ELTS
=
16
/
sizeof
(
T
);
using
PackedType
=
typename
PackedOn16Bytes
<
T
>::
Type
;
using
PackedType
=
typename
PackedOn16Bytes
<
T
>::
Type
;
T
*
local_shared_buffer
=
reinterpret_cast
<
T
*>
(
params
.
peer_comm_buffer_ptrs
[
params
.
local_rank
]);
T
const
*
local_input_buffer
=
reinterpret_cast
<
T
const
*>
(
params
.
local_input_buffer_ptr
);
auto
peer_comm_buffer_ptrs
=
params
.
peer_comm_buffer_ptrs
->
ptrs
;
T
*
local_shared_buffer
=
reinterpret_cast
<
T
*>
(
peer_comm_buffer_ptrs
[
params
.
local_rank
]);
T
*
local_output_buffer
=
reinterpret_cast
<
T
*>
(
params
.
local_output_buffer_ptr
);
T
*
local_output_buffer
=
reinterpret_cast
<
T
*>
(
params
.
local_output_buffer_ptr
);
size_t
const
chunk_start
=
bidx
*
params
.
elts_per_block
+
tidx
*
PACKED_ELTS
;
size_t
const
chunk_start
=
bidx
*
params
.
elts_per_block
+
tidx
*
PACKED_ELTS
;
size_t
const
chunk_end
=
min
(
chunk_start
+
params
.
elts_per_block
,
params
.
elts_per_rank
);
size_t
const
chunk_end
=
min
(
chunk_start
+
params
.
elts_per_block
,
params
.
elts_per_rank
);
T
*
buffers
[
RANKS_PER_NODE
];
T
*
buffers
[
RANKS_PER_NODE
];
T
*
buffers_unorder
[
RANKS_PER_NODE
];
int
ranks
[
RANKS_PER_NODE
];
int
ranks
[
RANKS_PER_NODE
];
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
RANKS_PER_NODE
;
++
ii
)
{
for
(
int
ii
=
0
;
ii
<
RANKS_PER_NODE
;
++
ii
)
{
// A mapping of the ranks to scatter reads as much as possible
// A mapping of the ranks to scatter reads as much as possible
int
rank
=
(
params
.
local_rank
+
ii
)
%
RANKS_PER_NODE
;
int
rank
=
(
params
.
local_rank
+
ii
)
%
RANKS_PER_NODE
;
ranks
[
ii
]
=
rank
;
ranks
[
ii
]
=
rank
;
buffers
[
ii
]
=
reinterpret_cast
<
T
*>
(
params
.
peer_comm_buffer_ptrs
[
rank
]);
buffers
[
ii
]
=
reinterpret_cast
<
T
*>
(
peer_comm_buffer_ptrs
[
rank
]);
buffers_unorder
[
ii
]
=
reinterpret_cast
<
T
*>
(
peer_comm_buffer_ptrs
[
ii
]);
}
}
#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12))
#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12))
...
@@ -308,7 +311,21 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
...
@@ -308,7 +311,21 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
#endif
#endif
#endif
#endif
block_barrier
(
params
.
peer_barrier_ptrs_in
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
,
if
constexpr
(
COPY_INPUT
)
{
// Copy all blocks from local buffer to shareable buffer
for
(
size_t
local_offset
=
chunk_start
;
local_offset
<
chunk_end
;
local_offset
+=
blockDim
.
x
*
PACKED_ELTS
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
RANKS_PER_NODE
;
++
ii
)
{
size_t
offset_rank
=
ranks
[
ii
]
*
params
.
elts_per_rank
+
local_offset
;
if
(
offset_rank
>=
params
.
elts_total
)
{
continue
;
}
*
reinterpret_cast
<
int4
*>
(
&
local_shared_buffer
[
offset_rank
])
=
*
reinterpret_cast
<
int4
const
*>
(
&
local_input_buffer
[
offset_rank
]);
}
}
}
block_barrier
<
true
>
(
params
.
peer_barrier_ptrs_in
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
,
grid_size
);
grid_size
);
// Each block accumulates the values from the different GPUs on the same node.
// Each block accumulates the values from the different GPUs on the same node.
...
@@ -319,7 +336,7 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
...
@@ -319,7 +336,7 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
PackedType
vals
[
RANKS_PER_NODE
];
PackedType
vals
[
RANKS_PER_NODE
];
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
RANKS_PER_NODE
;
++
ii
)
{
for
(
int
ii
=
0
;
ii
<
RANKS_PER_NODE
;
++
ii
)
{
vals
[
ii
].
packed
=
*
reinterpret_cast
<
int4
const
*>
(
&
buffers
[
ii
][
responsible_block_offset
]);
vals
[
ii
].
packed
=
*
reinterpret_cast
<
int4
const
*>
(
&
buffers
_unorder
[
ii
][
responsible_block_offset
]);
}
}
// Sum the values from the different ranks.
// Sum the values from the different ranks.
...
@@ -328,16 +345,19 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
...
@@ -328,16 +345,19 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
#pragma unroll
#pragma unroll
for
(
int
rank
=
0
;
rank
<
RANKS_PER_NODE
;
++
rank
)
{
for
(
int
rank
=
0
;
rank
<
RANKS_PER_NODE
;
++
rank
)
{
// Always reduce from rank 0 to ensure stable reduce order.
// Always reduce from rank 0 to ensure stable reduce order.
int
ii
=
(
rank
+
RANKS_PER_NODE
-
params
.
local_rank
)
%
RANKS_PER_NODE
;
sums
.
packed
=
add128b
(
sums
,
vals
[
rank
]);
sums
.
packed
=
add128b
(
sums
,
vals
[
ii
]);
}
}
// Store to the local buffer.
// Store to the local buffer or tmp buffer
if
constexpr
(
COPY_INPUT
)
{
*
reinterpret_cast
<
int4
*>
(
&
local_shared_buffer
[
responsible_block_offset
])
=
sums
.
packed
;
*
reinterpret_cast
<
int4
*>
(
&
local_shared_buffer
[
responsible_block_offset
])
=
sums
.
packed
;
}
else
{
*
reinterpret_cast
<
int4
*>
(
&
params
.
tmp_result_buffers
[
params
.
local_rank
][
responsible_block_offset
])
=
sums
.
packed
;
}
}
}
block_barrier
(
params
.
peer_barrier_ptrs_out
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
,
block_barrier
<
false
,
true
>
(
params
.
peer_barrier_ptrs_out
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
grid_size
,
false
,
tru
e
);
bidx
,
grid_siz
e
);
// Gather all needed elts from other intra-node ranks
// Gather all needed elts from other intra-node ranks
for
(
size_t
local_offset
=
chunk_start
;
local_offset
<
chunk_end
;
local_offset
+=
blockDim
.
x
*
PACKED_ELTS
)
{
for
(
size_t
local_offset
=
chunk_start
;
local_offset
<
chunk_end
;
local_offset
+=
blockDim
.
x
*
PACKED_ELTS
)
{
...
@@ -348,8 +368,13 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
...
@@ -348,8 +368,13 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
if
(
offset_rank
>=
params
.
elts_total
)
{
if
(
offset_rank
>=
params
.
elts_total
)
{
continue
;
continue
;
}
}
if
constexpr
(
COPY_INPUT
)
{
*
reinterpret_cast
<
int4
*>
(
&
local_output_buffer
[
offset_rank
])
=
*
reinterpret_cast
<
int4
*>
(
&
buffers
[
ii
][
offset_rank
]);
*
reinterpret_cast
<
int4
*>
(
&
local_output_buffer
[
offset_rank
])
=
*
reinterpret_cast
<
int4
*>
(
&
buffers
[
ii
][
offset_rank
]);
}
else
{
*
reinterpret_cast
<
int4
*>
(
&
local_output_buffer
[
offset_rank
])
=
*
reinterpret_cast
<
int4
*>
(
&
params
.
tmp_result_buffers
[
ranks
[
ii
]][
offset_rank
]);
}
}
}
}
}
#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12))
#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12))
...
@@ -417,48 +442,50 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
...
@@ -417,48 +442,50 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
RANKS_PER_NODE
>
template
<
typename
T
,
int
RANKS_PER_NODE
,
bool
COPY_INPUT
>
void
dispatchARKernels
(
AllReduceStrategyType
algo
,
AllReduceParams
&
param
,
int
blocks_per_grid
,
int
threads_per_block
,
void
dispatchARKernels
(
AllReduceStrategyType
algo
,
AllReduceParams
&
param
,
int
blocks_per_grid
,
int
threads_per_block
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
switch
(
algo
)
{
switch
(
algo
)
{
case
AllReduceStrategyType
::
ONESHOT
:
{
case
AllReduceStrategyType
::
ONESHOT
:
{
oneShotAllReduceKernel
<
T
,
RANKS_PER_NODE
><<<
blocks_per_grid
,
threads_per_block
,
0
,
stream
>>>
(
param
);
oneShotAllReduceKernel
<
T
,
RANKS_PER_NODE
,
COPY_INPUT
><<<
blocks_per_grid
,
threads_per_block
,
0
,
stream
>>>
(
param
);
break
;
break
;
}
}
case
AllReduceStrategyType
::
TWOSHOT
:
{
case
AllReduceStrategyType
::
TWOSHOT
:
{
twoShotAllReduceKernel
<
T
,
RANKS_PER_NODE
><<<
blocks_per_grid
,
threads_per_block
,
0
,
stream
>>>
(
param
);
twoShotAllReduceKernel
<
T
,
RANKS_PER_NODE
,
COPY_INPUT
><<<
blocks_per_grid
,
threads_per_block
,
0
,
stream
>>>
(
param
);
break
;
break
;
}
}
}
}
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
COPY_INPUT
>
void
invokeOneOrTwoShotAllReduceKernel
(
AllReduceParams
&
param
,
AllReduceStrategyType
strat
,
cudaStream_t
stream
)
{
void
dispatchARKernelsCopyInput
(
AllReduceStrategyType
strat
,
AllReduceParams
&
param
,
cudaStream_t
stream
)
{
void
*
buffer
=
reinterpret_cast
<
void
*>
(
param
.
peer_comm_buffer_ptrs
[
param
.
rank
]);
void
*
local_inp_buffer
=
param
.
local_input_buffer_ptr
;
CHECK_CUDA_SUCCESS
(
cudaMemcpyAsync
(
buffer
,
local_inp_buffer
,
param
.
elts_total
*
param
.
elts_size
,
cudaMemcpyDeviceToDevice
,
stream
));
CHECK_CUDA_SUCCESS
(
cudaGetLastError
());
size_t
elts_per_thread
=
16
/
sizeof
(
T
);
size_t
elts_per_thread
=
16
/
sizeof
(
T
);
auto
[
blocks_per_grid
,
threads_per_block
]
=
kernelLaunchConfig
(
strat
,
param
,
elts_per_thread
);
auto
[
blocks_per_grid
,
threads_per_block
]
=
kernelLaunchConfig
(
strat
,
param
,
elts_per_thread
);
switch
(
param
.
ranks_per_node
)
{
switch
(
param
.
ranks_per_node
)
{
case
2
:
case
2
:
dispatchARKernels
<
T
,
2
>
(
strat
,
param
,
blocks_per_grid
,
threads_per_block
,
stream
);
dispatchARKernels
<
T
,
2
,
COPY_INPUT
>
(
strat
,
param
,
blocks_per_grid
,
threads_per_block
,
stream
);
break
;
break
;
case
4
:
case
4
:
dispatchARKernels
<
T
,
4
>
(
strat
,
param
,
blocks_per_grid
,
threads_per_block
,
stream
);
dispatchARKernels
<
T
,
4
,
COPY_INPUT
>
(
strat
,
param
,
blocks_per_grid
,
threads_per_block
,
stream
);
break
;
break
;
case
6
:
case
6
:
dispatchARKernels
<
T
,
6
>
(
strat
,
param
,
blocks_per_grid
,
threads_per_block
,
stream
);
dispatchARKernels
<
T
,
6
,
COPY_INPUT
>
(
strat
,
param
,
blocks_per_grid
,
threads_per_block
,
stream
);
break
;
break
;
case
8
:
case
8
:
dispatchARKernels
<
T
,
8
>
(
strat
,
param
,
blocks_per_grid
,
threads_per_block
,
stream
);
dispatchARKernels
<
T
,
8
,
COPY_INPUT
>
(
strat
,
param
,
blocks_per_grid
,
threads_per_block
,
stream
);
break
;
break
;
default:
default:
break
;
break
;
}
}
}
template
<
typename
T
>
void
invokeOneOrTwoShotAllReduceKernel
(
AllReduceParams
&
param
,
AllReduceStrategyType
strat
,
cudaStream_t
stream
)
{
if
(
param
.
is_capturing
)
{
dispatchARKernelsCopyInput
<
T
,
false
>
(
strat
,
param
,
stream
);
}
else
{
dispatchARKernelsCopyInput
<
T
,
true
>
(
strat
,
param
,
stream
);
}
CHECK_CUDA_SUCCESS
(
cudaGetLastError
());
CHECK_CUDA_SUCCESS
(
cudaGetLastError
());
}
}
...
...
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh
View file @
6cb3974e
...
@@ -36,6 +36,10 @@ enum class AllReduceStrategyType : int8_t {
...
@@ -36,6 +36,10 @@ enum class AllReduceStrategyType : int8_t {
AUTO
=
3
,
AUTO
=
3
,
};
};
struct
RankData
{
void
*
ptrs
[
MAX_RANKS_PER_NODE
];
};
struct
AllReduceParams
{
struct
AllReduceParams
{
size_t
elts_size
;
size_t
elts_size
;
size_t
elts_total
;
size_t
elts_total
;
...
@@ -46,9 +50,11 @@ struct AllReduceParams {
...
@@ -46,9 +50,11 @@ struct AllReduceParams {
uint32_t
barrier_flag
;
uint32_t
barrier_flag
;
uint32_t
*
peer_barrier_ptrs_in
[
MAX_RANKS_PER_NODE
];
uint32_t
*
peer_barrier_ptrs_in
[
MAX_RANKS_PER_NODE
];
uint32_t
*
peer_barrier_ptrs_out
[
MAX_RANKS_PER_NODE
];
uint32_t
*
peer_barrier_ptrs_out
[
MAX_RANKS_PER_NODE
];
void
*
peer_comm_buffer_ptrs
[
MAX_RANKS_PER_NODE
];
uint32_t
*
tmp_result_buffers
[
MAX_RANKS_PER_NODE
];
RankData
*
peer_comm_buffer_ptrs
;
void
*
local_input_buffer_ptr
;
void
*
local_input_buffer_ptr
;
void
*
local_output_buffer_ptr
;
void
*
local_output_buffer_ptr
;
bool
is_capturing
;
};
};
inline
size_t
GetMaxRequiredWorkspaceSize
(
int
world_size
)
{
inline
size_t
GetMaxRequiredWorkspaceSize
(
int
world_size
)
{
...
...
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu
View file @
6cb3974e
...
@@ -12,25 +12,46 @@
...
@@ -12,25 +12,46 @@
using
namespace
trt_llm
;
using
namespace
trt_llm
;
using
fptr_t
=
int64_t
;
using
fptr_t
=
int64_t
;
using
IPC_KEY
=
std
::
array
<
uint8_t
,
sizeof
(
cudaIpcMemHandle_t
)
>
;
class
AllReduceMeta
{
class
AllReduceMeta
{
public:
public:
AllReduceMeta
(
int64_t
rank_id
,
int64_t
world_size
,
const
std
::
vector
<
fptr_t
>&
buffers
,
AllReduceMeta
(
int64_t
rank_id
,
int64_t
world_size
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
fptr_t
>&
buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
const
std
::
vector
<
fptr_t
>&
barrier_out
)
{
const
std
::
vector
<
fptr_t
>&
tmp_result_buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
const
std
::
vector
<
fptr_t
>&
barrier_out
)
{
this
->
rank_id
=
(
int
)
rank_id
;
this
->
rank_id
=
(
int
)
rank_id
;
this
->
world_size
=
(
int
)
world_size
;
this
->
world_size
=
(
int
)
world_size
;
this
->
buffers
=
buffers
;
this
->
barrier_in
=
barrier_in
;
this
->
barrier_in
=
barrier_in
;
this
->
barrier_out
=
barrier_out
;
this
->
barrier_out
=
barrier_out
;
this
->
tmp_result_buffers
=
tmp_result_buffers
;
this
->
rank_data_base
=
reinterpret_cast
<
RankData
*>
(
rank_data
.
data_ptr
());
RankData
data
;
for
(
int
i
=
0
;
i
<
world_size
;
i
++
)
{
data
.
ptrs
[
i
]
=
(
void
*
)
buffers
[
i
];
}
auto
d_data
=
this
->
rank_data_base
++
;
CHECK_CUDA_SUCCESS
(
cudaMemcpy
(
d_data
,
&
data
,
sizeof
(
RankData
),
cudaMemcpyHostToDevice
));
this
->
buffers
=
d_data
;
}
~
AllReduceMeta
()
{
for
(
auto
[
_
,
ptr
]
:
ipc_handles_
)
{
CHECK_CUDA_SUCCESS
(
cudaIpcCloseMemHandle
(
ptr
));
}
}
}
public:
public:
int
world_size
;
int
world_size
;
int
rank_id
;
int
rank_id
;
std
::
vector
<
fptr_t
>
buffers
;
std
::
vector
<
fptr_t
>
barrier_in
;
std
::
vector
<
fptr_t
>
barrier_in
;
std
::
vector
<
fptr_t
>
barrier_out
;
std
::
vector
<
fptr_t
>
barrier_out
;
std
::
vector
<
fptr_t
>
tmp_result_buffers
;
int
barrier_flag
=
1
;
int
barrier_flag
=
1
;
RankData
*
buffers
;
RankData
*
rank_data_base
;
std
::
vector
<
void
*>
graph_unreg_buffers
;
std
::
map
<
IPC_KEY
,
char
*>
ipc_handles_
;
};
};
// Get the number of bits for a given data type.
// Get the number of bits for a given data type.
...
@@ -52,9 +73,10 @@ inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype)
...
@@ -52,9 +73,10 @@ inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype)
return
num_elements
%
(
16
/
((
get_bits
(
dtype
)
+
7
)
/
8
))
==
0
;
return
num_elements
%
(
16
/
((
get_bits
(
dtype
)
+
7
)
/
8
))
==
0
;
}
}
fptr_t
init_custom_ar
(
int64_t
rank_id
,
int64_t
world_size
,
const
std
::
vector
<
fptr_t
>&
buffers
,
fptr_t
init_custom_ar
(
int64_t
rank_id
,
int64_t
world_size
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
fptr_t
>&
buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
const
std
::
vector
<
fptr_t
>&
barrier_out
)
{
const
std
::
vector
<
fptr_t
>&
tmp_result_buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
auto
m
=
new
AllReduceMeta
(
rank_id
,
world_size
,
buffers
,
barrier_in
,
barrier_out
);
const
std
::
vector
<
fptr_t
>&
barrier_out
)
{
auto
m
=
new
AllReduceMeta
(
rank_id
,
world_size
,
rank_data
,
buffers
,
tmp_result_buffers
,
barrier_in
,
barrier_out
);
return
(
fptr_t
)
m
;
return
(
fptr_t
)
m
;
}
}
...
@@ -63,6 +85,75 @@ void dispose(fptr_t _fa) {
...
@@ -63,6 +85,75 @@ void dispose(fptr_t _fa) {
delete
fa
;
delete
fa
;
}
}
std
::
tuple
<
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
)
{
AllReduceMeta
*
m
=
reinterpret_cast
<
AllReduceMeta
*>
(
_fa
);
auto
num_buffers
=
m
->
graph_unreg_buffers
.
size
();
auto
handle_sz
=
sizeof
(
cudaIpcMemHandle_t
);
std
::
string
handles
(
handle_sz
*
num_buffers
,
static_cast
<
char
>
(
0
));
std
::
vector
<
int64_t
>
offsets
(
num_buffers
);
for
(
int
i
=
0
;
i
<
num_buffers
;
i
++
)
{
auto
ptr
=
m
->
graph_unreg_buffers
[
i
];
void
*
base_ptr
;
// note: must share the base address of each allocation, or we get wrong
// address
if
(
cuPointerGetAttribute
(
&
base_ptr
,
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR
,
(
CUdeviceptr
)
ptr
)
!=
CUDA_SUCCESS
)
{
assert
(
false
&&
"failed to get pointer attr"
);
}
CHECK_CUDA_SUCCESS
(
cudaIpcGetMemHandle
((
cudaIpcMemHandle_t
*
)
&
handles
[
i
*
handle_sz
],
base_ptr
));
offsets
[
i
]
=
((
char
*
)
ptr
)
-
((
char
*
)
base_ptr
);
}
std
::
vector
<
int64_t
>
bytes
(
handles
.
begin
(),
handles
.
end
());
return
std
::
make_pair
(
bytes
,
offsets
);
}
char
*
open_ipc_handle
(
AllReduceMeta
*
meta
,
const
void
*
ipc_handle
)
{
auto
[
it
,
new_handle
]
=
meta
->
ipc_handles_
.
insert
({
*
((
IPC_KEY
*
)
ipc_handle
),
nullptr
});
if
(
new_handle
)
{
char
*
ipc_ptr
;
CHECK_CUDA_SUCCESS
(
cudaIpcOpenMemHandle
((
void
**
)
&
ipc_ptr
,
*
((
const
cudaIpcMemHandle_t
*
)
ipc_handle
),
cudaIpcMemLazyEnablePeerAccess
));
it
->
second
=
ipc_ptr
;
}
return
it
->
second
;
}
// Note: when registering graph buffers, we intentionally choose to not
// deduplicate the addresses. That means if the allocator reuses some
// addresses, they will be registered again. This is to account for the remote
// possibility of different allocation patterns between ranks. For example,
// rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
AllReduceMeta
*
m
=
reinterpret_cast
<
AllReduceMeta
*>
(
_fa
);
std
::
vector
<
std
::
string
>
handle_bytes
;
handle_bytes
.
reserve
(
handles
.
size
());
for
(
int
i
=
0
;
i
<
handles
.
size
();
i
++
)
{
handle_bytes
.
emplace_back
(
handles
[
i
].
begin
(),
handles
[
i
].
end
());
}
auto
num_buffers
=
m
->
graph_unreg_buffers
.
size
();
std
::
vector
<
RankData
>
rank_data
(
num_buffers
);
for
(
int
i
=
0
;
i
<
num_buffers
;
i
++
)
{
auto
self_ptr
=
m
->
graph_unreg_buffers
[
i
];
auto
&
rd
=
rank_data
[
i
];
for
(
int
j
=
0
;
j
<
m
->
world_size
;
j
++
)
{
if
(
j
!=
m
->
rank_id
)
{
char
*
handle
=
open_ipc_handle
(
m
,
&
handle_bytes
[
j
][
i
*
sizeof
(
cudaIpcMemHandle_t
)]);
handle
+=
offsets
[
j
][
i
];
rd
.
ptrs
[
j
]
=
handle
;
}
else
{
rd
.
ptrs
[
j
]
=
self_ptr
;
}
}
}
CHECK_CUDA_SUCCESS
(
cudaMemcpy
(
m
->
rank_data_base
,
rank_data
.
data
(),
sizeof
(
RankData
)
*
num_buffers
,
cudaMemcpyHostToDevice
));
m
->
rank_data_base
+=
num_buffers
;
m
->
graph_unreg_buffers
.
clear
();
}
void
all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
)
{
void
all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
)
{
AllReduceMeta
*
m
=
reinterpret_cast
<
AllReduceMeta
*>
(
_fa
);
AllReduceMeta
*
m
=
reinterpret_cast
<
AllReduceMeta
*>
(
_fa
);
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
...
@@ -87,8 +178,18 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
...
@@ -87,8 +178,18 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
params
.
elts_size
=
inp
.
element_size
();
params
.
elts_size
=
inp
.
element_size
();
params
.
barrier_flag
=
++
(
m
->
barrier_flag
);
params
.
barrier_flag
=
++
(
m
->
barrier_flag
);
cudaStreamCaptureStatus
status
;
CHECK_CUDA_SUCCESS
(
cudaStreamIsCapturing
(
stream
,
&
status
));
params
.
is_capturing
=
(
status
==
cudaStreamCaptureStatusActive
);
if
(
params
.
is_capturing
)
{
params
.
peer_comm_buffer_ptrs
=
m
->
rank_data_base
+
m
->
graph_unreg_buffers
.
size
();
m
->
graph_unreg_buffers
.
push_back
(
params
.
local_input_buffer_ptr
);
}
else
{
params
.
peer_comm_buffer_ptrs
=
m
->
buffers
;
}
for
(
int
i
=
0
;
i
<
world_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
world_size
;
++
i
)
{
params
.
peer_comm
_buffer
_ptr
s
[
i
]
=
reinterpret_cast
<
void
*>
(
m
->
buffers
[
i
]);
params
.
tmp_result
_buffers
[
i
]
=
reinterpret_cast
<
uint32_t
*>
(
m
->
tmp_result_
buffers
[
i
]);
}
}
for
(
int
i
=
0
;
i
<
world_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
world_size
;
++
i
)
{
params
.
peer_barrier_ptrs_in
[
i
]
=
reinterpret_cast
<
uint32_t
*>
(
m
->
barrier_in
[
i
]);
params
.
peer_barrier_ptrs_in
[
i
]
=
reinterpret_cast
<
uint32_t
*>
(
m
->
barrier_in
[
i
]);
...
...
sgl-kernel/src/sgl-kernel/ops/__init__.py
View file @
6cb3974e
from
sgl_kernel.ops._kernels
import
all_reduce
as
_all_reduce
from
sgl_kernel.ops._kernels
import
all_reduce
as
_all_reduce
from
sgl_kernel.ops._kernels
import
dispose
as
_dispose
from
sgl_kernel.ops._kernels
import
dispose
as
_dispose
from
sgl_kernel.ops._kernels
import
(
get_graph_buffer_ipc_meta
as
_get_graph_buffer_ipc_meta
,
)
from
sgl_kernel.ops._kernels
import
init_custom_ar
as
_init_custom_ar
from
sgl_kernel.ops._kernels
import
init_custom_ar
as
_init_custom_ar
from
sgl_kernel.ops._kernels
import
int8_scaled_mm
as
_int8_scaled_mm
from
sgl_kernel.ops._kernels
import
int8_scaled_mm
as
_int8_scaled_mm
from
sgl_kernel.ops._kernels
import
moe_align_block_size
as
_moe_align_block_size
from
sgl_kernel.ops._kernels
import
moe_align_block_size
as
_moe_align_block_size
from
sgl_kernel.ops._kernels
import
register_graph_buffers
as
_register_graph_buffers
from
sgl_kernel.ops._kernels
import
(
from
sgl_kernel.ops._kernels
import
(
sampling_scaling_penalties
as
_sampling_scaling_penalties
,
sampling_scaling_penalties
as
_sampling_scaling_penalties
,
)
)
def
init_custom_reduce
(
rank_id
,
num_devices
,
buffers
,
barrier_in
,
barrier_out
):
def
init_custom_reduce
(
return
_init_custom_ar
(
rank_id
,
num_devices
,
buffers
,
barrier_in
,
barrier_out
)
rank_id
,
num_devices
,
rank_data
,
buffers
,
tmp_buffers
,
barrier_in
,
barrier_out
):
return
_init_custom_ar
(
rank_id
,
num_devices
,
rank_data
,
buffers
,
tmp_buffers
,
barrier_in
,
barrier_out
)
def
custom_dispose
(
fa
):
def
custom_dispose
(
fa
):
...
@@ -20,6 +28,14 @@ def custom_reduce(fa, inp, out):
...
@@ -20,6 +28,14 @@ def custom_reduce(fa, inp, out):
_all_reduce
(
fa
,
inp
,
out
)
_all_reduce
(
fa
,
inp
,
out
)
def
get_graph_buffer_ipc_meta
(
fa
):
return
_get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
,
handles
,
offsets
):
_register_graph_buffers
(
fa
,
handles
,
offsets
)
def
moe_align_block_size
(
def
moe_align_block_size
(
topk_ids
,
topk_ids
,
num_experts
,
num_experts
,
...
...
sgl-kernel/tests/test_trt_reduce.py
View file @
6cb3974e
...
@@ -10,6 +10,7 @@ from typing import Any, List, Optional, Union
...
@@ -10,6 +10,7 @@ from typing import Any, List, Optional, Union
import
ray
import
ray
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
sgl_kernel
import
ops
as
custom_ops
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
vllm
import
_custom_ops
as
vllm_ops
from
vllm
import
_custom_ops
as
vllm_ops
...
@@ -104,35 +105,38 @@ class TestCustomAllReduce(unittest.TestCase):
...
@@ -104,35 +105,38 @@ class TestCustomAllReduce(unittest.TestCase):
multi_process_parallel
(
world_size
,
self
,
self
.
performance
)
multi_process_parallel
(
world_size
,
self
,
self
.
performance
)
def
init_custom_allreduce
(
self
,
rank
,
world_size
,
group
):
def
init_custom_allreduce
(
self
,
rank
,
world_size
,
group
):
import
sgl_kernel
buffer_max_size
=
8
*
1024
*
1024
buffer_max_size
=
8
*
1024
*
1024
barrier_max_size
=
8
*
(
24
+
2
)
*
8
barrier_max_size
=
8
*
(
24
+
2
)
*
8
self
.
buffer_ptrs
=
self
.
create_shared_buffer
(
buffer_max_size
,
group
=
group
)
self
.
buffer_ptrs
=
self
.
create_shared_buffer
(
buffer_max_size
,
group
=
group
)
self
.
tmp_result_buffer_ptrs
=
self
.
create_shared_buffer
(
buffer_max_size
,
group
=
group
)
self
.
barrier_in_ptrs
=
self
.
create_shared_buffer
(
barrier_max_size
,
group
=
group
)
self
.
barrier_in_ptrs
=
self
.
create_shared_buffer
(
barrier_max_size
,
group
=
group
)
self
.
barrier_out_ptrs
=
self
.
create_shared_buffer
(
barrier_max_size
,
group
=
group
)
self
.
barrier_out_ptrs
=
self
.
create_shared_buffer
(
barrier_max_size
,
group
=
group
)
self
.
rank_data
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
)
self
.
custom_ptr
=
sgl_kernel
.
ops
.
init_custom_reduce
(
self
.
custom_ptr
=
custom_
ops
.
init_custom_reduce
(
rank
,
rank
,
world_size
,
world_size
,
self
.
rank_data
,
self
.
buffer_ptrs
,
self
.
buffer_ptrs
,
self
.
tmp_result_buffer_ptrs
,
self
.
barrier_in_ptrs
,
self
.
barrier_in_ptrs
,
self
.
barrier_out_ptrs
,
self
.
barrier_out_ptrs
,
)
)
def
custom_allreduce
(
self
,
inp
,
out
):
def
custom_allreduce
(
self
,
inp
,
out
):
import
sgl_kernel
custom_ops
.
custom_reduce
(
self
.
custom_ptr
,
inp
,
out
)
sgl_kernel
.
ops
.
custom_reduce
(
self
.
custom_ptr
,
inp
,
out
)
def
free_custom_allreduce
(
self
,
group
):
def
free_custom_allreduce
(
self
,
group
):
import
sgl_kernel
self
.
free_shared_buffer
(
self
.
buffer_ptrs
,
group
)
self
.
free_shared_buffer
(
self
.
buffer_ptrs
,
group
)
self
.
free_shared_buffer
(
self
.
tmp_result_buffer_ptrs
,
group
)
self
.
free_shared_buffer
(
self
.
barrier_in_ptrs
,
group
)
self
.
free_shared_buffer
(
self
.
barrier_in_ptrs
,
group
)
self
.
free_shared_buffer
(
self
.
barrier_out_ptrs
,
group
)
self
.
free_shared_buffer
(
self
.
barrier_out_ptrs
,
group
)
sgl_kernel
.
ops
.
custom_dispose
(
self
.
custom_ptr
)
custom_
ops
.
custom_dispose
(
self
.
custom_ptr
)
def
init_vllm_allreduce
(
self
,
rank
,
group
):
def
init_vllm_allreduce
(
self
,
rank
,
group
):
self
.
vllm_rank
=
rank
self
.
vllm_rank
=
rank
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment