Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cc4325b6
Unverified
Commit
cc4325b6
authored
Sep 24, 2024
by
Hanzhi Zhou
Committed by
GitHub
Sep 24, 2024
Browse files
[Bugfix] Fix potentially unsafe custom allreduce synchronization (#8558)
parent
8ff7ced9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
83 additions
and
59 deletions
+83
-59
csrc/custom_all_reduce.cuh
csrc/custom_all_reduce.cuh
+74
-54
csrc/custom_all_reduce_test.cu
csrc/custom_all_reduce_test.cu
+9
-5
No files found.
csrc/custom_all_reduce.cuh
View file @
cc4325b6
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <iostream>
#include <iostream>
#include <array>
#include <limits>
#include <limits>
#include <map>
#include <map>
#include <unordered_map>
#include <unordered_map>
...
@@ -23,17 +24,23 @@
...
@@ -23,17 +24,23 @@
namespace
vllm
{
namespace
vllm
{
constexpr
int
kMaxBlocks
=
64
;
constexpr
int
kMaxBlocks
=
36
;
// note: we don't want to use atomics for signals because peer atomics are no
// Counter may overflow, but it's fine since unsigned int overflow is
// supported on PCIe links
// well-defined behavior.
using
FlagType
=
uint32_t
;
struct
Signal
{
struct
Signal
{
alignas
(
128
)
uint32_t
start
[
kMaxBlocks
][
8
];
alignas
(
128
)
FlagType
self_counter
[
kMaxBlocks
][
8
];
alignas
(
128
)
uint32_t
end
[
kMaxBlocks
][
8
];
// Two sets of peer counters are needed for two syncs. The reason is that
// it's possible for peer GPU block to arrive at the second sync point while
// the current GPU block haven't passed the first sync point. Thus, peer GPU
// may write counter+1 while current GPU is busy waiting for counter. We use
// alternating counter array to avoid this possibility.
alignas
(
128
)
FlagType
peer_counter
[
2
][
kMaxBlocks
][
8
];
};
};
struct
__align__
(
16
)
RankData
{
const
void
*
__restrict__
ptrs
[
8
];
};
struct
__align__
(
16
)
RankData
{
const
void
*
__restrict__
ptrs
[
8
];
};
struct
__align__
(
16
)
RankSignals
{
volatile
Signal
*
signals
[
8
];
};
struct
__align__
(
16
)
RankSignals
{
Signal
*
signals
[
8
];
};
// like std::array, but aligned
// like std::array, but aligned
template
<
typename
T
,
int
sz
>
template
<
typename
T
,
int
sz
>
...
@@ -123,47 +130,60 @@ DINLINE O downcast(array_t<float, O::size> val) {
...
@@ -123,47 +130,60 @@ DINLINE O downcast(array_t<float, O::size> val) {
}
}
}
}
// This function is meant to be used as the first synchronization in the all
static
DINLINE
void
st_flag_release
(
FlagType
*
flag_addr
,
FlagType
flag
)
{
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
asm
volatile
(
"st.release.sys.global.u32 [%1], %0;"
::
"r"
(
flag
),
// prior memory accesses. Note: volatile writes will not be reordered against
"l"
(
flag_addr
));
// other volatile writes.
}
template
<
int
ngpus
>
DINLINE
void
start_sync
(
const
RankSignals
&
sg
,
volatile
Signal
*
self_sg
,
static
DINLINE
FlagType
ld_flag_acquire
(
FlagType
*
flag_addr
)
{
int
rank
)
{
FlagType
flag
;
if
(
threadIdx
.
x
<
ngpus
)
{
asm
volatile
(
"ld.acquire.sys.global.u32 %0, [%1];"
// reset flag for next time
:
"=r"
(
flag
)
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
]
=
0
;
:
"l"
(
flag_addr
));
// simultaneously write to the corresponding flag of all ranks.
return
flag
;
// Latency = 1 p2p write
}
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
]
=
1
;
// wait until we got true from all ranks
static
DINLINE
void
st_flag_volatile
(
FlagType
*
flag_addr
,
FlagType
flag
)
{
while
(
!
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
]);
asm
volatile
(
"st.volatile.global.u32 [%1], %0;"
::
"r"
(
flag
),
"l"
(
flag_addr
));
}
}
__syncthreads
();
static
DINLINE
FlagType
ld_flag_volatile
(
FlagType
*
flag_addr
)
{
FlagType
flag
;
asm
volatile
(
"ld.volatile.global.u32 %0, [%1];"
:
"=r"
(
flag
)
:
"l"
(
flag_addr
));
return
flag
;
}
}
// This function is meant to be used as the second or the final synchronization
// is_start: whether this is the very first synchronization barrier.
// barrier in the all reduce kernel. If it's the final synchronization barrier,
// need_fence: whether a memory fence is needed. If true, a release-acquire
// we don't need to make any visibility guarantees for prior memory accesses.
// semantic is used to enforce memory access order before and after this
template
<
int
ngpus
,
bool
final_sync
=
false
>
// barrier.
DINLINE
void
end_sync
(
const
RankSignals
&
sg
,
volatile
Signal
*
self_sg
,
template
<
int
ngpus
,
bool
is_start
,
bool
need_fence
=
false
>
int
rank
)
{
DINLINE
void
multi_gpu_barrier
(
const
RankSignals
&
sg
,
Signal
*
self_sg
,
__syncthreads
();
int
rank
)
{
// eliminate the case that prior writes are not visible after signals become
if
constexpr
(
!
is_start
)
__syncthreads
();
// visible. Note that I did not managed to make this happen through a lot of
static_assert
(
// testing. Might be the case that hardware provides stronger guarantee than
!
(
is_start
&&
need_fence
));
// Start barrier shouldn't need fence.
// the memory model.
if
constexpr
(
!
final_sync
)
__threadfence_system
();
if
(
threadIdx
.
x
<
ngpus
)
{
if
(
threadIdx
.
x
<
ngpus
)
{
// reset flag for next time
// Increment the counter. Technically we only need one counter, but we use
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
]
=
0
;
// multiple per block to eliminate the need to share the counter via smem.
// simultaneously write to the corresponding flag of all ranks.
auto
val
=
self_sg
->
self_counter
[
blockIdx
.
x
][
threadIdx
.
x
]
+=
1
;
// Latency = 1 p2p write
// Write the expected counter value to peer and wait for correct value from
sg
.
signals
[
threadIdx
.
x
]
->
end
[
blockIdx
.
x
][
rank
]
=
1
;
// peer.
// wait until we got true from all ranks
auto
peer_counter_ptr
=
while
(
!
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
]);
&
sg
.
signals
[
threadIdx
.
x
]
->
peer_counter
[
val
%
2
][
blockIdx
.
x
][
rank
];
auto
self_counter_ptr
=
&
self_sg
->
peer_counter
[
val
%
2
][
blockIdx
.
x
][
threadIdx
.
x
];
if
constexpr
(
need_fence
)
{
st_flag_release
(
peer_counter_ptr
,
val
);
while
(
ld_flag_acquire
(
self_counter_ptr
)
!=
val
);
}
else
{
st_flag_volatile
(
peer_counter_ptr
,
val
);
while
(
ld_flag_volatile
(
self_counter_ptr
)
!=
val
);
}
}
}
if
constexpr
(
!
final_sy
nc
)
__syncthreads
();
if
constexpr
(
is_start
||
need_fe
nc
e
)
__syncthreads
();
}
}
template
<
typename
P
,
int
ngpus
,
typename
A
>
template
<
typename
P
,
int
ngpus
,
typename
A
>
...
@@ -178,33 +198,31 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) {
...
@@ -178,33 +198,31 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) {
template
<
typename
T
,
int
ngpus
>
template
<
typename
T
,
int
ngpus
>
__global__
void
__launch_bounds__
(
512
,
1
)
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_1stage
(
RankData
*
_dp
,
RankSignals
sg
,
cross_device_reduce_1stage
(
RankData
*
_dp
,
RankSignals
sg
,
Signal
*
self_sg
,
volatile
Signal
*
self_sg
,
T
*
__restrict__
result
,
T
*
__restrict__
result
,
int
rank
,
int
size
)
{
int
rank
,
int
size
)
{
using
P
=
typename
packed_t
<
T
>::
P
;
using
P
=
typename
packed_t
<
T
>::
P
;
using
A
=
typename
packed_t
<
T
>::
A
;
using
A
=
typename
packed_t
<
T
>::
A
;
// note: we don't reorder the address so the accumulation order is the same
// note: we don't reorder the address so the accumulation order is the same
// for all ranks, ensuring bitwise identical results
// for all ranks, ensuring bitwise identical results
auto
dp
=
*
_dp
;
auto
dp
=
*
_dp
;
start_sync
<
ngpus
>
(
sg
,
self_sg
,
rank
);
multi_gpu_barrier
<
ngpus
,
true
>
(
sg
,
self_sg
,
rank
);
// do the actual reduction
// do the actual reduction
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
((
P
*
)
result
)[
idx
]
=
packed_reduce
<
P
,
ngpus
,
A
>
((
const
P
**
)
&
dp
.
ptrs
[
0
],
idx
);
((
P
*
)
result
)[
idx
]
=
packed_reduce
<
P
,
ngpus
,
A
>
((
const
P
**
)
&
dp
.
ptrs
[
0
],
idx
);
}
}
end_sync
<
ngpus
,
tru
e
>
(
sg
,
self_sg
,
rank
);
multi_gpu_barrier
<
ngpus
,
fals
e
>
(
sg
,
self_sg
,
rank
);
}
}
template
<
typename
P
>
template
<
typename
P
>
DINLINE
P
*
get_tmp_buf
(
volatile
Signal
*
sg
)
{
DINLINE
P
*
get_tmp_buf
(
Signal
*
sg
)
{
return
(
P
*
)(((
Signal
*
)
sg
)
+
1
);
return
(
P
*
)(((
Signal
*
)
sg
)
+
1
);
}
}
template
<
typename
T
,
int
ngpus
>
template
<
typename
T
,
int
ngpus
>
__global__
void
__launch_bounds__
(
512
,
1
)
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_2stage
(
RankData
*
_dp
,
RankSignals
sg
,
cross_device_reduce_2stage
(
RankData
*
_dp
,
RankSignals
sg
,
Signal
*
self_sg
,
volatile
Signal
*
self_sg
,
T
*
__restrict__
result
,
T
*
__restrict__
result
,
int
rank
,
int
size
)
{
int
rank
,
int
size
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
gridDim
.
x
*
blockDim
.
x
;
int
stride
=
gridDim
.
x
*
blockDim
.
x
;
using
P
=
typename
packed_t
<
T
>::
P
;
using
P
=
typename
packed_t
<
T
>::
P
;
...
@@ -222,12 +240,12 @@ __global__ void __launch_bounds__(512, 1)
...
@@ -222,12 +240,12 @@ __global__ void __launch_bounds__(512, 1)
tmps
[
i
]
=
get_tmp_buf
<
P
>
(
sg
.
signals
[
target
]);
tmps
[
i
]
=
get_tmp_buf
<
P
>
(
sg
.
signals
[
target
]);
}
}
auto
tmp_out
=
tmps
[
0
];
auto
tmp_out
=
tmps
[
0
];
start_sync
<
ngpus
>
(
sg
,
self_sg
,
rank
);
multi_gpu_barrier
<
ngpus
,
true
>
(
sg
,
self_sg
,
rank
);
// stage 1: reduce scatter
// stage 1: reduce scatter
for
(
int
idx
=
start
+
tid
;
idx
<
end
;
idx
+=
stride
)
{
for
(
int
idx
=
start
+
tid
;
idx
<
end
;
idx
+=
stride
)
{
tmp_out
[
idx
-
start
]
=
packed_reduce
<
P
,
ngpus
,
A
>
(
ptrs
,
idx
);
tmp_out
[
idx
-
start
]
=
packed_reduce
<
P
,
ngpus
,
A
>
(
ptrs
,
idx
);
}
}
end_sync
<
ngpus
>
(
sg
,
self_sg
,
rank
);
multi_gpu_barrier
<
ngpus
,
false
,
true
>
(
sg
,
self_sg
,
rank
);
// stage 2: allgather. Note: it's important to match the tid between
// stage 2: allgather. Note: it's important to match the tid between
// the two stages, because visibility across devices is only guaranteed
// the two stages, because visibility across devices is only guaranteed
...
@@ -437,6 +455,8 @@ class CustomAllreduce {
...
@@ -437,6 +455,8 @@ class CustomAllreduce {
#define KL(ngpus, name) \
#define KL(ngpus, name) \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
rank_, size);
rank_, size);
// TODO(hanzhi713): Threshold is different for A100 and H100.
// Add per device threshold.
#define REDUCE_CASE(ngpus) \
#define REDUCE_CASE(ngpus) \
case ngpus: { \
case ngpus: { \
if (world_size_ == 2) { \
if (world_size_ == 2) { \
...
...
csrc/custom_all_reduce_test.cu
View file @
cc4325b6
/**
/**
* This is a standalone test for custom allreduce.
* This is a standalone test for custom allreduce.
* To compile, make sure you have MPI and NCCL installed in your system.
* To compile, make sure you have MPI and NCCL installed in your system.
* export MPI_HOME=
XXX
* export MPI_HOME=
xxx
* nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
* nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
* custom_all_reduce_test -lnccl -I${MPI_HOME}
/include
-lmpi
* custom_all_reduce_test -lnccl -I${MPI_HOME} -lmpi
*
*
* Warning: this C++ test is not designed to be very readable and was used
* Warning: this C++ test is not designed to be very readable and was used
* during the rapid prototyping process.
* during the rapid prototyping process.
*
*
* To run:
* To run:
* mpirun -np 8 ./custom_all_reduce_test
* mpirun
--allow-run-as-root
-np 8 ./custom_all_reduce_test
*/
*/
#include <cuda.h>
#include <cuda.h>
#include <curand_kernel.h>
#include <curand_kernel.h>
...
@@ -302,15 +302,19 @@ int main(int argc, char** argv) {
...
@@ -302,15 +302,19 @@ int main(int argc, char** argv) {
bool
performance_test
=
true
;
bool
performance_test
=
true
;
cudaProfilerStart
();
cudaProfilerStart
();
// for (int threads : {256, 512}) {
// Uncomment to scan through different block size configs.
// for (int threads : {256, 512, 1024}) {
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
// run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
// run<half>(myRank, nRanks, comm, threads, block_limit, 1024 * 1024,
// performance_test);
// }
// }
// }
// }
// Scan through different sizes to test performance.
for
(
int
sz
=
512
;
sz
<=
(
8
<<
20
);
sz
*=
2
)
{
for
(
int
sz
=
512
;
sz
<=
(
8
<<
20
);
sz
*=
2
)
{
run
<
half
>
(
myRank
,
nRanks
,
comm
,
512
,
36
,
sz
+
8
*
47
,
performance_test
);
run
<
half
>
(
myRank
,
nRanks
,
comm
,
512
,
36
,
sz
+
8
*
47
,
performance_test
);
}
}
cudaProfilerStop
();
cudaProfilerStop
();
MPICHECK
(
MPI_Finalize
());
return
EXIT_SUCCESS
;
return
EXIT_SUCCESS
;
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment