Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3900a94a
Unverified
Commit
3900a94a
authored
Jan 06, 2025
by
yizhang2077
Committed by
GitHub
Jan 06, 2025
Browse files
Support twoshot kernel (#2688)
parent
ded9fcd0
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
216 additions
and
21 deletions
+216
-21
sgl-kernel/pyproject.toml
sgl-kernel/pyproject.toml
+1
-1
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu
+204
-2
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh
+6
-8
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu
+1
-1
sgl-kernel/tests/test_trt_reduce.py
sgl-kernel/tests/test_trt_reduce.py
+4
-9
No files found.
sgl-kernel/pyproject.toml
View file @
3900a94a
...
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
...
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
[project]
name
=
"sgl-kernel"
name
=
"sgl-kernel"
version
=
"0.0.2.post1
0
"
version
=
"0.0.2.post1
1
"
description
=
"Kernel Library for SGLang"
description
=
"Kernel Library for SGLang"
readme
=
"README.md"
readme
=
"README.md"
requires-python
=
">=3.8"
requires-python
=
">=3.8"
...
...
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu
View file @
3900a94a
...
@@ -41,6 +41,16 @@ static inline __device__ uint32_t ld_flag_acquire(uint32_t* flag_addr) {
...
@@ -41,6 +41,16 @@ static inline __device__ uint32_t ld_flag_acquire(uint32_t* flag_addr) {
return
flag
;
return
flag
;
}
}
static
inline
__device__
void
st_flag_volatile
(
uint32_t
const
&
flag
,
uint32_t
*
flag_addr
)
{
asm
volatile
(
"st.volatile.global.u32 [%1], %0;"
::
"r"
(
flag
),
"l"
(
flag_addr
));
}
static
inline
__device__
uint32_t
ld_flag_volatile
(
uint32_t
*
flag_addr
)
{
uint32_t
flag
;
asm
volatile
(
"ld.volatile.global.u32 %0, [%1];"
:
"=r"
(
flag
)
:
"l"
(
flag_addr
));
return
flag
;
}
namespace
trt_llm
{
namespace
trt_llm
{
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
@@ -116,6 +126,45 @@ __inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const
...
@@ -116,6 +126,45 @@ __inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const
__syncthreads
();
__syncthreads
();
}
}
__inline__
__device__
void
block_barrier
(
uint32_t
**
signals
,
uint32_t
const
flag
,
size_t
const
local_rank
,
size_t
const
world_size
,
int
const
tidx
,
int
const
bidx
,
int
const
grid_size
,
bool
start
=
true
,
bool
need_fence
=
false
)
{
if
(
!
start
)
{
__syncthreads
();
}
// After this function, the block of id == bidx of each GPU has reached the barrier
if
(
tidx
<
world_size
)
{
// we can think of signals having the shape [world_size, 2, num_blocks, world_size]
// (+ an offset on dim 2 to account for flags used in multi_gpu_barrier)
// Dimension 0 is the "listening" dimension, dimension 3 is "emitting" dimension
// Block broadcast its flag (local_rank on emitting dimension) to all receivers
uint32_t
flag_block_offset
=
world_size
+
bidx
*
world_size
;
if
(
flag
%
2
==
1
)
{
flag_block_offset
+=
(
grid_size
+
1
)
*
world_size
;
}
if
(
need_fence
)
{
st_flag_release
(
flag
,
signals
[
tidx
]
+
flag_block_offset
+
local_rank
);
}
else
{
st_flag_volatile
(
flag
,
signals
[
tidx
]
+
flag_block_offset
+
local_rank
);
}
// Blocks check that corresponding blocks on other GPUs have also set the flag
uint32_t
*
peer_barrier_d
=
signals
[
local_rank
]
+
flag_block_offset
+
tidx
;
if
(
need_fence
)
{
while
(
ld_flag_acquire
(
peer_barrier_d
)
!=
flag
)
{
}
}
else
{
while
(
ld_flag_volatile
(
peer_barrier_d
)
!=
flag
)
{
}
}
}
__syncthreads
();
}
template
<
typename
T
,
int
RANKS_PER_NODE
>
/* COPY_INPUT = false, PUSH_MODE = false */
template
<
typename
T
,
int
RANKS_PER_NODE
>
/* COPY_INPUT = false, PUSH_MODE = false */
static
__global__
void
oneShotAllReduceKernel
(
AllReduceParams
params
)
{
static
__global__
void
oneShotAllReduceKernel
(
AllReduceParams
params
)
{
// Suppose that two GPUs participate in the AR exchange, and we start four blocks.
// Suppose that two GPUs participate in the AR exchange, and we start four blocks.
...
@@ -189,6 +238,124 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
...
@@ -189,6 +238,124 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
}
}
}
}
template
<
typename
T
,
int
RANKS_PER_NODE
>
static
__global__
void
__launch_bounds__
(
512
,
1
)
twoShotAllReduceKernel
(
AllReduceParams
params
)
{
// Suppose that two GPUs participate in the AR exchange, and we start two blocks.
// The message is partitioned into chunks as detailed below:
// message
// |-------------------|
// |--GPU 0--|--GPU 1--| (GPU responsibility parts)
// GPU 0 | B0 | B1 | B0 | B1 |
// GPU 1 | B0 | B1 | B0 | B1 |
//
// Here the step-by-step behavior of one block:
// 1. B0 copies all chunks is it responsible for, from local_input to shareable buffer
// 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier #0)
// 3. B0 on GPU 0 gather and sum the B0 chunks from GPU 1, that are in the GPU 0 responsibility
// part (the first half of the message, see GPU responsibility row above)
// 3bis. Likewise, B0 on GPU 1 copies and sum the chunks for GPU 0,
// where GPU 1 is responsible: the second half of the message.
// 4. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier #1)
// 5. B0 writes result to local_output. It gathers each chunk from its responsible GPU.
// For example, here it reads the first chunk from GPU 0 and second chunk from GPU 1.
//
// With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2.
// We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready
// to be read.
//
// Note that compared to one-shot, one block (CTA) writes multiple input chunks and write multiple output chunks.
// However, it's only responsible for the summation of a single chunk.
//
// With PUSH_MODE, we consider that the shared buffer is of size:
// params.peer_comm_buffer_ptrs: [world_size, world_size, message_size / world_size]
//
// Here the step-by-step behavior of one block:
// 1. B0 push the chunks is it responsible for into the corresponding GPUs:
// params.peer_comm_buffer_ptrs[target_gpu, local_gpu, current B0 slice]
// 2. block sync so the blocks have been shared by other GPUs
// 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice]
// 4. block barrier (corresponding blocks have finished reduction)
// 5. pull and write on local buffer, by reading params.peer_comm_buffer_ptrs[:, 0, B0 slice] (reduction result is
// written at index 0 of 2nd dim)
int
const
bidx
=
blockIdx
.
x
;
int
const
tidx
=
threadIdx
.
x
;
int
const
grid_size
=
gridDim
.
x
;
// The number of elements packed into one for comms
static
constexpr
int
PACKED_ELTS
=
16
/
sizeof
(
T
);
using
PackedType
=
typename
PackedOn16Bytes
<
T
>::
Type
;
T
*
local_shared_buffer
=
reinterpret_cast
<
T
*>
(
params
.
peer_comm_buffer_ptrs
[
params
.
local_rank
]);
T
*
local_output_buffer
=
reinterpret_cast
<
T
*>
(
params
.
local_output_buffer_ptr
);
size_t
const
chunk_start
=
bidx
*
params
.
elts_per_block
+
tidx
*
PACKED_ELTS
;
size_t
const
chunk_end
=
min
(
chunk_start
+
params
.
elts_per_block
,
params
.
elts_per_rank
);
T
*
buffers
[
RANKS_PER_NODE
];
int
ranks
[
RANKS_PER_NODE
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
RANKS_PER_NODE
;
++
ii
)
{
// A mapping of the ranks to scatter reads as much as possible
int
rank
=
(
params
.
local_rank
+
ii
)
%
RANKS_PER_NODE
;
ranks
[
ii
]
=
rank
;
buffers
[
ii
]
=
reinterpret_cast
<
T
*>
(
params
.
peer_comm_buffer_ptrs
[
rank
]);
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize
();
#endif
block_barrier
(
params
.
peer_barrier_ptrs_in
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
,
grid_size
);
// Each block accumulates the values from the different GPUs on the same node.
for
(
size_t
local_offset
=
chunk_start
;
local_offset
<
chunk_end
;
local_offset
+=
blockDim
.
x
*
PACKED_ELTS
)
{
size_t
const
responsible_block_offset
=
local_offset
+
params
.
rank_offset
;
// Iterate over the different ranks/devices on the node to load the values.
PackedType
vals
[
RANKS_PER_NODE
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
RANKS_PER_NODE
;
++
ii
)
{
vals
[
ii
].
packed
=
*
reinterpret_cast
<
int4
const
*>
(
&
buffers
[
ii
][
responsible_block_offset
]);
}
// Sum the values from the different ranks.
PackedType
sums
;
sums
.
packed
=
{
0
,
0
,
0
,
0
};
#pragma unroll
for
(
int
rank
=
0
;
rank
<
RANKS_PER_NODE
;
++
rank
)
{
// Always reduce from rank 0 to ensure stable reduce order.
int
ii
=
(
rank
+
RANKS_PER_NODE
-
params
.
local_rank
)
%
RANKS_PER_NODE
;
sums
.
packed
=
add128b
(
sums
,
vals
[
ii
]);
}
// Store to the local buffer.
*
reinterpret_cast
<
int4
*>
(
&
local_shared_buffer
[
responsible_block_offset
])
=
sums
.
packed
;
}
block_barrier
(
params
.
peer_barrier_ptrs_out
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
,
grid_size
,
false
,
true
);
// Gather all needed elts from other intra-node ranks
for
(
size_t
local_offset
=
chunk_start
;
local_offset
<
chunk_end
;
local_offset
+=
blockDim
.
x
*
PACKED_ELTS
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
RANKS_PER_NODE
;
++
ii
)
{
// use round-robin gathering from other ranks
size_t
offset_rank
=
ranks
[
ii
]
*
params
.
elts_per_rank
+
local_offset
;
if
(
offset_rank
>=
params
.
elts_total
)
{
continue
;
}
*
reinterpret_cast
<
int4
*>
(
&
local_output_buffer
[
offset_rank
])
=
*
reinterpret_cast
<
int4
*>
(
&
buffers
[
ii
][
offset_rank
]);
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion
();
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
int
divUp
(
int
a
,
int
b
)
{
inline
int
divUp
(
int
a
,
int
b
)
{
...
@@ -211,6 +378,33 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
...
@@ -211,6 +378,33 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
params
.
elts_per_rank
=
params
.
elts_total
;
params
.
elts_per_rank
=
params
.
elts_total
;
break
;
break
;
}
}
case
AllReduceStrategyType
::
TWOSHOT
:
{
assert
(
params
.
elts_total
%
(
elts_per_thread
*
params
.
ranks_per_node
)
==
0
);
size_t
const
total_threads
=
roundUp
(
params
.
elts_total
/
(
elts_per_thread
*
params
.
ranks_per_node
),
WARP_SIZE
);
/*
threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads);
blocks_per_grid = std::min(static_cast<size_t>(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block));
*/
while
(
total_threads
%
blocks_per_grid
!=
0
||
total_threads
/
blocks_per_grid
>
DEFAULT_BLOCK_SIZE
)
{
blocks_per_grid
+=
1
;
}
threads_per_block
=
total_threads
/
blocks_per_grid
;
// NOTE: need to adjust here
if
(
blocks_per_grid
>
MAX_ALL_REDUCE_BLOCKS
)
{
size_t
iter_factor
=
1
;
while
(
blocks_per_grid
/
iter_factor
>
MAX_ALL_REDUCE_BLOCKS
||
blocks_per_grid
%
iter_factor
)
{
iter_factor
+=
1
;
}
blocks_per_grid
/=
iter_factor
;
}
params
.
elts_per_rank
=
params
.
elts_total
/
params
.
ranks_per_node
;
params
.
rank_offset
=
params
.
local_rank
*
params
.
elts_per_rank
;
params
.
elts_per_block
=
roundUp
(
divUp
(
params
.
elts_per_rank
,
blocks_per_grid
),
elts_per_thread
);
break
;
}
default:
default:
assert
(
false
&&
"Algorithm not supported here."
);
assert
(
false
&&
"Algorithm not supported here."
);
}
}
...
@@ -223,7 +417,16 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
...
@@ -223,7 +417,16 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
template
<
typename
T
,
int
RANKS_PER_NODE
>
template
<
typename
T
,
int
RANKS_PER_NODE
>
void
dispatchARKernels
(
AllReduceStrategyType
algo
,
AllReduceParams
&
param
,
int
blocks_per_grid
,
int
threads_per_block
,
void
dispatchARKernels
(
AllReduceStrategyType
algo
,
AllReduceParams
&
param
,
int
blocks_per_grid
,
int
threads_per_block
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
oneShotAllReduceKernel
<
T
,
RANKS_PER_NODE
><<<
blocks_per_grid
,
threads_per_block
,
0
,
stream
>>>
(
param
);
switch
(
algo
)
{
case
AllReduceStrategyType
::
ONESHOT
:
{
oneShotAllReduceKernel
<
T
,
RANKS_PER_NODE
><<<
blocks_per_grid
,
threads_per_block
,
0
,
stream
>>>
(
param
);
break
;
}
case
AllReduceStrategyType
::
TWOSHOT
:
{
twoShotAllReduceKernel
<
T
,
RANKS_PER_NODE
><<<
blocks_per_grid
,
threads_per_block
,
0
,
stream
>>>
(
param
);
break
;
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -233,7 +436,6 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategy
...
@@ -233,7 +436,6 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategy
CHECK_CUDA_SUCCESS
(
CHECK_CUDA_SUCCESS
(
cudaMemcpyAsync
(
buffer
,
local_inp_buffer
,
param
.
elts_total
*
param
.
elts_size
,
cudaMemcpyDeviceToDevice
,
stream
));
cudaMemcpyAsync
(
buffer
,
local_inp_buffer
,
param
.
elts_total
*
param
.
elts_size
,
cudaMemcpyDeviceToDevice
,
stream
));
assert
(
strat
==
AllReduceStrategyType
::
ONESHOT
&&
"Custom allreduce only support oneshot"
);
CHECK_CUDA_SUCCESS
(
cudaGetLastError
());
CHECK_CUDA_SUCCESS
(
cudaGetLastError
());
size_t
elts_per_thread
=
16
/
sizeof
(
T
);
size_t
elts_per_thread
=
16
/
sizeof
(
T
);
...
...
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh
View file @
3900a94a
...
@@ -25,9 +25,9 @@
...
@@ -25,9 +25,9 @@
namespace
trt_llm
{
namespace
trt_llm
{
constexpr
size_t
WARP_SIZE
=
32
;
constexpr
size_t
WARP_SIZE
=
32
;
constexpr
size_t
MAX_ALL_REDUCE_BLOCKS
=
24
;
constexpr
size_t
MAX_ALL_REDUCE_BLOCKS
=
36
;
constexpr
size_t
MAX_RANKS_PER_NODE
=
8
;
constexpr
size_t
MAX_RANKS_PER_NODE
=
8
;
constexpr
size_t
DEFAULT_BLOCK_SIZE
=
1024
;
constexpr
size_t
DEFAULT_BLOCK_SIZE
=
512
;
enum
class
AllReduceStrategyType
:
int8_t
{
enum
class
AllReduceStrategyType
:
int8_t
{
RING
=
0
,
RING
=
0
,
...
@@ -53,9 +53,9 @@ struct AllReduceParams {
...
@@ -53,9 +53,9 @@ struct AllReduceParams {
inline
size_t
GetMaxRequiredWorkspaceSize
(
int
world_size
)
{
inline
size_t
GetMaxRequiredWorkspaceSize
(
int
world_size
)
{
if
(
world_size
<=
2
)
{
if
(
world_size
<=
2
)
{
return
16
*
10
00
*
10
00
;
return
16
*
10
24
*
10
24
;
}
}
return
8
*
10
00
*
10
00
;
return
8
*
10
24
*
10
24
;
}
}
inline
AllReduceStrategyType
SelectImplementation
(
size_t
message_size
,
int
world_size
)
{
inline
AllReduceStrategyType
SelectImplementation
(
size_t
message_size
,
int
world_size
)
{
...
@@ -71,17 +71,15 @@ inline AllReduceStrategyType SelectImplementation(size_t message_size, int world
...
@@ -71,17 +71,15 @@ inline AllReduceStrategyType SelectImplementation(size_t message_size, int world
}
}
if
(
world_size
<=
4
)
{
if
(
world_size
<=
4
)
{
if
(
message_size
<
1
*
10
00
*
10
00
)
{
if
(
message_size
<
1
*
10
24
*
10
24
)
{
return
AllReduceStrategyType
::
ONESHOT
;
return
AllReduceStrategyType
::
ONESHOT
;
}
}
assert
(
false
&&
"Custom allreduce do not twoshot currently"
);
return
AllReduceStrategyType
::
TWOSHOT
;
return
AllReduceStrategyType
::
TWOSHOT
;
}
}
if
(
message_size
<
5
00
*
10
00
)
{
if
(
message_size
<
5
12
*
10
24
)
{
return
AllReduceStrategyType
::
ONESHOT
;
return
AllReduceStrategyType
::
ONESHOT
;
}
}
assert
(
false
&&
"Custom allreduce do not twoshot currently"
);
return
AllReduceStrategyType
::
TWOSHOT
;
return
AllReduceStrategyType
::
TWOSHOT
;
}
}
...
...
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu
View file @
3900a94a
...
@@ -71,7 +71,7 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
...
@@ -71,7 +71,7 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
AllReduceStrategyType
strategy
=
SelectImplementation
(
num_elements
*
((
get_bits
(
dtype
)
+
7
)
/
8
),
m
->
world_size
);
AllReduceStrategyType
strategy
=
SelectImplementation
(
num_elements
*
((
get_bits
(
dtype
)
+
7
)
/
8
),
m
->
world_size
);
// should be gurantee in python code
// should be gurantee in python code
assert
(
strategy
==
AllReduceStrategyType
::
ONESHOT
);
assert
(
strategy
==
AllReduceStrategyType
::
ONESHOT
||
strategy
==
AllReduceStrategyType
::
TWOSHOT
);
assert
(
CanApplyCustomAllReduce
(
num_elements
,
dtype
));
assert
(
CanApplyCustomAllReduce
(
num_elements
,
dtype
));
// Initialize the all-reduce kernel arguments.
// Initialize the all-reduce kernel arguments.
...
...
sgl-kernel/tests/test_trt_reduce.py
View file @
3900a94a
...
@@ -55,13 +55,8 @@ class TestCustomAllReduce(unittest.TestCase):
...
@@ -55,13 +55,8 @@ class TestCustomAllReduce(unittest.TestCase):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
random
.
seed
(
42
)
random
.
seed
(
42
)
cls
.
test_sizes
=
{
cls
.
test_sizes
=
[
512
,
4096
,
32768
,
262144
,
524288
,
1048576
,
2097152
]
2
:
[
512
,
4096
,
32768
,
262144
,
2097152
],
cls
.
world_sizes
=
[
2
,
4
,
8
]
4
:
[
512
,
4096
,
32768
,
131072
],
6
:
[
512
,
4096
,
32768
,
65536
],
8
:
[
512
,
4096
,
32768
,
65536
],
}
cls
.
world_sizes
=
[
2
,
4
,
6
,
8
]
@
staticmethod
@
staticmethod
def
create_shared_buffer
(
def
create_shared_buffer
(
...
@@ -194,7 +189,7 @@ class TestCustomAllReduce(unittest.TestCase):
...
@@ -194,7 +189,7 @@ class TestCustomAllReduce(unittest.TestCase):
self
.
init_custom_allreduce
(
rank
=
rank
,
world_size
=
world_size
,
group
=
group
)
self
.
init_custom_allreduce
(
rank
=
rank
,
world_size
=
world_size
,
group
=
group
)
test_loop
=
10
test_loop
=
10
for
sz
in
self
.
test_sizes
[
world_size
]
:
for
sz
in
self
.
test_sizes
:
for
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]:
for
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]:
for
_
in
range
(
test_loop
):
for
_
in
range
(
test_loop
):
inp1
=
torch
.
randint
(
inp1
=
torch
.
randint
(
...
@@ -216,7 +211,7 @@ class TestCustomAllReduce(unittest.TestCase):
...
@@ -216,7 +211,7 @@ class TestCustomAllReduce(unittest.TestCase):
self
.
init_vllm_allreduce
(
rank
,
group
)
self
.
init_vllm_allreduce
(
rank
,
group
)
self
.
init_custom_allreduce
(
rank
=
rank
,
world_size
=
world_size
,
group
=
group
)
self
.
init_custom_allreduce
(
rank
=
rank
,
world_size
=
world_size
,
group
=
group
)
for
sz
in
self
.
test_sizes
[
world_size
]
:
for
sz
in
self
.
test_sizes
:
inp1
=
torch
.
randint
(
inp1
=
torch
.
randint
(
1
,
16
,
(
sz
,),
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
()
1
,
16
,
(
sz
,),
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
()
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment