Commit 456a96c8 authored by yuguo's avatar yuguo
Browse files

[DCU] overlap bug fix in ECO and BW finally

parent b9ec4909
...@@ -80,7 +80,7 @@ ...@@ -80,7 +80,7 @@
printf("[%s:%s:%d] " message "\n", FILENAME(__FILE__), __FUNCTION__, __LINE__, __VA_ARGS__) printf("[%s:%s:%d] " message "\n", FILENAME(__FILE__), __FUNCTION__, __LINE__, __VA_ARGS__)
// Report and error on timeout // Report and error on timeout
#define CHECK_TIMEOUT(t, timeout) ((clock64() - (t)) > timeout) #define CHECK_TIMEOUT(t, timeout) ((static_cast<int64_t>(clock64()) - static_cast<int64_t>(t)) > static_cast<int64_t>(timeout))
template <int RANKS> template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS) __global__ void __launch_bounds__(MAX_THREADS)
...@@ -292,17 +292,11 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -292,17 +292,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
targetgpu = threadIdx.x * gpustep + firstrank; targetgpu = threadIdx.x * gpustep + firstrank;
myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset; myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - NVTE_MAX_OPS; // +op; reduceidptr = myptr - NVTE_MAX_OPS; // +op;
__threadfence_system();
reduce_id = (*reduceidptr) + 1; reduce_id = (*reduceidptr) + 1;
__threadfence_system();
flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset; flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset;
__threadfence_system();
if (blockIdx.x == 0) flagptr[physgpu] = reduce_id; if (blockIdx.x == 0) flagptr[physgpu] = reduce_id;
__threadfence_system();
volatile int *flag = (volatile int *)&(myptr[targetgpu]); volatile int *flag = (volatile int *)&(myptr[targetgpu]);
__threadfence_system();
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]); userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
__threadfence_system();
clock_t s = clock64(); clock_t s = clock64();
while (CHECK_IDS(*flag, reduce_id)) { while (CHECK_IDS(*flag, reduce_id)) {
if (CHECK_TIMEOUT(s, ub_timeout)) { if (CHECK_TIMEOUT(s, ub_timeout)) {
...@@ -315,9 +309,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -315,9 +309,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1;
__threadfence_system();
int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder);
__threadfence_system();
if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1; if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1;
} }
...@@ -348,9 +340,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -348,9 +340,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
userptr[myrank][mylineoffset + line] = sum; userptr[myrank][mylineoffset + line] = sum;
} }
__threadfence_system();
if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id; if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id;
__threadfence_system();
} // fp16 inplace reduce-scatter kernel } // fp16 inplace reduce-scatter kernel
template <int RANKS> template <int RANKS>
...@@ -368,17 +358,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ ...@@ -368,17 +358,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
targetgpu = threadIdx.x * gpustep + firstrank; targetgpu = threadIdx.x * gpustep + firstrank;
myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset; myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - NVTE_MAX_OPS; // +op; reduceidptr = myptr - NVTE_MAX_OPS; // +op;
__threadfence_system();
reduce_id = (*reduceidptr) + 1; reduce_id = (*reduceidptr) + 1;
__threadfence_system();
flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset; flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset;
__threadfence_system();
if (blockIdx.x == 0) flagptr[physgpu] = reduce_id; if (blockIdx.x == 0) flagptr[physgpu] = reduce_id;
__threadfence_system();
volatile int *flag = (volatile int *)&(myptr[targetgpu]); volatile int *flag = (volatile int *)&(myptr[targetgpu]);
__threadfence_system();
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]); userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
__threadfence_system();
clock_t s = clock64(); clock_t s = clock64();
while (CHECK_IDS(*flag, reduce_id)) { while (CHECK_IDS(*flag, reduce_id)) {
if (CHECK_TIMEOUT(s, ub_timeout)) { if (CHECK_TIMEOUT(s, ub_timeout)) {
...@@ -391,9 +375,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ ...@@ -391,9 +375,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1;
__threadfence_system();
int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder);
__threadfence_system();
if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1; if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1;
} }
...@@ -424,9 +406,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ ...@@ -424,9 +406,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
(reinterpret_cast<int4 *>(outbuf))[(line / rowlines) * skiplines + (line % rowlines)] = sum; (reinterpret_cast<int4 *>(outbuf))[(line / rowlines) * skiplines + (line % rowlines)] = sum;
} }
__threadfence_system();
if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id; if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id;
__threadfence_system();
} // fp16 reduce-scatter kernel (out of place) } // fp16 reduce-scatter kernel (out of place)
#if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010 #if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010
...@@ -1250,13 +1230,9 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1250,13 +1230,9 @@ __global__ void __launch_bounds__(MAX_THREADS)
targetgpu = threadIdx.x * gpustep + firstrank; targetgpu = threadIdx.x * gpustep + firstrank;
myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset; myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - NVTE_MAX_OPS; // +op; reduceidptr = myptr - NVTE_MAX_OPS; // +op;
__threadfence_system();
reduce_id = (*reduceidptr) + 1; reduce_id = (*reduceidptr) + 1;
__threadfence_system();
flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset; flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset;
__threadfence_system();
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]); userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
__threadfence_system();
clock_t s = clock64(); clock_t s = clock64();
} }
...@@ -1292,9 +1268,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1292,9 +1268,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
__shared__ int lastSM; __shared__ int lastSM;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1;
__threadfence_system();
int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder);
__threadfence_system();
if (old_val + adder == NVTE_MAX_SMS * reduce_id) if (old_val + adder == NVTE_MAX_SMS * reduce_id)
lastSM = 1; lastSM = 1;
else else
...@@ -1302,13 +1276,9 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1302,13 +1276,9 @@ __global__ void __launch_bounds__(MAX_THREADS)
} }
__syncthreads(); __syncthreads();
if (lastSM && threadIdx.x < RANKS) { if (lastSM && threadIdx.x < RANKS) {
__threadfence_system();
if (threadIdx.x == 0) *reduceidptr = reduce_id; if (threadIdx.x == 0) *reduceidptr = reduce_id;
__threadfence_system();
flagptr[physgpu] = reduce_id; flagptr[physgpu] = reduce_id;
__threadfence_system();
volatile int *flag = (volatile int *)&myptr[targetgpu]; volatile int *flag = (volatile int *)&myptr[targetgpu];
__threadfence_system();
clock_t s = clock64(); clock_t s = clock64();
while (CHECK_IDS(*flag, reduce_id)) { while (CHECK_IDS(*flag, reduce_id)) {
if (CHECK_TIMEOUT(s, ub_timeout)) { if (CHECK_TIMEOUT(s, ub_timeout)) {
...@@ -1337,13 +1307,9 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1337,13 +1307,9 @@ __global__ void __launch_bounds__(MAX_THREADS)
targetgpu = threadIdx.x * gpustep + firstrank; targetgpu = threadIdx.x * gpustep + firstrank;
myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset; myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - NVTE_MAX_OPS; // +op; reduceidptr = myptr - NVTE_MAX_OPS; // +op;
__threadfence_system();
reduce_id = (*reduceidptr) + 1; reduce_id = (*reduceidptr) + 1;
__threadfence_system();
flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset; flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset;
__threadfence_system();
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]); userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
__threadfence_system();
} }
__syncthreads(); __syncthreads();
localptr = userptr[myrank]; localptr = userptr[myrank];
...@@ -1397,9 +1363,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1397,9 +1363,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
__shared__ int lastSM; __shared__ int lastSM;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1;
__threadfence_system();
int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder);
__threadfence_system();
if (old_val + adder == NVTE_MAX_SMS * reduce_id) if (old_val + adder == NVTE_MAX_SMS * reduce_id)
lastSM = 1; lastSM = 1;
else else
...@@ -1407,13 +1371,9 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1407,13 +1371,9 @@ __global__ void __launch_bounds__(MAX_THREADS)
} }
__syncthreads(); __syncthreads();
if (lastSM && threadIdx.x < RANKS) { if (lastSM && threadIdx.x < RANKS) {
__threadfence_system();
if (threadIdx.x == 0) *reduceidptr = reduce_id; if (threadIdx.x == 0) *reduceidptr = reduce_id;
__threadfence_system();
flagptr[physgpu] = reduce_id; flagptr[physgpu] = reduce_id;
__threadfence_system();
volatile int *flag = (volatile int *)&myptr[targetgpu]; volatile int *flag = (volatile int *)&myptr[targetgpu];
__threadfence_system();
clock_t s = clock64(); clock_t s = clock64();
while (CHECK_IDS(*flag, reduce_id)) { while (CHECK_IDS(*flag, reduce_id)) {
if (CHECK_TIMEOUT(s, ub_timeout)) { if (CHECK_TIMEOUT(s, ub_timeout)) {
...@@ -2197,8 +2157,11 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -2197,8 +2157,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
atomicAdd_system(flagptr, atomicAdd_system(flagptr,
1); // otherwise need local SM sync before sending flag 1); // otherwise need local SM sync before sending flag
} else { // 0 bytes and 1 SM only } else { // 0 bytes and 1 SM only
#ifdef defined(__gfx928__) || defined(__gfx926__) || defined(__gfx906__)
*flagptr = *flagptr + 1; *flagptr = *flagptr + 1;
__threadfence_system(); #else
atomicAdd_system(flagptr, 1);
#endif
} }
} }
...@@ -2210,9 +2173,7 @@ __global__ void kuserbuffers_pushrecv(int myrank, int peer, int nvrank, int nvpe ...@@ -2210,9 +2173,7 @@ __global__ void kuserbuffers_pushrecv(int myrank, int peer, int nvrank, int nvpe
int *ce_start_ptr, int *ce_end_ptr) { int *ce_start_ptr, int *ce_end_ptr) {
const int signal_id = (*recv_id) + adder; const int signal_id = (*recv_id) + adder;
*recv_id = signal_id; *recv_id = signal_id;
__threadfence_system();
volatile int *flag = (volatile int *)flagptr; volatile int *flag = (volatile int *)flagptr;
__threadfence_system();
if (*flag >= signal_id) return; if (*flag >= signal_id) return;
clock_t s = clock64(); clock_t s = clock64();
while (CHECK_IDS(*flag, signal_id)) { while (CHECK_IDS(*flag, signal_id)) {
......
...@@ -56,11 +56,7 @@ def get_cublas_workspace_size_bytes() -> None: ...@@ -56,11 +56,7 @@ def get_cublas_workspace_size_bytes() -> None:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures.""" """Return 32 MiB if using hopper, 4 MiB for all other architectures."""
# Add env for control the padding for blaslt # Add env for control the padding for blaslt
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
nvte_blaslt_nopad = int(os.environ.get("NVTE_BLASLT_NOPAD", 0)) return 134_217_728
if(nvte_blaslt_nopad):
return 536_870_912
else:
return 1_073_741_824
if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9: if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9:
return 33_554_432 return 33_554_432
return 4_194_304 return 4_194_304
......
...@@ -253,7 +253,7 @@ if IS_HIP_EXTENSION: ...@@ -253,7 +253,7 @@ if IS_HIP_EXTENSION:
import re import re
return (re.search('K100_AI', torch.cuda.get_device_name(torch.cuda.current_device())) is not None) return (re.search('K100_AI', torch.cuda.get_device_name(torch.cuda.current_device())) is not None)
def is_BW3000(): def is_BW():
"""check whether this machine is BW""" """check whether this machine is BW"""
import re import re
return (re.search('BW', torch.cuda.get_device_name(torch.cuda.current_device())) is not None) return (re.search('BW', torch.cuda.get_device_name(torch.cuda.current_device())) is not None)
...@@ -264,7 +264,7 @@ def is_bf16_compatible() -> None: ...@@ -264,7 +264,7 @@ def is_bf16_compatible() -> None:
""" """
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
# only MI200 and MI300 machines support bf16 # only MI200 and MI300 machines support bf16
if get_device_compute_capability() == (9, 4) or is_mi200() or is_K100_AI() or is_BW3000(): if get_device_compute_capability() == (9, 4) or is_mi200() or is_K100_AI() or is_BW():
return True return True
else: else:
return False return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment