Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bcbbf519
"src/kernel/cpu/binary_bcast_reduce_min.cc" did not exist on "da0c92a2280dc8dddd0cb478fc094a523ab4d441"
Unverified
Commit
bcbbf519
authored
Apr 06, 2025
by
Yi Zhang
Committed by
GitHub
Apr 05, 2025
Browse files
sgl-kernel transfer custom allreduce from trt kernel to vllm kernel (#5079)
parent
0d99adb7
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
692 additions
and
937 deletions
+692
-937
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+1
-2
sgl-kernel/csrc/allreduce/custom_all_reduce.cu
sgl-kernel/csrc/allreduce/custom_all_reduce.cu
+137
-0
sgl-kernel/csrc/allreduce/custom_all_reduce.cuh
sgl-kernel/csrc/allreduce/custom_all_reduce.cuh
+489
-0
sgl-kernel/csrc/allreduce/trt_reduce_internal.cu
sgl-kernel/csrc/allreduce/trt_reduce_internal.cu
+0
-532
sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu
sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu
+0
-226
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+7
-4
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+7
-10
sgl-kernel/include/trt_reduce_internal.cuh
sgl-kernel/include/trt_reduce_internal.cuh
+0
-109
sgl-kernel/python/sgl_kernel/allreduce.py
sgl-kernel/python/sgl_kernel/allreduce.py
+26
-16
sgl-kernel/tests/test_custom_allreduce.py
sgl-kernel/tests/test_custom_allreduce.py
+25
-38
No files found.
sgl-kernel/CMakeLists.txt
View file @
bcbbf519
...
@@ -157,8 +157,7 @@ string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE
...
@@ -157,8 +157,7 @@ string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE
string
(
REPLACE
"-D__CUDA_NO_HALF2_OPERATORS__"
""
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
"
)
string
(
REPLACE
"-D__CUDA_NO_HALF2_OPERATORS__"
""
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
"
)
set
(
SOURCES
set
(
SOURCES
"csrc/allreduce/trt_reduce_internal.cu"
"csrc/allreduce/custom_all_reduce.cu"
"csrc/allreduce/trt_reduce_kernel.cu"
"csrc/attention/lightning_attention_decode_kernel.cu"
"csrc/attention/lightning_attention_decode_kernel.cu"
"csrc/elementwise/activation.cu"
"csrc/elementwise/activation.cu"
"csrc/elementwise/fused_add_rms_norm_kernel.cu"
"csrc/elementwise/fused_add_rms_norm_kernel.cu"
...
...
sgl-kernel/csrc/allreduce/custom_all_reduce.cu
0 → 100644
View file @
bcbbf519
// Adapted from: https://github.com/vllm-project/vllm/blob/v0.8.2/csrc/custom_all_reduce.cu
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/all.h>
#include "custom_all_reduce.cuh"
// Fake pointer type, must match fptr_t type in ops.h.
// We use this type alias to indicate when pointers are passed in as int64_t.
using
fptr_t
=
int64_t
;
static_assert
(
sizeof
(
void
*
)
==
sizeof
(
fptr_t
));
fptr_t
init_custom_ar
(
const
std
::
vector
<
fptr_t
>&
fake_ipc_ptrs
,
torch
::
Tensor
&
rank_data
,
int64_t
rank
,
bool
full_nvlink
)
{
int
world_size
=
fake_ipc_ptrs
.
size
();
if
(
world_size
>
8
)
throw
std
::
invalid_argument
(
"world size > 8 is not supported"
);
if
(
world_size
%
2
!=
0
)
throw
std
::
invalid_argument
(
"Odd num gpus is not supported for now"
);
if
(
rank
<
0
||
rank
>=
world_size
)
throw
std
::
invalid_argument
(
"invalid rank passed in"
);
vllm
::
Signal
*
ipc_ptrs
[
8
];
for
(
int
i
=
0
;
i
<
world_size
;
i
++
)
{
ipc_ptrs
[
i
]
=
reinterpret_cast
<
vllm
::
Signal
*>
(
fake_ipc_ptrs
[
i
]);
}
return
(
fptr_t
)
new
vllm
::
CustomAllreduce
(
ipc_ptrs
,
rank_data
.
data_ptr
(),
rank_data
.
numel
(),
rank
,
world_size
,
full_nvlink
);
}
/**
* Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
* t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
* because it allows transpose of contiguous slice (i.e. slicing the first
* dimension). Currently, we require this because stride information is not
* passed into the kernels and we treat input tensors as flat.
*
* Examples
* A = torch.zeros(3, 3, 3)
* 1. A: OK
* 2. A[1:]: OK
* 3. A.permute(2, 0, 1): OK
* 4. A[1:].permute(2, 0, 1): OK
* 5. A[None].expand(2, -1, -1, -1): Not OK
* 6. A[:, 1:, 1:]: Not OK
*/
bool
_is_weak_contiguous
(
torch
::
Tensor
&
t
)
{
return
t
.
is_contiguous
()
||
(
t
.
storage
().
nbytes
()
-
t
.
storage_offset
()
*
t
.
element_size
()
==
t
.
numel
()
*
t
.
element_size
());
}
/**
* Performs an out-of-place allreduce and stores result in out.
*
* If _reg_buffer is null, assumes inp.data_ptr() is already IPC-registered.
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
* copied into _reg_buffer.
*/
void
all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
fptr_t
_reg_buffer
,
int64_t
reg_buffer_sz_bytes
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
inp
));
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
TORCH_CHECK_EQ
(
inp
.
scalar_type
(),
out
.
scalar_type
());
TORCH_CHECK_EQ
(
inp
.
numel
(),
out
.
numel
());
TORCH_CHECK
(
_is_weak_contiguous
(
out
));
TORCH_CHECK
(
_is_weak_contiguous
(
inp
));
auto
input_size
=
inp
.
numel
()
*
inp
.
element_size
();
auto
reg_buffer
=
reinterpret_cast
<
void
*>
(
_reg_buffer
);
if
(
reg_buffer
)
{
TORCH_CHECK_LE
(
input_size
,
reg_buffer_sz_bytes
);
AT_CUDA_CHECK
(
cudaMemcpyAsync
(
reg_buffer
,
inp
.
data_ptr
(),
input_size
,
cudaMemcpyDeviceToDevice
,
stream
));
}
else
{
reg_buffer
=
inp
.
data_ptr
();
}
switch
(
out
.
scalar_type
())
{
case
at
::
ScalarType
::
Float
:
{
fa
->
allreduce
<
float
>
(
stream
,
reinterpret_cast
<
float
*>
(
reg_buffer
),
reinterpret_cast
<
float
*>
(
out
.
data_ptr
()),
out
.
numel
());
break
;
}
case
at
::
ScalarType
::
Half
:
{
fa
->
allreduce
<
half
>
(
stream
,
reinterpret_cast
<
half
*>
(
reg_buffer
),
reinterpret_cast
<
half
*>
(
out
.
data_ptr
()),
out
.
numel
());
break
;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case
at
::
ScalarType
::
BFloat16
:
{
fa
->
allreduce
<
nv_bfloat16
>
(
stream
,
reinterpret_cast
<
nv_bfloat16
*>
(
reg_buffer
),
reinterpret_cast
<
nv_bfloat16
*>
(
out
.
data_ptr
()),
out
.
numel
());
break
;
}
#endif
default:
throw
std
::
runtime_error
(
"custom allreduce only supports float32, float16 and bfloat16"
);
}
}
void
dispose
(
fptr_t
_fa
)
{
delete
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
}
int64_t
meta_size
()
{
return
sizeof
(
vllm
::
Signal
);
}
void
register_buffer
(
fptr_t
_fa
,
const
std
::
vector
<
fptr_t
>&
fake_ipc_ptrs
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
TORCH_CHECK
(
fake_ipc_ptrs
.
size
()
==
fa
->
world_size_
);
void
*
ipc_ptrs
[
8
];
for
(
int
i
=
0
;
i
<
fake_ipc_ptrs
.
size
();
i
++
)
{
ipc_ptrs
[
i
]
=
reinterpret_cast
<
void
*>
(
fake_ipc_ptrs
[
i
]);
}
fa
->
register_buffer
(
ipc_ptrs
);
}
// Use vector<int64_t> to represent byte data for python binding compatibility.
std
::
tuple
<
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
auto
[
handle
,
offsets
]
=
fa
->
get_graph_buffer_ipc_meta
();
std
::
vector
<
int64_t
>
bytes
(
handle
.
begin
(),
handle
.
end
());
return
std
::
make_tuple
(
bytes
,
offsets
);
}
// Use vector<int64_t> to represent byte data for python binding compatibility.
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
std
::
vector
<
std
::
string
>
bytes
;
bytes
.
reserve
(
handles
.
size
());
for
(
int
i
=
0
;
i
<
handles
.
size
();
i
++
)
{
bytes
.
emplace_back
(
handles
[
i
].
begin
(),
handles
[
i
].
end
());
}
bytes
.
reserve
(
handles
.
size
());
fa
->
register_graph_buffers
(
bytes
,
offsets
);
}
sgl-kernel/csrc/allreduce/custom_all_reduce.cuh
0 → 100644
View file @
bcbbf519
// Adapted from https://github.com/vllm-project/vllm/blob/v0.8.2/csrc/custom_all_reduce.cuh
#pragma once
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <array>
#include <iostream>
#include <limits>
#include <map>
#include <unordered_map>
#include <vector>
#include "utils.h"
namespace
vllm
{
constexpr
int
kMaxBlocks
=
36
;
// Counter may overflow, but it's fine since unsigned int overflow is
// well-defined behavior.
using
FlagType
=
uint32_t
;
struct
Signal
{
alignas
(
128
)
FlagType
self_counter
[
kMaxBlocks
][
8
];
// Two sets of peer counters are needed for two syncs. The reason is that
// it's possible for peer GPU block to arrive at the second sync point while
// the current GPU block haven't passed the first sync point. Thus, peer GPU
// may write counter+1 while current GPU is busy waiting for counter. We use
// alternating counter array to avoid this possibility.
alignas
(
128
)
FlagType
peer_counter
[
2
][
kMaxBlocks
][
8
];
};
struct
__align__
(
16
)
RankData
{
const
void
*
__restrict__
ptrs
[
8
];
};
struct
__align__
(
16
)
RankSignals
{
Signal
*
signals
[
8
];
};
// like std::array, but aligned
template
<
typename
T
,
int
sz
>
struct
__align__
(
alignof
(
T
)
*
sz
)
array_t
{
T
data
[
sz
];
using
type
=
T
;
static
constexpr
int
size
=
sz
;
};
// use packed type to maximize memory efficiency
// goal: generate ld.128 and st.128 instructions
template
<
typename
T
>
struct
packed_t
{
// the (P)acked type for load/store
using
P
=
array_t
<
T
,
16
/
sizeof
(
T
)
>
;
// the (A)ccumulator type for reduction
using
A
=
array_t
<
float
,
16
/
sizeof
(
T
)
>
;
};
#define DINLINE __device__ __forceinline__
// scalar cast functions
DINLINE
float
upcast_s
(
half
val
)
{
return
__half2float
(
val
);
}
template
<
typename
T
>
DINLINE
T
downcast_s
(
float
val
);
template
<
>
DINLINE
half
downcast_s
(
float
val
)
{
return
__float2half
(
val
);
}
// scalar add functions
// for some reason when compiling with Pytorch, the + operator for half and
// bfloat is disabled so we call the intrinsics directly
DINLINE
half
&
assign_add
(
half
&
a
,
half
b
)
{
a
=
__hadd
(
a
,
b
);
return
a
;
}
DINLINE
float
&
assign_add
(
float
&
a
,
float
b
)
{
return
a
+=
b
;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
DINLINE
float
upcast_s
(
nv_bfloat16
val
)
{
return
__bfloat162float
(
val
);
}
template
<
>
DINLINE
nv_bfloat16
downcast_s
(
float
val
)
{
return
__float2bfloat16
(
val
);
}
DINLINE
nv_bfloat16
&
assign_add
(
nv_bfloat16
&
a
,
nv_bfloat16
b
)
{
a
=
__hadd
(
a
,
b
);
return
a
;
}
#endif
template
<
typename
T
,
int
N
>
DINLINE
array_t
<
T
,
N
>&
packed_assign_add
(
array_t
<
T
,
N
>&
a
,
array_t
<
T
,
N
>
b
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
assign_add
(
a
.
data
[
i
],
b
.
data
[
i
]);
}
return
a
;
}
template
<
typename
T
,
int
N
>
DINLINE
array_t
<
float
,
N
>
upcast
(
array_t
<
T
,
N
>
val
)
{
if
constexpr
(
std
::
is_same
<
T
,
float
>::
value
)
{
return
val
;
}
else
{
array_t
<
float
,
N
>
out
;
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
out
.
data
[
i
]
=
upcast_s
(
val
.
data
[
i
]);
}
return
out
;
}
}
template
<
typename
O
>
DINLINE
O
downcast
(
array_t
<
float
,
O
::
size
>
val
)
{
if
constexpr
(
std
::
is_same
<
typename
O
::
type
,
float
>::
value
)
{
return
val
;
}
else
{
O
out
;
#pragma unroll
for
(
int
i
=
0
;
i
<
O
::
size
;
i
++
)
{
out
.
data
[
i
]
=
downcast_s
<
typename
O
::
type
>
(
val
.
data
[
i
]);
}
return
out
;
}
}
static
DINLINE
void
st_flag_release
(
FlagType
*
flag_addr
,
FlagType
flag
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm
volatile
(
"st.release.sys.global.u32 [%1], %0;"
::
"r"
(
flag
),
"l"
(
flag_addr
));
#else
asm
volatile
(
"membar.sys; st.volatile.global.u32 [%1], %0;"
::
"r"
(
flag
),
"l"
(
flag_addr
));
#endif
}
static
DINLINE
FlagType
ld_flag_acquire
(
FlagType
*
flag_addr
)
{
FlagType
flag
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm
volatile
(
"ld.acquire.sys.global.u32 %0, [%1];"
:
"=r"
(
flag
)
:
"l"
(
flag_addr
));
#else
asm
volatile
(
"ld.volatile.global.u32 %0, [%1]; membar.gl;"
:
"=r"
(
flag
)
:
"l"
(
flag_addr
));
#endif
return
flag
;
}
static
DINLINE
void
st_flag_volatile
(
FlagType
*
flag_addr
,
FlagType
flag
)
{
asm
volatile
(
"st.volatile.global.u32 [%1], %0;"
::
"r"
(
flag
),
"l"
(
flag_addr
));
}
static
DINLINE
FlagType
ld_flag_volatile
(
FlagType
*
flag_addr
)
{
FlagType
flag
;
asm
volatile
(
"ld.volatile.global.u32 %0, [%1];"
:
"=r"
(
flag
)
:
"l"
(
flag_addr
));
return
flag
;
}
// is_start: whether this is the very first synchronization barrier.
// need_fence: whether a memory fence is needed. If true, a release-acquire
// semantic is used to enforce memory access order before and after this
// barrier.
template
<
int
ngpus
,
bool
is_start
,
bool
need_fence
=
false
>
DINLINE
void
multi_gpu_barrier
(
const
RankSignals
&
sg
,
Signal
*
self_sg
,
int
rank
)
{
if
constexpr
(
!
is_start
)
__syncthreads
();
static_assert
(
!
(
is_start
&&
need_fence
));
// Start barrier shouldn't need fence.
if
(
threadIdx
.
x
<
ngpus
)
{
// Increment the counter. Technically we only need one counter, but we use
// multiple per block to eliminate the need to share the counter via smem.
auto
val
=
self_sg
->
self_counter
[
blockIdx
.
x
][
threadIdx
.
x
]
+=
1
;
// Write the expected counter value to peer and wait for correct value from
// peer.
auto
peer_counter_ptr
=
&
sg
.
signals
[
threadIdx
.
x
]
->
peer_counter
[
val
%
2
][
blockIdx
.
x
][
rank
];
auto
self_counter_ptr
=
&
self_sg
->
peer_counter
[
val
%
2
][
blockIdx
.
x
][
threadIdx
.
x
];
if
constexpr
(
need_fence
)
{
st_flag_release
(
peer_counter_ptr
,
val
);
while
(
ld_flag_acquire
(
self_counter_ptr
)
!=
val
)
;
}
else
{
st_flag_volatile
(
peer_counter_ptr
,
val
);
while
(
ld_flag_volatile
(
self_counter_ptr
)
!=
val
)
;
}
}
if
constexpr
(
is_start
||
need_fence
)
__syncthreads
();
}
template
<
typename
P
,
int
ngpus
,
typename
A
>
DINLINE
P
packed_reduce
(
const
P
*
ptrs
[],
int
idx
)
{
A
tmp
=
upcast
(
ptrs
[
0
][
idx
]);
#pragma unroll
for
(
int
i
=
1
;
i
<
ngpus
;
i
++
)
{
packed_assign_add
(
tmp
,
upcast
(
ptrs
[
i
][
idx
]));
}
return
downcast
<
P
>
(
tmp
);
}
template
<
typename
T
,
int
ngpus
>
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_1stage
(
RankData
*
_dp
,
RankSignals
sg
,
Signal
*
self_sg
,
T
*
__restrict__
result
,
int
rank
,
int
size
)
{
using
P
=
typename
packed_t
<
T
>::
P
;
using
A
=
typename
packed_t
<
T
>::
A
;
// note: we don't reorder the address so the accumulation order is the same
// for all ranks, ensuring bitwise identical results
auto
dp
=
*
_dp
;
multi_gpu_barrier
<
ngpus
,
true
>
(
sg
,
self_sg
,
rank
);
// do the actual reduction
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
((
P
*
)
result
)[
idx
]
=
packed_reduce
<
P
,
ngpus
,
A
>
((
const
P
**
)
&
dp
.
ptrs
[
0
],
idx
);
}
multi_gpu_barrier
<
ngpus
,
false
>
(
sg
,
self_sg
,
rank
);
}
template
<
typename
P
>
DINLINE
P
*
get_tmp_buf
(
Signal
*
sg
)
{
return
(
P
*
)(((
Signal
*
)
sg
)
+
1
);
}
template
<
typename
T
,
int
ngpus
>
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_2stage
(
RankData
*
_dp
,
RankSignals
sg
,
Signal
*
self_sg
,
T
*
__restrict__
result
,
int
rank
,
int
size
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
gridDim
.
x
*
blockDim
.
x
;
using
P
=
typename
packed_t
<
T
>::
P
;
using
A
=
typename
packed_t
<
T
>::
A
;
int
part
=
size
/
ngpus
;
int
start
=
rank
*
part
;
int
end
=
rank
==
ngpus
-
1
?
size
:
start
+
part
;
int
largest_part
=
part
+
size
%
ngpus
;
const
P
*
ptrs
[
ngpus
];
P
*
tmps
[
ngpus
];
#pragma unroll
for
(
int
i
=
0
;
i
<
ngpus
;
i
++
)
{
int
target
=
(
rank
+
i
)
%
ngpus
;
ptrs
[
i
]
=
(
const
P
*
)
_dp
->
ptrs
[
target
];
tmps
[
i
]
=
get_tmp_buf
<
P
>
(
sg
.
signals
[
target
]);
}
auto
tmp_out
=
tmps
[
0
];
multi_gpu_barrier
<
ngpus
,
true
>
(
sg
,
self_sg
,
rank
);
// stage 1: reduce scatter
for
(
int
idx
=
start
+
tid
;
idx
<
end
;
idx
+=
stride
)
{
tmp_out
[
idx
-
start
]
=
packed_reduce
<
P
,
ngpus
,
A
>
(
ptrs
,
idx
);
}
multi_gpu_barrier
<
ngpus
,
false
,
true
>
(
sg
,
self_sg
,
rank
);
// stage 2: allgather. Note: it's important to match the tid between
// the two stages, because visibility across devices is only guaranteed
// between threads that have the same tid. If thread i computes the sum of
// start + i in the first stage, then thread i also gathers start + i from all
// ranks.
for
(
int
idx
=
tid
;
idx
<
largest_part
;
idx
+=
stride
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
ngpus
;
i
++
)
{
int
gather_from_rank
=
((
rank
+
i
)
%
ngpus
);
if
(
gather_from_rank
==
ngpus
-
1
||
idx
<
part
)
{
int
dst_idx
=
gather_from_rank
*
part
+
idx
;
((
P
*
)
result
)[
dst_idx
]
=
tmps
[
i
][
idx
];
}
}
}
}
using
IPC_KEY
=
std
::
array
<
uint8_t
,
sizeof
(
cudaIpcMemHandle_t
)
>
;
static_assert
(
sizeof
(
IPC_KEY
)
==
sizeof
(
cudaIpcMemHandle_t
));
static_assert
(
alignof
(
IPC_KEY
)
==
alignof
(
cudaIpcMemHandle_t
));
class
CustomAllreduce
{
public:
int
rank_
;
int
world_size_
;
bool
full_nvlink_
;
RankSignals
sg_
;
// Stores an map from a pointer to its peer pointters from all ranks.
std
::
unordered_map
<
void
*
,
RankData
*>
buffers_
;
Signal
*
self_sg_
;
// Stores rank data from all ranks. This is mainly for cuda graph purposes.
// For cuda graph to work, all kernel arguments must be fixed during graph
// capture time. However, the peer pointers are not known during graph capture
// time. Therefore, during capture, we increment the rank data pointer and use
// that as the argument to the kernel. The kernel arguments are stored in
// graph_unreg_buffers_. The actual peer pointers will be filled in at the
// memory pointed to by the pointers in graph_unreg_buffers_ when
// the IPC handles are exchanged between ranks.
//
// The overall process looks like this:
// 1. Graph capture.
// 2. Each rank obtains the IPC handles for each addresses used during cuda
// graph capture using get_graph_buffer_ipc_meta.
// 3. (In Python) all gather the IPC handles.
// 4. Obtain the peer pointers by opening the IPC handles, and store them in
// the rank data array at corresponding positions.
RankData
*
d_rank_data_base_
,
*
d_rank_data_end_
;
std
::
vector
<
void
*>
graph_unreg_buffers_
;
// a map from IPC handles to opened IPC pointers
std
::
map
<
IPC_KEY
,
char
*>
ipc_handles_
;
/**
* Signals are an array of ipc-enabled buffers from all ranks.
* For each of the buffer, the layout is as follows:
* | -- sizeof(Signal) -- | ------ a few MB ----- |
* The first section is for allreduce synchronization, and the second section
* is for storing the intermediate results required by some allreduce algos.
*
* Note: this class does not own any device memory. Any required buffers
* are passed in from the constructor.
*/
CustomAllreduce
(
Signal
**
signals
,
void
*
rank_data
,
size_t
rank_data_sz
,
int
rank
,
int
world_size
,
bool
full_nvlink
=
true
)
:
rank_
(
rank
),
world_size_
(
world_size
),
full_nvlink_
(
full_nvlink
),
self_sg_
(
signals
[
rank
]),
d_rank_data_base_
(
reinterpret_cast
<
RankData
*>
(
rank_data
)),
d_rank_data_end_
(
d_rank_data_base_
+
rank_data_sz
/
sizeof
(
RankData
))
{
for
(
int
i
=
0
;
i
<
world_size_
;
i
++
)
{
sg_
.
signals
[
i
]
=
signals
[
i
];
}
}
char
*
open_ipc_handle
(
const
void
*
ipc_handle
)
{
auto
[
it
,
new_handle
]
=
ipc_handles_
.
insert
({
*
((
IPC_KEY
*
)
ipc_handle
),
nullptr
});
if
(
new_handle
)
{
char
*
ipc_ptr
;
CHECK_CUDA_SUCCESS
(
cudaIpcOpenMemHandle
(
(
void
**
)
&
ipc_ptr
,
*
((
const
cudaIpcMemHandle_t
*
)
ipc_handle
),
cudaIpcMemLazyEnablePeerAccess
));
it
->
second
=
ipc_ptr
;
}
return
it
->
second
;
}
std
::
pair
<
std
::
string
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
()
{
auto
num_buffers
=
graph_unreg_buffers_
.
size
();
auto
handle_sz
=
sizeof
(
cudaIpcMemHandle_t
);
std
::
string
handles
(
handle_sz
*
num_buffers
,
static_cast
<
char
>
(
0
));
std
::
vector
<
int64_t
>
offsets
(
num_buffers
);
for
(
int
i
=
0
;
i
<
num_buffers
;
i
++
)
{
auto
ptr
=
graph_unreg_buffers_
[
i
];
void
*
base_ptr
;
// note: must share the base address of each allocation, or we get wrong
// address
if
(
cuPointerGetAttribute
(
&
base_ptr
,
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR
,
(
CUdeviceptr
)
ptr
)
!=
CUDA_SUCCESS
)
throw
std
::
runtime_error
(
"failed to get pointer attr"
);
CHECK_CUDA_SUCCESS
(
cudaIpcGetMemHandle
((
cudaIpcMemHandle_t
*
)
&
handles
[
i
*
handle_sz
],
base_ptr
));
offsets
[
i
]
=
((
char
*
)
ptr
)
-
((
char
*
)
base_ptr
);
}
return
std
::
make_pair
(
handles
,
offsets
);
}
void
check_rank_data_capacity
(
size_t
num
=
1
)
{
if
(
d_rank_data_base_
+
num
>
d_rank_data_end_
)
throw
std
::
runtime_error
(
"Rank data buffer is overflowed by "
+
std
::
to_string
(
d_rank_data_base_
+
num
-
d_rank_data_end_
));
}
/**
* Register already-shared IPC pointers.
*/
void
register_buffer
(
void
**
ptrs
)
{
check_rank_data_capacity
();
RankData
data
;
for
(
int
i
=
0
;
i
<
world_size_
;
i
++
)
{
data
.
ptrs
[
i
]
=
ptrs
[
i
];
}
auto
d_data
=
d_rank_data_base_
++
;
CHECK_CUDA_SUCCESS
(
cudaMemcpy
(
d_data
,
&
data
,
sizeof
(
RankData
),
cudaMemcpyHostToDevice
));
buffers_
[
ptrs
[
rank_
]]
=
d_data
;
}
// Note: when registering graph buffers, we intentionally choose to not
// deduplicate the addresses. That means if the allocator reuses some
// addresses, they will be registered again. This is to account for the remote
// possibility of different allocation patterns between ranks. For example,
// rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
void
register_graph_buffers
(
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
auto
num_buffers
=
graph_unreg_buffers_
.
size
();
check_rank_data_capacity
(
num_buffers
);
std
::
vector
<
RankData
>
rank_data
(
num_buffers
);
for
(
int
i
=
0
;
i
<
num_buffers
;
i
++
)
{
auto
self_ptr
=
graph_unreg_buffers_
[
i
];
auto
&
rd
=
rank_data
[
i
];
for
(
int
j
=
0
;
j
<
world_size_
;
j
++
)
{
if
(
j
!=
rank_
)
{
char
*
handle
=
open_ipc_handle
(
&
handles
[
j
][
i
*
sizeof
(
cudaIpcMemHandle_t
)]);
handle
+=
offsets
[
j
][
i
];
rd
.
ptrs
[
j
]
=
handle
;
}
else
{
rd
.
ptrs
[
j
]
=
self_ptr
;
}
}
}
CHECK_CUDA_SUCCESS
(
cudaMemcpy
(
d_rank_data_base_
,
rank_data
.
data
(),
sizeof
(
RankData
)
*
num_buffers
,
cudaMemcpyHostToDevice
));
d_rank_data_base_
+=
num_buffers
;
graph_unreg_buffers_
.
clear
();
}
/**
* Performs allreduce, assuming input has already been registered.
*
* Block and grid default configs are results after careful grid search. Using
* 36 blocks give the best or close to the best runtime on the devices I
* tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only
* take a small amount of SMs. Not quite sure the underlying reason, but my
* guess is that too many SMs will cause contention on NVLink bus.
*/
template
<
typename
T
>
void
allreduce
(
cudaStream_t
stream
,
T
*
input
,
T
*
output
,
int
size
,
int
threads
=
512
,
int
block_limit
=
36
)
{
auto
d
=
packed_t
<
T
>::
P
::
size
;
if
(
size
%
d
!=
0
)
throw
std
::
runtime_error
(
"custom allreduce currently requires input length to be multiple "
"of "
+
std
::
to_string
(
d
));
if
(
block_limit
>
kMaxBlocks
)
throw
std
::
runtime_error
(
"max supported block limit is "
+
std
::
to_string
(
kMaxBlocks
)
+
". Got "
+
std
::
to_string
(
block_limit
));
RankData
*
ptrs
;
cudaStreamCaptureStatus
status
;
CHECK_CUDA_SUCCESS
(
cudaStreamIsCapturing
(
stream
,
&
status
));
if
(
status
==
cudaStreamCaptureStatusActive
)
{
ptrs
=
d_rank_data_base_
+
graph_unreg_buffers_
.
size
();
graph_unreg_buffers_
.
push_back
(
input
);
}
else
{
auto
it
=
buffers_
.
find
(
input
);
if
(
it
==
buffers_
.
end
())
throw
std
::
runtime_error
(
"buffer address "
+
std
::
to_string
(
reinterpret_cast
<
uint64_t
>
(
input
))
+
" is not registered!"
);
ptrs
=
it
->
second
;
}
size
/=
d
;
auto
bytes
=
size
*
sizeof
(
typename
packed_t
<
T
>::
P
);
int
blocks
=
std
::
min
(
block_limit
,
(
size
+
threads
-
1
)
/
threads
);
#define KL(ngpus, name) name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, rank_, size);
// TODO(hanzhi713): Threshold is different for A100 and H100.
// Add per device threshold.
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (world_size_ == 2) { \
KL(ngpus, cross_device_reduce_1stage); \
} else if (full_nvlink_) { \
if ((world_size_ <= 4 && bytes < 512 * 1024) || (world_size_ <= 8 && bytes < 256 * 1024)) { \
KL(ngpus, cross_device_reduce_1stage); \
} else { \
KL(ngpus, cross_device_reduce_2stage); \
} \
} \
break; \
}
switch
(
world_size_
)
{
REDUCE_CASE
(
2
)
REDUCE_CASE
(
4
)
REDUCE_CASE
(
6
)
REDUCE_CASE
(
8
)
default:
throw
std
::
runtime_error
(
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
"gpus = "
+
std
::
to_string
(
world_size_
));
}
#undef REDUCE_CASE
#undef KL
}
~
CustomAllreduce
()
{
for
(
auto
[
_
,
ptr
]
:
ipc_handles_
)
{
CHECK_CUDA_SUCCESS
(
cudaIpcCloseMemHandle
(
ptr
));
}
}
};
/**
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
a template instantiation:
* template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
half *, int, int, int);
*/
}
// namespace vllm
sgl-kernel/csrc/allreduce/trt_reduce_internal.cu
deleted
100644 → 0
View file @
0d99adb7
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// reference:
// https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cassert>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <tuple>
#include "trt_reduce_internal.cuh"
#include "utils.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
void
st_flag_release
(
uint32_t
const
&
flag
,
uint32_t
*
flag_addr
)
{
asm
volatile
(
"st.global.release.sys.b32 [%1], %0;"
::
"r"
(
flag
),
"l"
(
flag_addr
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
ld_flag_acquire
(
uint32_t
*
flag_addr
)
{
uint32_t
flag
;
asm
volatile
(
"ld.global.acquire.sys.b32 %0, [%1];"
:
"=r"
(
flag
)
:
"l"
(
flag_addr
));
return
flag
;
}
static
inline
__device__
void
st_flag_volatile
(
uint32_t
const
&
flag
,
uint32_t
*
flag_addr
)
{
asm
volatile
(
"st.volatile.global.u32 [%1], %0;"
::
"r"
(
flag
),
"l"
(
flag_addr
));
}
static
inline
__device__
uint32_t
ld_flag_volatile
(
uint32_t
*
flag_addr
)
{
uint32_t
flag
;
asm
volatile
(
"ld.volatile.global.u32 %0, [%1];"
:
"=r"
(
flag
)
:
"l"
(
flag_addr
));
return
flag
;
}
namespace
trt_llm
{
////////////////////////////////////////////////////////////////////////////////////////////////////
// Type Converter that packs data format to 128 bits data type
//
using
PackedFloat
=
union
{
int4
packed
;
float
unpacked
[
4
];
};
using
PackedHalf
=
union
{
int4
packed
;
half2
unpacked
[
4
];
};
template
<
typename
T
>
struct
PackedOn16Bytes
{};
template
<
>
struct
PackedOn16Bytes
<
float
>
{
using
Type
=
PackedFloat
;
};
template
<
>
struct
PackedOn16Bytes
<
half
>
{
using
Type
=
PackedHalf
;
};
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
using
PackedBFloat16
=
union
{
int4
packed
;
__nv_bfloat162
unpacked
[
4
];
};
template
<
>
struct
PackedOn16Bytes
<
__nv_bfloat16
>
{
using
Type
=
PackedBFloat16
;
};
#endif
// add two 128b data
template
<
typename
T
>
inline
__device__
int4
add128b
(
T
&
a
,
T
&
b
)
{
T
c
;
c
.
unpacked
[
0
]
=
a
.
unpacked
[
0
]
+
b
.
unpacked
[
0
];
c
.
unpacked
[
1
]
=
a
.
unpacked
[
1
]
+
b
.
unpacked
[
1
];
c
.
unpacked
[
2
]
=
a
.
unpacked
[
2
]
+
b
.
unpacked
[
2
];
c
.
unpacked
[
3
]
=
a
.
unpacked
[
3
]
+
b
.
unpacked
[
3
];
return
c
.
packed
;
}
__inline__
__device__
void
multi_gpu_barrier
(
uint32_t
**
signals
,
uint32_t
const
flag
,
size_t
const
local_rank
,
size_t
const
world_size
,
int
const
tidx
,
int
const
bidx
)
{
// After this function, at least one block in each GPU has reached the barrier
if
(
tidx
<
world_size
)
{
// we can think of signals having the shape [world_size, world_size]
// Dimension 0 is the "listening" dimension, dimension 1 is "emitting" dimension
// Block 0 broadcasts its flag (local_rank on emitting dimension) to all receivers
size_t
offset
=
(
flag
%
2
)
?
world_size
:
0
;
if
(
bidx
==
0
)
{
st_flag_release
(
flag
,
signals
[
tidx
]
+
offset
+
local_rank
);
}
// All blocks check that corresponding block 0 on other GPUs have set the flag
// No deadlock because block #0 is always the first block started
uint32_t
*
peer_barrier_d
=
signals
[
local_rank
]
+
offset
+
tidx
;
while
(
ld_flag_acquire
(
peer_barrier_d
)
!=
flag
)
{
}
}
__syncthreads
();
}
template
<
bool
start
,
bool
need_fence
=
false
>
__inline__
__device__
void
block_barrier
(
uint32_t
**
signals
,
uint32_t
const
flag
,
size_t
const
local_rank
,
size_t
const
world_size
,
int
const
tidx
,
int
const
bidx
,
int
const
grid_size
)
{
if
constexpr
(
!
start
)
{
__syncthreads
();
}
// After this function, the block of id == bidx of each GPU has reached the barrier
if
(
tidx
<
world_size
)
{
// we can think of signals having the shape [world_size, 2, num_blocks, world_size]
// (+ an offset on dim 2 to account for flags used in multi_gpu_barrier)
// Dimension 0 is the "listening" dimension, dimension 3 is "emitting" dimension
// Block broadcast its flag (local_rank on emitting dimension) to all receivers
uint32_t
flag_block_offset
=
world_size
+
bidx
*
world_size
;
flag_block_offset
+=
(
grid_size
+
1
)
*
world_size
*
(
flag
%
2
);
uint32_t
*
peer_barrier_d
=
signals
[
local_rank
]
+
flag_block_offset
+
tidx
;
// Blocks check that corresponding blocks on other GPUs have also set the flag
if
constexpr
(
need_fence
)
{
st_flag_release
(
flag
,
signals
[
tidx
]
+
flag_block_offset
+
local_rank
);
while
(
ld_flag_acquire
(
peer_barrier_d
)
!=
flag
)
{
}
}
else
{
st_flag_volatile
(
flag
,
signals
[
tidx
]
+
flag_block_offset
+
local_rank
);
while
(
ld_flag_volatile
(
peer_barrier_d
)
!=
flag
)
{
}
}
}
if
constexpr
(
start
||
need_fence
)
{
__syncthreads
();
}
}
template
<
typename
T
,
int
RANKS_PER_NODE
,
bool
COPY_INPUT
=
true
>
static
__global__
void
__launch_bounds__
(
512
,
1
)
oneShotAllReduceKernel
(
AllReduceParams
params
)
{
// Suppose that two GPUs participate in the AR exchange, and we start four blocks.
// The message is partitioned into chunks as detailed below:
// message
// |-------------------|
// GPU 0 | B0 | B1 | B2 | B3 |
// GPU 1 | B0 | B1 | B2 | B3 |
//
// Here the step-by-step behavior of one block:
// 1. B0 copies the chunk it is responsible for, from local_input to shareable buffer
// 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier)
// 3. B0 on GPU 0 pull and sum the chunk from GPU 1, writes the result to local_output
//
// With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2.
// We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready
//
// With PUSH_MODE, we consider that the shared buffer is of size:
// params.peer_comm_buffer_ptrs: [world_size, world_size, message_size]
//
// Here the step-by-step behavior of one block:
// 1. B0 push the chunk is it responsible for into all other GPUs:
// params.peer_comm_buffer_ptrs[:, local_gpu, B0 slice]
// 2. block sync so the block is shared by other GPUs
// 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice]
int
const
bidx
=
blockIdx
.
x
;
int
const
tidx
=
threadIdx
.
x
;
int
const
grid_size
=
gridDim
.
x
;
// The number of elements packed into one for comms
static
constexpr
int
NUM_ELTS
=
16
/
sizeof
(
T
);
// Packed data type for comms
using
PackedStruct
=
typename
PackedOn16Bytes
<
T
>::
Type
;
// The source pointers. Distributed round-robin for the different warps.
auto
peer_comm_buffer_ptrs
=
params
.
peer_comm_buffer_ptrs
->
ptrs
;
T
*
local_shared_buffer
=
reinterpret_cast
<
T
*>
(
peer_comm_buffer_ptrs
[
params
.
local_rank
]);
// Start and end offsets of the thread
size_t
chunk_start
=
bidx
*
params
.
elts_per_block
+
tidx
*
NUM_ELTS
;
size_t
chunk_end
=
std
::
min
((
bidx
+
1
)
*
params
.
elts_per_block
,
params
.
elts_per_rank
);
if
constexpr
(
COPY_INPUT
)
{
T
const
*
local_input_buffer
=
reinterpret_cast
<
T
const
*>
(
params
.
local_input_buffer_ptr
);
// Copy from local buffer to shareable buffer
for
(
size_t
iter_offset
=
chunk_start
;
iter_offset
<
chunk_end
;
iter_offset
+=
blockDim
.
x
*
NUM_ELTS
)
{
*
reinterpret_cast
<
int4
*>
(
&
local_shared_buffer
[
iter_offset
])
=
*
reinterpret_cast
<
int4
const
*>
(
&
local_input_buffer
[
iter_offset
]);
}
}
// wait for equivalent blocks of other GPUs to have copied data to their shareable buffer
block_barrier
<
true
>
(
params
.
peer_barrier_ptrs_in
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
,
grid_size
);
// Each block accumulates the values from the different GPUs on the same node.
for
(
size_t
iter_offset
=
chunk_start
;
iter_offset
<
chunk_end
;
iter_offset
+=
blockDim
.
x
*
NUM_ELTS
)
{
// Iterate over the different ranks/devices on the node to load the values.
PackedStruct
vals
[
RANKS_PER_NODE
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
RANKS_PER_NODE
;
++
ii
)
{
vals
[
ii
].
packed
=
*
reinterpret_cast
<
int4
const
*>
(
&
((
T
*
)
peer_comm_buffer_ptrs
[
ii
])[
iter_offset
]);
}
// Sum the values from the different ranks.
PackedStruct
sums
;
sums
.
packed
=
{
0
,
0
,
0
,
0
};
#pragma unroll
for
(
int
rank
=
0
;
rank
<
RANKS_PER_NODE
;
++
rank
)
{
// Always reduce from rank 0 to ensure stable reduce order.
sums
.
packed
=
add128b
(
sums
,
vals
[
rank
]);
}
// Store to the destination buffer.
*
reinterpret_cast
<
int4
*>
(
&
reinterpret_cast
<
T
*>
(
params
.
local_output_buffer_ptr
)[
iter_offset
])
=
sums
.
packed
;
}
block_barrier
<
false
>
(
params
.
peer_barrier_ptrs_out
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
,
grid_size
);
}
template
<
typename
T
,
int
RANKS_PER_NODE
,
bool
COPY_INPUT
=
true
>
static
__global__
void
__launch_bounds__
(
512
,
1
)
twoShotAllReduceKernel
(
AllReduceParams
params
)
{
// Suppose that two GPUs participate in the AR exchange, and we start two blocks.
// The message is partitioned into chunks as detailed below:
// message
// |-------------------|
// |--GPU 0--|--GPU 1--| (GPU responsibility parts)
// GPU 0 | B0 | B1 | B0 | B1 |
// GPU 1 | B0 | B1 | B0 | B1 |
//
// Here the step-by-step behavior of one block:
// 1. B0 copies all chunks is it responsible for, from local_input to shareable buffer
// 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier #0)
// 3. B0 on GPU 0 gather and sum the B0 chunks from GPU 1, that are in the GPU 0 responsibility
// part (the first half of the message, see GPU responsibility row above)
// 3bis. Likewise, B0 on GPU 1 copies and sum the chunks for GPU 0,
// where GPU 1 is responsible: the second half of the message.
// 4. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier #1)
// 5. B0 writes result to local_output. It gathers each chunk from its responsible GPU.
// For example, here it reads the first chunk from GPU 0 and second chunk from GPU 1.
//
// With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2.
// We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready
// to be read.
//
// Note that compared to one-shot, one block (CTA) writes multiple input chunks and write multiple output chunks.
// However, it's only responsible for the summation of a single chunk.
//
// With PUSH_MODE, we consider that the shared buffer is of size:
// params.peer_comm_buffer_ptrs: [world_size, world_size, message_size / world_size]
//
// Here the step-by-step behavior of one block:
// 1. B0 push the chunks is it responsible for into the corresponding GPUs:
// params.peer_comm_buffer_ptrs[target_gpu, local_gpu, current B0 slice]
// 2. block sync so the blocks have been shared by other GPUs
// 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice]
// 4. block barrier (corresponding blocks have finished reduction)
// 5. pull and write on local buffer, by reading params.peer_comm_buffer_ptrs[:, 0, B0 slice] (reduction result is
// written at index 0 of 2nd dim)
int
const
bidx
=
blockIdx
.
x
;
int
const
tidx
=
threadIdx
.
x
;
int
const
grid_size
=
gridDim
.
x
;
// The number of elements packed into one for comms
static
constexpr
int
PACKED_ELTS
=
16
/
sizeof
(
T
);
using
PackedType
=
typename
PackedOn16Bytes
<
T
>::
Type
;
T
const
*
local_input_buffer
=
reinterpret_cast
<
T
const
*>
(
params
.
local_input_buffer_ptr
);
auto
peer_comm_buffer_ptrs
=
params
.
peer_comm_buffer_ptrs
->
ptrs
;
T
*
local_shared_buffer
=
reinterpret_cast
<
T
*>
(
peer_comm_buffer_ptrs
[
params
.
local_rank
]);
T
*
local_output_buffer
=
reinterpret_cast
<
T
*>
(
params
.
local_output_buffer_ptr
);
size_t
const
chunk_start
=
bidx
*
params
.
elts_per_block
+
tidx
*
PACKED_ELTS
;
size_t
const
chunk_end
=
min
(
chunk_start
+
params
.
elts_per_block
,
params
.
elts_per_rank
);
T
*
buffers
[
RANKS_PER_NODE
];
T
*
buffers_unorder
[
RANKS_PER_NODE
];
int
ranks
[
RANKS_PER_NODE
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
RANKS_PER_NODE
;
++
ii
)
{
// A mapping of the ranks to scatter reads as much as possible
int
rank
=
(
params
.
local_rank
+
ii
)
%
RANKS_PER_NODE
;
ranks
[
ii
]
=
rank
;
buffers
[
ii
]
=
reinterpret_cast
<
T
*>
(
peer_comm_buffer_ptrs
[
rank
]);
buffers_unorder
[
ii
]
=
reinterpret_cast
<
T
*>
(
peer_comm_buffer_ptrs
[
ii
]);
}
#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12))
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize
();
#endif
#endif
if
constexpr
(
COPY_INPUT
)
{
// Copy all blocks from local buffer to shareable buffer
for
(
size_t
local_offset
=
chunk_start
;
local_offset
<
chunk_end
;
local_offset
+=
blockDim
.
x
*
PACKED_ELTS
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
RANKS_PER_NODE
;
++
ii
)
{
size_t
offset_rank
=
ranks
[
ii
]
*
params
.
elts_per_rank
+
local_offset
;
if
(
offset_rank
>=
params
.
elts_total
)
{
continue
;
}
*
reinterpret_cast
<
int4
*>
(
&
local_shared_buffer
[
offset_rank
])
=
*
reinterpret_cast
<
int4
const
*>
(
&
local_input_buffer
[
offset_rank
]);
}
}
}
block_barrier
<
true
>
(
params
.
peer_barrier_ptrs_in
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
,
grid_size
);
// Each block accumulates the values from the different GPUs on the same node.
for
(
size_t
local_offset
=
chunk_start
;
local_offset
<
chunk_end
;
local_offset
+=
blockDim
.
x
*
PACKED_ELTS
)
{
size_t
const
responsible_block_offset
=
local_offset
+
params
.
rank_offset
;
// Iterate over the different ranks/devices on the node to load the values.
PackedType
vals
[
RANKS_PER_NODE
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
RANKS_PER_NODE
;
++
ii
)
{
vals
[
ii
].
packed
=
*
reinterpret_cast
<
int4
const
*>
(
&
buffers_unorder
[
ii
][
responsible_block_offset
]);
}
// Sum the values from the different ranks.
PackedType
sums
;
sums
.
packed
=
{
0
,
0
,
0
,
0
};
#pragma unroll
for
(
int
rank
=
0
;
rank
<
RANKS_PER_NODE
;
++
rank
)
{
// Always reduce from rank 0 to ensure stable reduce order.
sums
.
packed
=
add128b
(
sums
,
vals
[
rank
]);
}
// Store to the local buffer or tmp buffer
if
constexpr
(
COPY_INPUT
)
{
*
reinterpret_cast
<
int4
*>
(
&
local_shared_buffer
[
responsible_block_offset
])
=
sums
.
packed
;
}
else
{
*
reinterpret_cast
<
int4
*>
(
&
params
.
tmp_result_buffers
[
params
.
local_rank
][
responsible_block_offset
])
=
sums
.
packed
;
}
}
block_barrier
<
false
,
true
>
(
params
.
peer_barrier_ptrs_out
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
,
grid_size
);
// Gather all needed elts from other intra-node ranks
for
(
size_t
local_offset
=
chunk_start
;
local_offset
<
chunk_end
;
local_offset
+=
blockDim
.
x
*
PACKED_ELTS
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
RANKS_PER_NODE
;
++
ii
)
{
// use round-robin gathering from other ranks
size_t
offset_rank
=
ranks
[
ii
]
*
params
.
elts_per_rank
+
local_offset
;
if
(
offset_rank
>=
params
.
elts_total
)
{
continue
;
}
if
constexpr
(
COPY_INPUT
)
{
*
reinterpret_cast
<
int4
*>
(
&
local_output_buffer
[
offset_rank
])
=
*
reinterpret_cast
<
int4
*>
(
&
buffers
[
ii
][
offset_rank
]);
}
else
{
*
reinterpret_cast
<
int4
*>
(
&
local_output_buffer
[
offset_rank
])
=
*
reinterpret_cast
<
int4
*>
(
&
params
.
tmp_result_buffers
[
ranks
[
ii
]][
offset_rank
]);
}
}
}
#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12))
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion
();
#endif
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
int
divUp
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
inline
int
roundUp
(
int
a
,
int
n
)
{
return
divUp
(
a
,
n
)
*
n
;
}
std
::
tuple
<
int
,
int
>
kernelLaunchConfig
(
AllReduceStrategyType
algo
,
AllReduceParams
&
params
,
size_t
elts_per_thread
)
{
int
blocks_per_grid
=
1
,
threads_per_block
=
DEFAULT_BLOCK_SIZE
;
switch
(
algo
)
{
case
AllReduceStrategyType
::
ONESHOT
:
{
assert
(
params
.
elts_total
%
elts_per_thread
==
0
);
size_t
const
total_threads
=
roundUp
(
params
.
elts_total
/
elts_per_thread
,
WARP_SIZE
);
threads_per_block
=
std
::
min
(
DEFAULT_BLOCK_SIZE
,
total_threads
);
blocks_per_grid
=
std
::
min
(
static_cast
<
int
>
(
MAX_ALL_REDUCE_BLOCKS
),
divUp
(
total_threads
,
threads_per_block
));
params
.
elts_per_block
=
roundUp
(
divUp
(
params
.
elts_total
,
blocks_per_grid
),
elts_per_thread
);
params
.
elts_per_rank
=
params
.
elts_total
;
break
;
}
case
AllReduceStrategyType
::
TWOSHOT
:
{
assert
(
params
.
elts_total
%
(
elts_per_thread
*
params
.
ranks_per_node
)
==
0
);
size_t
const
total_threads
=
roundUp
(
params
.
elts_total
/
(
elts_per_thread
*
params
.
ranks_per_node
),
WARP_SIZE
);
threads_per_block
=
std
::
min
(
DEFAULT_BLOCK_SIZE
,
total_threads
);
blocks_per_grid
=
std
::
min
(
static_cast
<
int
>
(
MAX_ALL_REDUCE_BLOCKS
),
divUp
(
total_threads
,
threads_per_block
));
params
.
elts_per_rank
=
params
.
elts_total
/
params
.
ranks_per_node
;
params
.
rank_offset
=
params
.
local_rank
*
params
.
elts_per_rank
;
params
.
elts_per_block
=
roundUp
(
divUp
(
params
.
elts_per_rank
,
blocks_per_grid
),
elts_per_thread
);
break
;
}
default:
assert
(
false
&&
"Algorithm not supported here."
);
}
return
std
::
make_tuple
(
blocks_per_grid
,
threads_per_block
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
RANKS_PER_NODE
,
bool
COPY_INPUT
>
void
dispatchARKernels
(
AllReduceStrategyType
algo
,
AllReduceParams
&
param
,
int
blocks_per_grid
,
int
threads_per_block
,
cudaStream_t
stream
)
{
switch
(
algo
)
{
case
AllReduceStrategyType
::
ONESHOT
:
{
oneShotAllReduceKernel
<
T
,
RANKS_PER_NODE
,
COPY_INPUT
><<<
blocks_per_grid
,
threads_per_block
,
0
,
stream
>>>
(
param
);
break
;
}
case
AllReduceStrategyType
::
TWOSHOT
:
{
twoShotAllReduceKernel
<
T
,
RANKS_PER_NODE
,
COPY_INPUT
><<<
blocks_per_grid
,
threads_per_block
,
0
,
stream
>>>
(
param
);
break
;
}
}
}
template
<
typename
T
,
bool
COPY_INPUT
>
void
dispatchARKernelsCopyInput
(
AllReduceStrategyType
strat
,
AllReduceParams
&
param
,
cudaStream_t
stream
)
{
size_t
elts_per_thread
=
16
/
sizeof
(
T
);
auto
[
blocks_per_grid
,
threads_per_block
]
=
kernelLaunchConfig
(
strat
,
param
,
elts_per_thread
);
switch
(
param
.
ranks_per_node
)
{
case
2
:
dispatchARKernels
<
T
,
2
,
COPY_INPUT
>
(
strat
,
param
,
blocks_per_grid
,
threads_per_block
,
stream
);
break
;
case
4
:
dispatchARKernels
<
T
,
4
,
COPY_INPUT
>
(
strat
,
param
,
blocks_per_grid
,
threads_per_block
,
stream
);
break
;
case
6
:
dispatchARKernels
<
T
,
6
,
COPY_INPUT
>
(
strat
,
param
,
blocks_per_grid
,
threads_per_block
,
stream
);
break
;
case
8
:
dispatchARKernels
<
T
,
8
,
COPY_INPUT
>
(
strat
,
param
,
blocks_per_grid
,
threads_per_block
,
stream
);
break
;
default:
break
;
}
}
template
<
typename
T
>
void
invokeOneOrTwoShotAllReduceKernel
(
AllReduceParams
&
param
,
AllReduceStrategyType
strat
,
cudaStream_t
stream
)
{
if
(
param
.
is_capturing
)
{
dispatchARKernelsCopyInput
<
T
,
false
>
(
strat
,
param
,
stream
);
}
else
{
dispatchARKernelsCopyInput
<
T
,
true
>
(
strat
,
param
,
stream
);
}
CHECK_CUDA_SUCCESS
(
cudaGetLastError
());
}
void
trtCustomAllReduce
(
AllReduceParams
&
params
,
at
::
ScalarType
data_type
,
AllReduceStrategyType
strat
,
cudaStream_t
stream
)
{
if
(
params
.
elts_total
==
0
)
{
return
;
}
switch
(
data_type
)
{
case
at
::
ScalarType
::
Float
:
invokeOneOrTwoShotAllReduceKernel
<
float
>
(
params
,
strat
,
stream
);
break
;
case
at
::
ScalarType
::
Half
:
invokeOneOrTwoShotAllReduceKernel
<
half
>
(
params
,
strat
,
stream
);
break
;
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case
at
::
ScalarType
::
BFloat16
:
invokeOneOrTwoShotAllReduceKernel
<
__nv_bfloat16
>
(
params
,
strat
,
stream
);
break
;
#endif
default:
assert
(
false
&&
"Unsupported data type"
);
}
}
}
// namespace trt_llm
sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu
deleted
100644 → 0
View file @
0d99adb7
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// reference: https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.h
#include <c10/cuda/CUDAStream.h>
#include <cassert>
#include "trt_reduce_internal.cuh"
#include "utils.h"
using
namespace
trt_llm
;
using
fptr_t
=
int64_t
;
using
IPC_KEY
=
std
::
array
<
uint8_t
,
sizeof
(
cudaIpcMemHandle_t
)
>
;
class
AllReduceMeta
{
public:
AllReduceMeta
(
int64_t
rank_id
,
int64_t
world_size
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
fptr_t
>&
buffers
,
const
std
::
vector
<
fptr_t
>&
tmp_result_buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
const
std
::
vector
<
fptr_t
>&
barrier_out
)
{
this
->
rank_id
=
(
int
)
rank_id
;
this
->
world_size
=
(
int
)
world_size
;
this
->
barrier_in
=
barrier_in
;
this
->
barrier_out
=
barrier_out
;
this
->
tmp_result_buffers
=
tmp_result_buffers
;
this
->
rank_data_base
=
reinterpret_cast
<
RankData
*>
(
rank_data
.
data_ptr
());
RankData
data
;
for
(
int
i
=
0
;
i
<
world_size
;
i
++
)
{
data
.
ptrs
[
i
]
=
(
void
*
)
buffers
[
i
];
}
auto
d_data
=
this
->
rank_data_base
++
;
CHECK_CUDA_SUCCESS
(
cudaMemcpy
(
d_data
,
&
data
,
sizeof
(
RankData
),
cudaMemcpyHostToDevice
));
this
->
buffers
=
d_data
;
}
~
AllReduceMeta
()
{
for
(
auto
[
_
,
ptr
]
:
ipc_handles_
)
{
CHECK_CUDA_SUCCESS
(
cudaIpcCloseMemHandle
(
ptr
));
}
}
public:
int
world_size
;
int
rank_id
;
std
::
vector
<
fptr_t
>
barrier_in
;
std
::
vector
<
fptr_t
>
barrier_out
;
std
::
vector
<
fptr_t
>
tmp_result_buffers
;
int
barrier_flag
=
1
;
RankData
*
buffers
;
RankData
*
rank_data_base
;
std
::
vector
<
void
*>
graph_unreg_buffers
;
std
::
map
<
IPC_KEY
,
char
*>
ipc_handles_
;
};
// Get the number of bits for a given data type.
inline
int
get_bits
(
at
::
ScalarType
dtype
)
{
switch
(
dtype
)
{
case
at
::
ScalarType
::
Float
:
return
32
;
case
at
::
ScalarType
::
Half
:
case
at
::
ScalarType
::
BFloat16
:
return
16
;
default:
assert
(
false
&&
"Unsupported data type"
);
}
}
// Check if customized all-reduce kernels can be applied.
inline
bool
CanApplyCustomAllReduce
(
int64_t
num_elements
,
at
::
ScalarType
dtype
)
{
// The customized all-reduce kernel has the following requirement(s).
return
num_elements
%
(
16
/
((
get_bits
(
dtype
)
+
7
)
/
8
))
==
0
;
}
fptr_t
init_custom_ar
(
int64_t
rank_id
,
int64_t
world_size
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
fptr_t
>&
buffers
,
const
std
::
vector
<
fptr_t
>&
tmp_result_buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
const
std
::
vector
<
fptr_t
>&
barrier_out
)
{
auto
m
=
new
AllReduceMeta
(
rank_id
,
world_size
,
rank_data
,
buffers
,
tmp_result_buffers
,
barrier_in
,
barrier_out
);
return
(
fptr_t
)
m
;
}
void
dispose
(
fptr_t
_fa
)
{
auto
fa
=
reinterpret_cast
<
AllReduceMeta
*>
(
_fa
);
delete
fa
;
}
std
::
tuple
<
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
)
{
AllReduceMeta
*
m
=
reinterpret_cast
<
AllReduceMeta
*>
(
_fa
);
auto
num_buffers
=
m
->
graph_unreg_buffers
.
size
();
auto
handle_sz
=
sizeof
(
cudaIpcMemHandle_t
);
std
::
string
handles
(
handle_sz
*
num_buffers
,
static_cast
<
char
>
(
0
));
std
::
vector
<
int64_t
>
offsets
(
num_buffers
);
for
(
int
i
=
0
;
i
<
num_buffers
;
i
++
)
{
auto
ptr
=
m
->
graph_unreg_buffers
[
i
];
void
*
base_ptr
;
// note: must share the base address of each allocation, or we get wrong
// address
if
(
cuPointerGetAttribute
(
&
base_ptr
,
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR
,
(
CUdeviceptr
)
ptr
)
!=
CUDA_SUCCESS
)
{
assert
(
false
&&
"failed to get pointer attr"
);
}
CHECK_CUDA_SUCCESS
(
cudaIpcGetMemHandle
((
cudaIpcMemHandle_t
*
)
&
handles
[
i
*
handle_sz
],
base_ptr
));
offsets
[
i
]
=
((
char
*
)
ptr
)
-
((
char
*
)
base_ptr
);
}
std
::
vector
<
int64_t
>
bytes
(
handles
.
begin
(),
handles
.
end
());
return
std
::
make_pair
(
bytes
,
offsets
);
}
char
*
open_ipc_handle
(
AllReduceMeta
*
meta
,
const
void
*
ipc_handle
)
{
auto
[
it
,
new_handle
]
=
meta
->
ipc_handles_
.
insert
({
*
((
IPC_KEY
*
)
ipc_handle
),
nullptr
});
if
(
new_handle
)
{
char
*
ipc_ptr
;
CHECK_CUDA_SUCCESS
(
cudaIpcOpenMemHandle
(
(
void
**
)
&
ipc_ptr
,
*
((
const
cudaIpcMemHandle_t
*
)
ipc_handle
),
cudaIpcMemLazyEnablePeerAccess
));
it
->
second
=
ipc_ptr
;
}
return
it
->
second
;
}
// Note: when registering graph buffers, we intentionally choose to not
// deduplicate the addresses. That means if the allocator reuses some
// addresses, they will be registered again. This is to account for the remote
// possibility of different allocation patterns between ranks. For example,
// rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
AllReduceMeta
*
m
=
reinterpret_cast
<
AllReduceMeta
*>
(
_fa
);
std
::
vector
<
std
::
string
>
handle_bytes
;
handle_bytes
.
reserve
(
handles
.
size
());
for
(
int
i
=
0
;
i
<
handles
.
size
();
i
++
)
{
handle_bytes
.
emplace_back
(
handles
[
i
].
begin
(),
handles
[
i
].
end
());
}
auto
num_buffers
=
m
->
graph_unreg_buffers
.
size
();
std
::
vector
<
RankData
>
rank_data
(
num_buffers
);
for
(
int
i
=
0
;
i
<
num_buffers
;
i
++
)
{
auto
self_ptr
=
m
->
graph_unreg_buffers
[
i
];
auto
&
rd
=
rank_data
[
i
];
for
(
int
j
=
0
;
j
<
m
->
world_size
;
j
++
)
{
if
(
j
!=
m
->
rank_id
)
{
char
*
handle
=
open_ipc_handle
(
m
,
&
handle_bytes
[
j
][
i
*
sizeof
(
cudaIpcMemHandle_t
)]);
handle
+=
offsets
[
j
][
i
];
rd
.
ptrs
[
j
]
=
handle
;
}
else
{
rd
.
ptrs
[
j
]
=
self_ptr
;
}
}
}
CHECK_CUDA_SUCCESS
(
cudaMemcpy
(
m
->
rank_data_base
,
rank_data
.
data
(),
sizeof
(
RankData
)
*
num_buffers
,
cudaMemcpyHostToDevice
));
m
->
rank_data_base
+=
num_buffers
;
m
->
graph_unreg_buffers
.
clear
();
}
void
all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
)
{
AllReduceMeta
*
m
=
reinterpret_cast
<
AllReduceMeta
*>
(
_fa
);
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
num_elements
=
inp
.
numel
();
auto
dtype
=
inp
.
scalar_type
();
AllReduceStrategyType
strategy
=
SelectImplementation
(
num_elements
*
((
get_bits
(
dtype
)
+
7
)
/
8
),
m
->
world_size
);
// should be gurantee in python code
assert
(
strategy
==
AllReduceStrategyType
::
ONESHOT
||
strategy
==
AllReduceStrategyType
::
TWOSHOT
);
assert
(
CanApplyCustomAllReduce
(
num_elements
,
dtype
));
// Initialize the all-reduce kernel arguments.
int
world_size
=
m
->
world_size
;
AllReduceParams
params
;
params
.
ranks_per_node
=
world_size
;
params
.
rank
=
m
->
rank_id
;
params
.
local_rank
=
m
->
rank_id
;
params
.
local_input_buffer_ptr
=
inp
.
data_ptr
();
params
.
local_output_buffer_ptr
=
out
.
data_ptr
();
params
.
elts_total
=
inp
.
numel
();
params
.
elts_size
=
inp
.
element_size
();
params
.
barrier_flag
=
++
(
m
->
barrier_flag
);
cudaStreamCaptureStatus
status
;
CHECK_CUDA_SUCCESS
(
cudaStreamIsCapturing
(
stream
,
&
status
));
params
.
is_capturing
=
(
status
==
cudaStreamCaptureStatusActive
);
if
(
params
.
is_capturing
)
{
params
.
peer_comm_buffer_ptrs
=
m
->
rank_data_base
+
m
->
graph_unreg_buffers
.
size
();
m
->
graph_unreg_buffers
.
push_back
(
params
.
local_input_buffer_ptr
);
}
else
{
params
.
peer_comm_buffer_ptrs
=
m
->
buffers
;
}
for
(
int
i
=
0
;
i
<
world_size
;
++
i
)
{
params
.
tmp_result_buffers
[
i
]
=
reinterpret_cast
<
uint32_t
*>
(
m
->
tmp_result_buffers
[
i
]);
}
for
(
int
i
=
0
;
i
<
world_size
;
++
i
)
{
params
.
peer_barrier_ptrs_in
[
i
]
=
reinterpret_cast
<
uint32_t
*>
(
m
->
barrier_in
[
i
]);
}
for
(
int
i
=
0
;
i
<
world_size
;
++
i
)
{
params
.
peer_barrier_ptrs_out
[
i
]
=
reinterpret_cast
<
uint32_t
*>
(
m
->
barrier_out
[
i
]);
}
auto
data_type
=
out
.
scalar_type
();
trtCustomAllReduce
(
params
,
data_type
,
strategy
,
stream
);
}
sgl-kernel/csrc/common_extension.cc
View file @
bcbbf519
...
@@ -26,15 +26,18 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -26,15 +26,18 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
def
(
"get_graph_buffer_ipc_meta"
,
&
get_graph_buffer_ipc_meta
);
m
.
def
(
"get_graph_buffer_ipc_meta"
,
&
get_graph_buffer_ipc_meta
);
m
.
def
(
"register_graph_buffers"
,
&
register_graph_buffers
);
m
.
def
(
"register_graph_buffers"
,
&
register_graph_buffers
);
m
.
def
(
"dispose"
,
&
dispose
);
m
.
def
(
"dispose"
,
&
dispose
);
m
.
def
(
"meta_size"
,
&
meta_size
);
m
.
def
(
"register_buffer"
,
&
register_buffer
);
m
.
def
(
m
.
def
(
"init_custom_ar(int
rank_id, int world_size, Tensor rank_data, int[] buffers, int[] tmp_result_buffers, int[]
"
"init_custom_ar(int
[] ipc_tensors, Tensor rank_data,
"
"
barrier_in, int[] barrier_out
) -> int"
);
"
int rank, bool full_nvlink
) -> int"
);
m
.
impl
(
"init_custom_ar"
,
torch
::
kCUDA
,
&
init_custom_ar
);
m
.
impl
(
"init_custom_ar"
,
torch
::
kCUDA
,
&
init_custom_ar
);
m
.
def
(
"all_reduce(int fa, Tensor inp, Tensor! out) -> ()"
);
m
.
def
(
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
"int reg_buffer_sz_bytes) -> ()"
);
m
.
impl
(
"all_reduce"
,
torch
::
kCUDA
,
&
all_reduce
);
m
.
impl
(
"all_reduce"
,
torch
::
kCUDA
,
&
all_reduce
);
/*
/*
* From csrc/attention
* From csrc/attention
*/
*/
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
bcbbf519
...
@@ -21,6 +21,7 @@ limitations under the License.
...
@@ -21,6 +21,7 @@ limitations under the License.
#include <torch/library.h>
#include <torch/library.h>
#include <torch/torch.h>
#include <torch/torch.h>
#include <tuple>
#include <vector>
#include <vector>
#define _CONCAT(A, B) A##B
#define _CONCAT(A, B) A##B
...
@@ -63,18 +64,14 @@ void register_graph_buffers(
...
@@ -63,18 +64,14 @@ void register_graph_buffers(
torch
::
Tensor
allocate_meta_buffer
(
int64_t
size
);
torch
::
Tensor
allocate_meta_buffer
(
int64_t
size
);
torch
::
Tensor
get_meta_buffer_ipc_handle
(
torch
::
Tensor
&
inp
);
torch
::
Tensor
get_meta_buffer_ipc_handle
(
torch
::
Tensor
&
inp
);
#else
#else
// TRTLLM custom allreduce
// custom allreduce
fptr_t
init_custom_ar
(
fptr_t
int64_t
rank_id
,
init_custom_ar
(
const
std
::
vector
<
fptr_t
>&
fake_ipc_ptrs
,
torch
::
Tensor
&
rank_data
,
int64_t
rank
,
bool
full_nvlink
);
int64_t
world_size
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
fptr_t
>&
buffers
,
const
std
::
vector
<
fptr_t
>&
tmp_result_buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
const
std
::
vector
<
fptr_t
>&
barrier_out
);
void
dispose
(
fptr_t
_fa
);
void
dispose
(
fptr_t
_fa
);
void
all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
);
int64_t
meta_size
();
void
all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
fptr_t
_reg_buffer
,
int64_t
reg_buffer_sz_bytes
);
std
::
tuple
<
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
);
std
::
tuple
<
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
);
void
register_buffer
(
fptr_t
_fa
,
const
std
::
vector
<
fptr_t
>&
fake_ipc_ptrs
);
void
register_graph_buffers
(
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
fptr_t
_fa
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
#endif
#endif
...
...
sgl-kernel/include/trt_reduce_internal.cuh
deleted
100644 → 0
View file @
0d99adb7
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// reference:
// https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_fp16.h>
#include <stdint.h>
#include <torch/all.h>
namespace
trt_llm
{
constexpr
size_t
WARP_SIZE
=
32
;
constexpr
size_t
MAX_ALL_REDUCE_BLOCKS
=
32
;
constexpr
size_t
MAX_RANKS_PER_NODE
=
8
;
constexpr
size_t
DEFAULT_BLOCK_SIZE
=
512
;
enum
class
AllReduceStrategyType
:
int8_t
{
RING
=
0
,
ONESHOT
=
1
,
TWOSHOT
=
2
,
AUTO
=
3
,
};
struct
RankData
{
void
*
ptrs
[
MAX_RANKS_PER_NODE
];
};
struct
AllReduceParams
{
size_t
elts_size
;
size_t
elts_total
;
size_t
elts_per_rank
;
size_t
elts_per_block
;
size_t
rank_offset
;
size_t
ranks_per_node
,
rank
,
local_rank
;
uint32_t
barrier_flag
;
uint32_t
*
peer_barrier_ptrs_in
[
MAX_RANKS_PER_NODE
];
uint32_t
*
peer_barrier_ptrs_out
[
MAX_RANKS_PER_NODE
];
uint32_t
*
tmp_result_buffers
[
MAX_RANKS_PER_NODE
];
RankData
*
peer_comm_buffer_ptrs
;
void
*
local_input_buffer_ptr
;
void
*
local_output_buffer_ptr
;
bool
is_capturing
;
};
inline
size_t
GetMaxRequiredWorkspaceSize
(
int
world_size
)
{
if
(
world_size
<=
2
)
{
return
16
*
1024
*
1024
;
}
return
8
*
1024
*
1024
;
}
inline
AllReduceStrategyType
SelectImplementation
(
size_t
message_size
,
int
world_size
)
{
const
size_t
maxWorkspaceSize
=
GetMaxRequiredWorkspaceSize
(
world_size
);
if
(
message_size
>
maxWorkspaceSize
)
{
assert
(
false
&&
"Custom allreduce do not ring currently"
);
return
AllReduceStrategyType
::
RING
;
}
if
(
world_size
<=
2
)
{
return
AllReduceStrategyType
::
ONESHOT
;
}
if
(
world_size
<=
4
)
{
if
(
message_size
<
1
*
1024
*
1024
)
{
return
AllReduceStrategyType
::
ONESHOT
;
}
return
AllReduceStrategyType
::
TWOSHOT
;
}
if
(
message_size
<
512
*
1024
)
{
return
AllReduceStrategyType
::
ONESHOT
;
}
return
AllReduceStrategyType
::
TWOSHOT
;
}
void
trtCustomAllReduce
(
AllReduceParams
&
params
,
at
::
ScalarType
data_type
,
AllReduceStrategyType
strat
,
cudaStream_t
stream
);
}
// namespace trt_llm
sgl-kernel/python/sgl_kernel/allreduce.py
View file @
bcbbf519
...
@@ -50,28 +50,38 @@ if torch.version.hip is not None:
...
@@ -50,28 +50,38 @@ if torch.version.hip is not None:
return
torch
.
ops
.
sgl_kernel
.
get_meta_buffer_ipc_handle
.
default
(
inp
)
return
torch
.
ops
.
sgl_kernel
.
get_meta_buffer_ipc_handle
.
default
(
inp
)
else
:
else
:
# TRTLLM custom allreduce
def
init_custom_r
educe
(
def
init_custom_
a
r
(
rank_id
,
num_devices
,
rank_data
,
buffers
,
tmp_buffers
,
barrier_in
,
barrier_out
ipc_tensors
:
List
[
int
],
rank_data
:
torch
.
Tensor
,
rank
:
int
,
full_nvlink
:
bool
):
)
->
int
:
return
torch
.
ops
.
sgl_kernel
.
init_custom_ar
.
default
(
return
torch
.
ops
.
sgl_kernel
.
init_custom_ar
.
default
(
rank_id
,
ipc_tensors
,
rank_data
,
rank
,
full_nvlink
num_devices
,
rank_data
,
buffers
,
tmp_buffers
,
barrier_in
,
barrier_out
,
)
)
def
custom_
dispose
(
fa
)
:
def
dispose
(
fa
:
int
)
->
None
:
torch
.
ops
.
sgl_kernel
.
dispose
.
default
(
fa
)
torch
.
ops
.
sgl_kernel
.
dispose
.
default
(
fa
)
def
custom_reduce
(
fa
,
inp
,
out
):
def
all_reduce
(
torch
.
ops
.
sgl_kernel
.
all_reduce
.
default
(
fa
,
inp
,
out
)
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
reg_buffer
:
int
,
reg_buffer_sz_bytes
:
int
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
all_reduce
.
default
(
fa
,
inp
,
out
,
reg_buffer
,
reg_buffer_sz_bytes
)
def
get_graph_buffer_ipc_meta
(
fa
):
def
get_graph_buffer_ipc_meta
(
fa
)
->
Tuple
[
List
[
int
],
List
[
int
]]
:
return
torch
.
ops
.
sgl_kernel
.
get_graph_buffer_ipc_meta
.
default
(
fa
)
return
torch
.
ops
.
sgl_kernel
.
get_graph_buffer_ipc_meta
.
default
(
fa
)
def
register_graph_buffers
(
fa
,
handles
,
offsets
):
def
register_buffer
(
fa
:
int
,
fake_ipc_ptrs
:
List
[
int
])
->
None
:
return
torch
.
ops
.
sgl_kernel
.
register_buffer
.
default
(
fa
,
fake_ipc_ptrs
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
List
[
int
]],
offsets
:
List
[
List
[
int
]]
)
->
None
:
torch
.
ops
.
sgl_kernel
.
register_graph_buffers
.
default
(
fa
,
handles
,
offsets
)
torch
.
ops
.
sgl_kernel
.
register_graph_buffers
.
default
(
fa
,
handles
,
offsets
)
def
meta_size
()
->
int
:
return
torch
.
ops
.
sgl_kernel
.
meta_size
.
default
()
sgl-kernel/tests/test_
trt
_allreduce.py
→
sgl-kernel/tests/test_
custom
_allreduce.py
View file @
bcbbf519
...
@@ -16,7 +16,6 @@ from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibra
...
@@ -16,7 +16,6 @@ from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibra
def
_run_correctness_worker
(
world_size
,
rank
,
distributed_init_port
,
test_sizes
):
def
_run_correctness_worker
(
world_size
,
rank
,
distributed_init_port
,
test_sizes
):
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
ranks
=
list
(
range
(
world_size
))
distributed_init_method
=
f
"tcp://localhost:
{
distributed_init_port
}
"
distributed_init_method
=
f
"tcp://localhost:
{
distributed_init_port
}
"
dist
.
init_process_group
(
dist
.
init_process_group
(
backend
=
"nccl"
,
backend
=
"nccl"
,
...
@@ -26,39 +25,18 @@ def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes)
...
@@ -26,39 +25,18 @@ def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes)
)
)
group
=
dist
.
group
.
WORLD
group
=
dist
.
group
.
WORLD
buffer_max_size
=
8
*
1024
*
1024
barrier_max_size
=
8
*
(
24
+
2
)
*
8
buffer_ptrs
=
None
tmp_result_buffer_ptrs
=
None
barrier_in_ptrs
=
None
barrier_out_ptrs
=
None
custom_ptr
=
None
try
:
try
:
buffer_ptrs
=
TestCustomAllReduce
.
create_shared_buffer
(
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
buffer_max_size
,
group
=
group
max_size
=
8192
*
1024
)
meta_ptrs
=
TestCustomAllReduce
.
create_shared_buffer
(
tmp_result_buffer_ptrs
=
TestCustomAllReduce
.
create_shared_buffer
(
custom_ops
.
meta_size
()
+
max_size
,
group
=
group
buffer_max_size
,
group
=
group
)
barrier_in_ptrs
=
TestCustomAllReduce
.
create_shared_buffer
(
barrier_max_size
,
group
=
group
)
barrier_out_ptrs
=
TestCustomAllReduce
.
create_shared_buffer
(
barrier_max_size
,
group
=
group
)
)
rank_data
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
device
)
rank_data
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
device
)
buffer_ptrs
=
TestCustomAllReduce
.
create_shared_buffer
(
max_size
,
group
=
group
)
custom_ptr
=
custom_ops
.
init_custom_reduce
(
custom_ptr
=
custom_ops
.
init_custom_ar
(
meta_ptrs
,
rank_data
,
rank
,
True
)
rank
,
custom_ops
.
register_buffer
(
custom_ptr
,
buffer_ptrs
)
world_size
,
rank_data
,
buffer_ptrs
,
tmp_result_buffer_ptrs
,
barrier_in_ptrs
,
barrier_out_ptrs
,
)
test_loop
=
10
test_loop
=
10
for
sz
in
test_sizes
:
for
sz
in
test_sizes
:
...
@@ -68,7 +46,9 @@ def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes)
...
@@ -68,7 +46,9 @@ def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes)
inp1_ref
=
inp1
.
clone
()
inp1_ref
=
inp1
.
clone
()
out1
=
torch
.
empty_like
(
inp1
)
out1
=
torch
.
empty_like
(
inp1
)
custom_ops
.
custom_reduce
(
custom_ptr
,
inp1
,
out1
)
custom_ops
.
all_reduce
(
custom_ptr
,
inp1
,
out1
,
buffer_ptrs
[
rank
],
max_size
)
dist
.
all_reduce
(
inp1_ref
,
group
=
group
)
dist
.
all_reduce
(
inp1_ref
,
group
=
group
)
...
@@ -77,15 +57,11 @@ def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes)
...
@@ -77,15 +57,11 @@ def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes)
finally
:
finally
:
dist
.
barrier
(
group
=
group
)
dist
.
barrier
(
group
=
group
)
if
custom_ptr
is
not
None
:
if
custom_ptr
is
not
None
:
custom_ops
.
custom_
dispose
(
custom_ptr
)
custom_ops
.
dispose
(
custom_ptr
)
if
buffer_ptrs
:
if
buffer_ptrs
:
TestCustomAllReduce
.
free_shared_buffer
(
buffer_ptrs
,
group
)
TestCustomAllReduce
.
free_shared_buffer
(
buffer_ptrs
,
group
)
if
tmp_result_buffer_ptrs
:
if
meta_ptrs
:
TestCustomAllReduce
.
free_shared_buffer
(
tmp_result_buffer_ptrs
,
group
)
TestCustomAllReduce
.
free_shared_buffer
(
meta_ptrs
,
group
)
if
barrier_in_ptrs
:
TestCustomAllReduce
.
free_shared_buffer
(
barrier_in_ptrs
,
group
)
if
barrier_out_ptrs
:
TestCustomAllReduce
.
free_shared_buffer
(
barrier_out_ptrs
,
group
)
dist
.
destroy_process_group
(
group
=
group
)
dist
.
destroy_process_group
(
group
=
group
)
...
@@ -122,7 +98,18 @@ def multi_process_parallel(
...
@@ -122,7 +98,18 @@ def multi_process_parallel(
class
TestCustomAllReduce
(
unittest
.
TestCase
):
class
TestCustomAllReduce
(
unittest
.
TestCase
):
test_sizes
=
[
512
,
4096
,
32768
,
262144
,
524288
,
1048576
,
2097152
]
test_sizes
=
[
512
,
2560
,
4096
,
5120
,
7680
,
32768
,
262144
,
524288
,
1048576
,
2097152
,
]
world_sizes
=
[
2
,
4
,
8
]
world_sizes
=
[
2
,
4
,
8
]
@
staticmethod
@
staticmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment