Unverified Commit e1e2b76e authored by Pavel Shamis (Pasha)'s avatar Pavel Shamis (Pasha) Committed by GitHub
Browse files

Fixing potential integer overflow on sequence counter (#729)



* Fixing potential integer overflow on sequence counter

Current implementation may potential cause hangs or data corruption
Signed-off-by: default avatarPasha (Pavel) Shamis <pasharesearch@gmail.com>

* Fixing typo in comments

Addressing reviewers comments
Signed-off-by: default avatarPasha (Pavel) Shamis <pasharesearch@gmail.com>

---------
Signed-off-by: default avatarPasha (Pavel) Shamis <pasharesearch@gmail.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 1fa5bf18
...@@ -51,6 +51,10 @@ ...@@ -51,6 +51,10 @@
asm volatile("fence.sc.gpu;\n"); \ asm volatile("fence.sc.gpu;\n"); \
} }
// Return true if producer > consumer, otherwise false while preventing integer overflow
// If we expect that producer will be 2B+ messages behind consumer
#define CHECK_IDS(producer, consumer) (((unsigned)(producer) - (unsigned)(consumer)) & (~INT_MAX))
template <int RANKS> template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS) __global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_rw(const int op, const int flagoffset, const int firstrank, userbuffers_fp16_sum_inplace_gpu_rw(const int op, const int flagoffset, const int firstrank,
...@@ -74,7 +78,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -74,7 +78,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
volatile int *flag = (volatile int *)&(myptr[targetgpu]); volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]); userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64(); clock_t s = clock64();
while (*flag < reduce_id) { while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > TIMEOUT) { if (clock64() - s > TIMEOUT) {
printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag); *flag);
...@@ -128,7 +132,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -128,7 +132,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
flagptr[physgpu] = reduce_id; flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&myptr[targetgpu]; volatile int *flag = (volatile int *)&myptr[targetgpu];
clock_t s = clock64(); clock_t s = clock64();
while (*flag < reduce_id) { while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > 2ull * TIMEOUT) { if (clock64() - s > 2ull * TIMEOUT) {
printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag); *flag);
...@@ -162,7 +166,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -162,7 +166,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
volatile int *flag = (volatile int *)&(myptr[targetgpu]); volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]); userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64(); clock_t s = clock64();
while (*flag < reduce_id) { while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > TIMEOUT) { if (clock64() - s > TIMEOUT) {
printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag); *flag);
...@@ -211,7 +215,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -211,7 +215,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
flagptr[physgpu] = reduce_id; flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&myptr[targetgpu]; volatile int *flag = (volatile int *)&myptr[targetgpu];
clock_t s = clock64(); clock_t s = clock64();
while (*flag < reduce_id) { while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > 2ull * TIMEOUT) { if (clock64() - s > 2ull * TIMEOUT) {
printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag); *flag);
...@@ -273,7 +277,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -273,7 +277,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
volatile int *flag = (volatile int *)&(myptr[targetgpu]); volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]); userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64(); clock_t s = clock64();
while (*flag < reduce_id) { while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > TIMEOUT) { if (clock64() - s > TIMEOUT) {
printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag); *flag);
...@@ -348,7 +352,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -348,7 +352,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
volatile int *flag = (volatile int *)&(myptr[targetgpu]); volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]); userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64(); clock_t s = clock64();
while (*flag < reduce_id) { while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > TIMEOUT) { if (clock64() - s > TIMEOUT) {
printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag); *flag);
...@@ -422,7 +426,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -422,7 +426,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
flagptr[physgpu] = reduce_id; flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&(myptr[targetgpu]); volatile int *flag = (volatile int *)&(myptr[targetgpu]);
clock_t s = clock64(); clock_t s = clock64();
while (*flag < reduce_id) { while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > TIMEOUT) { if (clock64() - s > TIMEOUT) {
printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag); *flag);
...@@ -490,7 +494,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -490,7 +494,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
flagptr[physgpu] = reduce_id; flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&myptr[targetgpu]; volatile int *flag = (volatile int *)&myptr[targetgpu];
clock_t s = clock64(); clock_t s = clock64();
while (*flag < reduce_id) { while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > 2ull * TIMEOUT) { if (clock64() - s > 2ull * TIMEOUT) {
printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag); *flag);
...@@ -525,7 +529,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -525,7 +529,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
flagptr[physgpu] = reduce_id; flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&(myptr[targetgpu]); volatile int *flag = (volatile int *)&(myptr[targetgpu]);
clock_t s = clock64(); clock_t s = clock64();
while (*flag < reduce_id) { while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > TIMEOUT) { if (clock64() - s > TIMEOUT) {
printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag); *flag);
...@@ -610,7 +614,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -610,7 +614,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
flagptr[physgpu] = reduce_id; flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&(myptr[targetgpu]); volatile int *flag = (volatile int *)&(myptr[targetgpu]);
clock_t s = clock64(); clock_t s = clock64();
while (*flag < reduce_id) { while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > TIMEOUT) { if (clock64() - s > TIMEOUT) {
printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x, printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x,
threadIdx.x, reduce_id, *flag); threadIdx.x, reduce_id, *flag);
...@@ -740,7 +744,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -740,7 +744,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
flagptr[physgpu] = reduce_id; flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&myptr[targetgpu]; volatile int *flag = (volatile int *)&myptr[targetgpu];
clock_t s = clock64(); clock_t s = clock64();
while (*flag < reduce_id) { while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > 2ull * TIMEOUT) { if (clock64() - s > 2ull * TIMEOUT) {
printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag); *flag);
...@@ -800,7 +804,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ ...@@ -800,7 +804,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
volatile int *flag = (volatile int *)&(myptr[targetgpu]); volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]); userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64(); clock_t s = clock64();
while (*flag < reduce_id) { while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > TIMEOUT) { if (clock64() - s > TIMEOUT) {
printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x, printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x,
threadIdx.x, reduce_id, *flag); threadIdx.x, reduce_id, *flag);
...@@ -888,7 +892,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -888,7 +892,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
volatile int *flag = (volatile int *)&(myptr[targetgpu]); volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]); userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64(); clock_t s = clock64();
while (*flag < reduce_id) { while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > TIMEOUT) { if (clock64() - s > TIMEOUT) {
printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x, printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x,
threadIdx.x, reduce_id, *flag); threadIdx.x, reduce_id, *flag);
...@@ -975,7 +979,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -975,7 +979,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
volatile int *flag = (volatile int *)&(myptr[targetgpu]); volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]); userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64(); clock_t s = clock64();
while (*flag < reduce_id) { while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > TIMEOUT) { if (clock64() - s > TIMEOUT) {
printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x, printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x,
threadIdx.x, reduce_id, *flag); threadIdx.x, reduce_id, *flag);
...@@ -1072,7 +1076,7 @@ userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_atomic_fp8( ...@@ -1072,7 +1076,7 @@ userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_atomic_fp8(
volatile int* flag = (volatile int*)&(myptr[targetgpu]); volatile int* flag = (volatile int*)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4*>(commbuff[targetgpu+handleridx]); userptr[threadIdx.x] = reinterpret_cast<int4*>(commbuff[targetgpu+handleridx]);
clock_t s = clock64(); clock_t s = clock64();
while (*flag < reduce_id) { while (CHECK_IDS(*flag, reduce_id)) {
if (clock64()-s > TIMEOUT) { if (clock64()-s > TIMEOUT) {
printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n",
myrank, blockIdx.x, threadIdx.x, reduce_id, *flag); myrank, blockIdx.x, threadIdx.x, reduce_id, *flag);
...@@ -1171,7 +1175,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1171,7 +1175,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
volatile int *flag = (volatile int *)&(myptr[targetgpu]); volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]); userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64(); clock_t s = clock64();
while (*flag < reduce_id) { while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > TIMEOUT) { if (clock64() - s > TIMEOUT) {
printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x, printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x,
threadIdx.x, reduce_id, *flag); threadIdx.x, reduce_id, *flag);
...@@ -1270,7 +1274,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1270,7 +1274,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
volatile int *flag = (volatile int *)&(myptr[targetgpu]); volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]); userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64(); clock_t s = clock64();
while (*flag < reduce_id) { while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > TIMEOUT) { if (clock64() - s > TIMEOUT) {
printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x, printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x,
threadIdx.x, reduce_id, *flag); threadIdx.x, reduce_id, *flag);
...@@ -1389,7 +1393,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1389,7 +1393,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
flagptr[physgpu] = reduce_id; flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&myptr[targetgpu]; volatile int *flag = (volatile int *)&myptr[targetgpu];
clock_t s = clock64(); clock_t s = clock64();
while (*flag < reduce_id) { while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > 2ull * TIMEOUT) { if (clock64() - s > 2ull * TIMEOUT) {
printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag); *flag);
...@@ -1486,7 +1490,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1486,7 +1490,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
flagptr[physgpu] = reduce_id; flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&myptr[targetgpu]; volatile int *flag = (volatile int *)&myptr[targetgpu];
clock_t s = clock64(); clock_t s = clock64();
while (*flag < reduce_id) { while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > 2ull * TIMEOUT) { if (clock64() - s > 2ull * TIMEOUT) {
printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag); *flag);
...@@ -1517,7 +1521,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1517,7 +1521,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
} }
volatile int *flag = (volatile int *)&((reinterpret_cast<int *>( volatile int *flag = (volatile int *)&((reinterpret_cast<int *>(
commbuff[myrank + firstrank]))[flagoffset + threadIdx.x + firstrank]); commbuff[myrank + firstrank]))[flagoffset + threadIdx.x + firstrank]);
while (*flag < basecounter) { while (CHECK_IDS(*flag, basecounter)) {
} }
} }
__syncthreads(); __syncthreads();
...@@ -1635,7 +1639,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1635,7 +1639,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
const int end_aligned = start_elem + aligned_elem; const int end_aligned = start_elem + aligned_elem;
if (mythreadIdx == 0) { if (mythreadIdx == 0) {
while (*flag < gathercounter) { while (CHECK_IDS(*flag, gathercounter)) {
} }
gathercounter++; gathercounter++;
} }
...@@ -1694,7 +1698,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ ...@@ -1694,7 +1698,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
flagptr[flagoffset + gpustep * myrank + firstrank] = basecounter; flagptr[flagoffset + gpustep * myrank + firstrank] = basecounter;
} }
volatile int *flag = &localflag[gpustep * threadIdx.x + firstrank]; volatile int *flag = &localflag[gpustep * threadIdx.x + firstrank];
while (*flag < basecounter) { while (CHECK_IDS(*flag, basecounter)) {
} }
} }
__syncthreads(); __syncthreads();
...@@ -1864,7 +1868,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ ...@@ -1864,7 +1868,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
const int end_aligned = start_elem + aligned_elem; const int end_aligned = start_elem + aligned_elem;
if (mythreadIdx == 0) { if (mythreadIdx == 0) {
while (*flag < gathercounter) { while (CHECK_IDS(*flag, gathercounter)) {
} }
gathercounter++; gathercounter++;
} }
...@@ -1908,7 +1912,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ ...@@ -1908,7 +1912,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
flagptr[flagoffset + gpustep * myrank + firstrank] = basecounter; flagptr[flagoffset + gpustep * myrank + firstrank] = basecounter;
} }
volatile int *flag = &localflag[gpustep * threadIdx.x + firstrank]; volatile int *flag = &localflag[gpustep * threadIdx.x + firstrank];
while (*flag < basecounter) { while (CHECK_IDS(*flag, basecounter)) {
} }
} }
__syncthreads(); __syncthreads();
...@@ -2114,7 +2118,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ ...@@ -2114,7 +2118,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
const int end_aligned = start_elem + aligned_elem; const int end_aligned = start_elem + aligned_elem;
if (mythreadIdx == 0) { if (mythreadIdx == 0) {
while (*flag < gathercounter) { while (CHECK_IDS(*flag, gathercounter)) {
} }
gathercounter++; gathercounter++;
} }
...@@ -3013,7 +3017,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -3013,7 +3017,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
const int signal_id = (*recv_id) + 1; const int signal_id = (*recv_id) + 1;
volatile int *flag = (volatile int *)recv_flagptr; volatile int *flag = (volatile int *)recv_flagptr;
clock_t s = clock64(); clock_t s = clock64();
while (*flag < signal_id) { while (CHECK_IDS(*flag, signal_id)) {
if (clock64() - s > TIMEOUT) { if (clock64() - s > TIMEOUT) {
printf("[%d from %d] pullrecv: expected %d, stuck with %d\n", myrank, peer, signal_id, printf("[%d from %d] pullrecv: expected %d, stuck with %d\n", myrank, peer, signal_id,
*flag); *flag);
...@@ -3073,7 +3077,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -3073,7 +3077,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
const int signal_id = (*recv_id) + 1; const int signal_id = (*recv_id) + 1;
volatile int *flag = (volatile int *)flagptr; volatile int *flag = (volatile int *)flagptr;
clock_t s = clock64(); clock_t s = clock64();
while (*flag < signal_id) { while (CHECK_IDS(*flag, signal_id)) {
if (clock64() - s > TIMEOUT) { if (clock64() - s > TIMEOUT) {
printf("[%d from %d] pullrecv: expected %d, stuck with %d\n", myrank, peer, signal_id, printf("[%d from %d] pullrecv: expected %d, stuck with %d\n", myrank, peer, signal_id,
*flag); *flag);
...@@ -3142,7 +3146,7 @@ __global__ void kuserbuffers_pushrecv(int myrank, int peer, int *recv_id, int *f ...@@ -3142,7 +3146,7 @@ __global__ void kuserbuffers_pushrecv(int myrank, int peer, int *recv_id, int *f
if (*flag >= signal_id) if (*flag >= signal_id)
return; return;
clock_t s = clock64(); clock_t s = clock64();
while (atomicAdd_system(flagptr, 0) < signal_id) { while (CHECK_IDS(*flag, signal_id)) {
if (clock64() - s > TIMEOUT) { if (clock64() - s > TIMEOUT) {
printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id, *flag); printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id, *flag);
return; return;
...@@ -3193,7 +3197,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -3193,7 +3197,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
if (*flag >= signal_id) if (*flag >= signal_id)
return; return;
clock_t s = clock64(); clock_t s = clock64();
while (*flag < signal_id) { while (CHECK_IDS(*flag, signal_id)) {
if (clock64() - s > TIMEOUT) { if (clock64() - s > TIMEOUT) {
printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id, printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id,
*flag); *flag);
...@@ -3245,7 +3249,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -3245,7 +3249,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
volatile int *flag = (volatile int *)recv_flagptr; volatile int *flag = (volatile int *)recv_flagptr;
// if(*flag>=signal_id) return; // if(*flag>=signal_id) return;
clock_t s = clock64(); clock_t s = clock64();
while (*flag < signal_id) { while (CHECK_IDS(*flag, signal_id)) {
if (clock64() - s > TIMEOUT) { if (clock64() - s > TIMEOUT) {
printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id, printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id,
*flag); /*return;*/ *flag); /*return;*/
...@@ -3312,7 +3316,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -3312,7 +3316,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
volatile int *flag = (volatile int *)recv_flagptr; volatile int *flag = (volatile int *)recv_flagptr;
// if(*flag>=signal_id) return; // if(*flag>=signal_id) return;
clock_t s = clock64(); clock_t s = clock64();
while (*flag < signal_id) { while (CHECK_IDS(*flag, signal_id)) {
if (clock64() - s > TIMEOUT) { if (clock64() - s > TIMEOUT) {
printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id, printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id,
*flag); /*return;*/ *flag); /*return;*/
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment