Commit 07b750a2 authored by yuguo's avatar yuguo
Browse files

[DCU] tmp fix overlap allmethod

parent 8fb50d09
......@@ -212,7 +212,7 @@ def _parse_args(argv=None, namespace=None):
parser.add_argument(
"--benchmark-iter",
type=int,
default=100,
default=20,
help="Number of iterations for benchmarking perf.",
)
parser.add_argument(
......@@ -376,6 +376,8 @@ def _train(opts):
ub_cfgs = {
"qkv_dgrad": {"method": "ring_exchange"},
"fc1_dgrad": {"method": "ring_exchange"},
"proj_fprop": {"method": "ring_exchange"},
"fc2_fprop": {"method": "ring_exchange"},
}
te.module.base.initialize_ub(
[opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim],
......@@ -498,11 +500,11 @@ def _train(opts):
if opts.benchmark:
# Warmup to not profile CPU overhead
for _ in range(100):
for _ in range(20):
if opts.use_cuda_graphs:
test_graph.replay()
else:
test_out = run_fwd_bwd(test_model, test_x)
test_out = run_fwd_bwd(ref_model, ref_x)
torch.cuda.cudart().cudaProfilerStart()
for _ in range(opts.benchmark_iter):
if opts.use_cuda_graphs:
......
......@@ -880,6 +880,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, _stream_send[0]));
}
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
}
}
......
......@@ -292,11 +292,17 @@ __global__ void __launch_bounds__(MAX_THREADS)
targetgpu = threadIdx.x * gpustep + firstrank;
myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - NVTE_MAX_OPS; // +op;
__threadfence_system();
reduce_id = (*reduceidptr) + 1;
__threadfence_system();
flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset;
__threadfence_system();
if (blockIdx.x == 0) flagptr[physgpu] = reduce_id;
__threadfence_system();
volatile int *flag = (volatile int *)&(myptr[targetgpu]);
__threadfence_system();
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
__threadfence_system();
clock_t s = clock64();
while (CHECK_IDS(*flag, reduce_id)) {
if (CHECK_TIMEOUT(s, ub_timeout)) {
......@@ -309,7 +315,9 @@ __global__ void __launch_bounds__(MAX_THREADS)
__syncthreads();
if (threadIdx.x == 0) {
const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1;
__threadfence_system();
int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder);
__threadfence_system();
if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1;
}
......@@ -340,8 +348,9 @@ __global__ void __launch_bounds__(MAX_THREADS)
userptr[myrank][mylineoffset + line] = sum;
}
__threadfence_system();
if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id;
__threadfence_system();
} // fp16 inplace reduce-scatter kernel
template <int RANKS>
......@@ -359,7 +368,9 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
targetgpu = threadIdx.x * gpustep + firstrank;
myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - NVTE_MAX_OPS; // +op;
__threadfence_system();
reduce_id = (*reduceidptr) + 1;
__threadfence_system();
flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset;
__threadfence_system();
if (blockIdx.x == 0) flagptr[physgpu] = reduce_id;
......@@ -380,7 +391,9 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
__syncthreads();
if (threadIdx.x == 0) {
const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1;
__threadfence_system();
int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder);
__threadfence_system();
if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1;
}
......@@ -1237,9 +1250,13 @@ __global__ void __launch_bounds__(MAX_THREADS)
targetgpu = threadIdx.x * gpustep + firstrank;
myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - NVTE_MAX_OPS; // +op;
__threadfence_system();
reduce_id = (*reduceidptr) + 1;
__threadfence_system();
flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset;
__threadfence_system();
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
__threadfence_system();
clock_t s = clock64();
}
......@@ -1275,7 +1292,9 @@ __global__ void __launch_bounds__(MAX_THREADS)
__shared__ int lastSM;
if (threadIdx.x == 0) {
const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1;
__threadfence_system();
int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder);
__threadfence_system();
if (old_val + adder == NVTE_MAX_SMS * reduce_id)
lastSM = 1;
else
......@@ -1283,9 +1302,13 @@ __global__ void __launch_bounds__(MAX_THREADS)
}
__syncthreads();
if (lastSM && threadIdx.x < RANKS) {
__threadfence_system();
if (threadIdx.x == 0) *reduceidptr = reduce_id;
__threadfence_system();
flagptr[physgpu] = reduce_id;
__threadfence_system();
volatile int *flag = (volatile int *)&myptr[targetgpu];
__threadfence_system();
clock_t s = clock64();
while (CHECK_IDS(*flag, reduce_id)) {
if (CHECK_TIMEOUT(s, ub_timeout)) {
......@@ -1314,9 +1337,13 @@ __global__ void __launch_bounds__(MAX_THREADS)
targetgpu = threadIdx.x * gpustep + firstrank;
myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - NVTE_MAX_OPS; // +op;
__threadfence_system();
reduce_id = (*reduceidptr) + 1;
__threadfence_system();
flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset;
__threadfence_system();
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
__threadfence_system();
}
__syncthreads();
localptr = userptr[myrank];
......@@ -1370,7 +1397,9 @@ __global__ void __launch_bounds__(MAX_THREADS)
__shared__ int lastSM;
if (threadIdx.x == 0) {
const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1;
__threadfence_system();
int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder);
__threadfence_system();
if (old_val + adder == NVTE_MAX_SMS * reduce_id)
lastSM = 1;
else
......@@ -1378,9 +1407,13 @@ __global__ void __launch_bounds__(MAX_THREADS)
}
__syncthreads();
if (lastSM && threadIdx.x < RANKS) {
__threadfence_system();
if (threadIdx.x == 0) *reduceidptr = reduce_id;
__threadfence_system();
flagptr[physgpu] = reduce_id;
__threadfence_system();
volatile int *flag = (volatile int *)&myptr[targetgpu];
__threadfence_system();
clock_t s = clock64();
while (CHECK_IDS(*flag, reduce_id)) {
if (CHECK_TIMEOUT(s, ub_timeout)) {
......@@ -2090,11 +2123,7 @@ template void reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e5m2>(
#endif
__global__ void kuserbuffers_pullsend(int myrank, int peer, int *send_id, int *flagptr) {
#ifdef __HIP_PLATFORM_AMD__
*flagptr = *flagptr + 1;
#else
atomicAdd_system(flagptr, 1);
#endif
}
__global__ void kuserbuffers_inc(int *id) { atomicAdd(id, 1); }
......@@ -2165,18 +2194,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
__syncthreads();
if (threadIdx.x) return;
__threadfence_system();
#ifdef __HIP_PLATFORM_AMD__
*flagptr = *flagptr + 1;
#else
atomicAdd_system(flagptr,
1); // otherwise need local SM sync before sending flag
#endif
} else { // 0 bytes and 1 SM only
#ifdef __HIP_PLATFORM_AMD__
*flagptr = *flagptr + 1;
#else
atomicAdd_system(flagptr, 1);
#endif
__threadfence_system();
}
}
......@@ -2188,7 +2210,9 @@ __global__ void kuserbuffers_pushrecv(int myrank, int peer, int nvrank, int nvpe
int *ce_start_ptr, int *ce_end_ptr) {
const int signal_id = (*recv_id) + adder;
*recv_id = signal_id;
__threadfence_system();
volatile int *flag = (volatile int *)flagptr;
__threadfence_system();
if (*flag >= signal_id) return;
clock_t s = clock64();
while (CHECK_IDS(*flag, signal_id)) {
......@@ -2235,18 +2259,10 @@ __global__ void __launch_bounds__(MAX_THREADS)
__syncthreads();
if (threadIdx.x) return;
__threadfence_system();
#ifdef __HIP_PLATFORM_AMD__
*send_flagptr = *send_flagptr + 1;
#else
atomicAdd_system(send_flagptr,
1); // otherwise need local SM sync before sending flag
#endif
} else { // 0 bytes and 1 SM only
#ifdef __HIP_PLATFORM_AMD__
*send_flagptr = *send_flagptr + 1;
#else
atomicAdd_system(send_flagptr, 1);
#endif
}
if (blockIdx.x == 0 && threadIdx.x == 0) {
......@@ -2301,18 +2317,10 @@ __global__ void __launch_bounds__(MAX_THREADS)
__syncthreads();
if (threadIdx.x) return;
__threadfence_system();
#ifdef __HIP_PLATFORM_AMD__
*send_flagptr = *send_flagptr + 1;
#else
atomicAdd_system(send_flagptr,
1); // otherwise need local SM sync before sending flag
#endif
} else { // 0 bytes and 1 SM only
#ifdef __HIP_PLATFORM_AMD__
*send_flagptr = *send_flagptr + 1;
#else
atomicAdd_system(send_flagptr, 1);
#endif
}
if (blockIdx.x == 0 && threadIdx.x == 0) {
......@@ -2382,19 +2390,11 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat
__syncthreads();
if (!threadIdx.x) {
__threadfence_system();
#ifdef __HIP_PLATFORM_AMD__
*send_flagptr = *send_flagptr + 1;
#else
atomicAdd_system(send_flagptr,
1); // otherwise need local SM sync before sending flag
#endif
}
} else { // 0 bytes and 1 SM only
#ifdef __HIP_PLATFORM_AMD__
*send_flagptr = *send_flagptr + 1;
#else
atomicAdd_system(send_flagptr, 1);
#endif
}
// wait for message to arrive.
......@@ -2466,9 +2466,6 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat
void userbuffers_send(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes, communicator *comm,
const int peer, cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
#endif
int peerlocal = peer % comm->nvsize;
void *flagptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 0);
// void *ce_send_start_ptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 1);
......@@ -2500,17 +2497,11 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds
NVTE_CHECK_CUDA(
cudaLaunchKernelExC(&cfg, reinterpret_cast<void *>(kuserbuffers_pushsend), kernelArgs));
}
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
#endif
}
void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size_t send_offset,
const size_t recv_offset, const size_t bytes, communicator *comm,
const int send_peer, const int recv_peer, cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
#endif
bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0);
int send_peerlocal = send_peer % comm->nvsize;
int recv_peerlocal = recv_peer % comm->nvsize;
......@@ -2560,18 +2551,12 @@ void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size
reinterpret_cast<void *>(&arg15)};
NVTE_CHECK_CUDA(
cudaLaunchKernelExC(&cfg, reinterpret_cast<void *>(kuserbuffers_pushsendrecv), kernelArgs));
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
#endif
}
void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler,
const size_t send_offset, const size_t recv_offset,
const size_t bytes, communicator *comm, const int send_peer,
const int recv_peer, void *counters, cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
#endif
assert(comm->push && comm->use_ce == 0);
bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0);
......@@ -2623,9 +2608,6 @@ void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler,
reinterpret_cast<void *>(&arg15), reinterpret_cast<void *>(&arg16)};
NVTE_CHECK_CUDA(cudaLaunchKernelExC(
&cfg, reinterpret_cast<void *>(kuserbuffers_pushsendrecv_atomic), kernelArgs));
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
#endif
}
void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler,
......@@ -2633,9 +2615,6 @@ void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler
const size_t bytes, communicator *comm, const int send_peer,
const int recv_peer, const int nchunks, void *counters,
bool shuffle, cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
#endif
assert(comm->push && comm->use_ce == 0);
// CE is not supported
......@@ -2675,17 +2654,11 @@ void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler
reinterpret_cast<void *>(&arg17), reinterpret_cast<void *>(&arg18)};
NVTE_CHECK_CUDA(cudaLaunchKernelExC(
&cfg, reinterpret_cast<void *>(kuserbuffers_pushsendrecv_multiatomic), kernelArgs));
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
#endif
}
void userbuffers_recv(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes, communicator *comm,
const int peer, cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
#endif
int peerlocal = peer % comm->nvsize;
void *flagptr = GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 0);
bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0);
......@@ -2719,9 +2692,6 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds
GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 2)
: nullptr));
}
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
#endif
}
// producer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment