Commit caf2fbf2 authored by yuguo's avatar yuguo
Browse files

[DCU] tp overlap opt

parent 0b0a70a5
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""Installation script.""" """Installation script."""
# NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc pip3 install . -v # NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc pip3 install . -v
# VTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc PYTHONPATH=/home/TransformerEngine/3rdparty/hipify_torch:$PYTHONPATH python3 setup.py bdist_wheel
import os import os
import sys import sys
......
...@@ -312,6 +312,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te ...@@ -312,6 +312,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_compute[0])); NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_compute[0]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
} // CommOverlapBase::bulk_overlap } // CommOverlapBase::bulk_overlap
...@@ -461,7 +462,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -461,7 +462,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms, accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]); _stream_compute[i % _stream_compute.size()], 1, 0);
NVTE_CHECK_CUDA( NVTE_CHECK_CUDA(
cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()])); cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()]));
...@@ -509,7 +510,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -509,7 +510,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms, accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]); _stream_compute[i % _stream_compute.size()], 1, 0);
NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[i % _stream_compute.size()])); NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[i % _stream_compute.size()]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0));
...@@ -540,6 +541,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -540,6 +541,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
} }
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
} // CommOverlapBase::split_overlap_rs } // CommOverlapBase::split_overlap_rs
/*************************************************************************************************** /***************************************************************************************************
...@@ -775,8 +777,10 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -775,8 +777,10 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
} }
if (_aggregate) { if (_aggregate) {
const int num_steps = _tp_size / 2; const int num_steps = _tp_size / 2;
#ifndef __HIP_PLATFORM_AMD__
input_chunk_size *= 2; input_chunk_size *= 2;
output_chunk_size *= 2; output_chunk_size *= 2;
#endif
// Initial 1X input chunk exchange between neighboring peers // Initial 1X input chunk exchange between neighboring peers
int send_chunk_id = _tp_id; int send_chunk_id = _tp_id;
...@@ -817,7 +821,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -817,7 +821,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]); _stream_compute[i % _stream_compute.size()], 1, 0);
if (i < num_steps - 1) { if (i < num_steps - 1) {
// P2P communication // P2P communication
...@@ -861,7 +865,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -861,7 +865,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]); _stream_compute[i % _stream_compute.size()], 1, 0);
if (i < _tp_size - 1) { if (i < _tp_size - 1) {
// P2P communication // P2P communication
...@@ -892,6 +896,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -892,6 +896,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
} // CommOverlapP2PBase::split_overlap_ag } // CommOverlapP2PBase::split_overlap_ag
/* /*
...@@ -1005,7 +1010,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, ...@@ -1005,7 +1010,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[stream_id]); use_split_accumulator, _math_sms, _stream_compute[stream_id], 1, 0);
if (i > 0) { if (i > 0) {
// P2P communication chunk // P2P communication chunk
...@@ -1034,6 +1039,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, ...@@ -1034,6 +1039,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
} }
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
// Reduce GEMM output chunks // Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].dptr()); char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].dptr());
......
...@@ -116,8 +116,11 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -116,8 +116,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
reduce_id++; reduce_id++;
} }
__syncthreads(); __syncthreads();
#ifdef __HIP_PLATFORM_AMD__
int warp = blockIdx.x + (threadIdx.x >> 6);
#else
int warp = blockIdx.x + (threadIdx.x >> 5); int warp = blockIdx.x + (threadIdx.x >> 5);
#endif
int dest[RANKS]; int dest[RANKS];
#pragma unroll #pragma unroll
for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1);
...@@ -201,8 +204,11 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -201,8 +204,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
reduce_id++; reduce_id++;
} }
__syncthreads(); __syncthreads();
#ifdef __HIP_PLATFORM_AMD__
int warp = blockIdx.x + (threadIdx.x >> 6);
#else
int warp = blockIdx.x + (threadIdx.x >> 5); int warp = blockIdx.x + (threadIdx.x >> 5);
#endif
int dest[RANKS]; int dest[RANKS];
#pragma unroll #pragma unroll
for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1);
...@@ -312,8 +318,11 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -312,8 +318,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder);
if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1; if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1;
} }
#ifdef __HIP_PLATFORM_AMD__
int warp = blockIdx.x + (threadIdx.x >> 6);
#else
int warp = blockIdx.x + (threadIdx.x >> 5); int warp = blockIdx.x + (threadIdx.x >> 5);
#endif
int dest[RANKS]; int dest[RANKS];
#pragma unroll #pragma unroll
for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1);
...@@ -378,8 +387,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ ...@@ -378,8 +387,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder);
if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1; if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1;
} }
#ifdef __HIP_PLATFORM_AMD__
int warp = blockIdx.x + (threadIdx.x >> 6);
#else
int warp = blockIdx.x + (threadIdx.x >> 5); int warp = blockIdx.x + (threadIdx.x >> 5);
#endif
int dest[RANKS]; int dest[RANKS];
#pragma unroll #pragma unroll
for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1);
...@@ -823,7 +835,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ ...@@ -823,7 +835,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder);
if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1; if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1;
} }
#ifdef __HIP_PLATFORM_AMD__
int warp = blockIdx.x + (threadIdx.x >> 6);
#else
int warp = blockIdx.x + (threadIdx.x >> 5); int warp = blockIdx.x + (threadIdx.x >> 5);
#endif
int dest[RANKS]; int dest[RANKS];
#pragma unroll #pragma unroll
for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1);
...@@ -907,8 +923,11 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -907,8 +923,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), /*numchunks * */ adder); int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), /*numchunks * */ adder);
if (old_val + adder == NVTE_MAX_SMS * (reduce_id /* + numchunks*/)) lastSM = 1; if (old_val + adder == NVTE_MAX_SMS * (reduce_id /* + numchunks*/)) lastSM = 1;
} }
#ifdef __HIP_PLATFORM_AMD__
int warp = blockIdx.x + (threadIdx.x >> 6);
#else
int warp = blockIdx.x + (threadIdx.x >> 5); int warp = blockIdx.x + (threadIdx.x >> 5);
#endif
int dest[RANKS]; int dest[RANKS];
#pragma unroll #pragma unroll
for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1);
...@@ -988,7 +1007,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ ...@@ -988,7 +1007,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1; if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1;
} }
#ifdef __HIP_PLATFORM_AMD__
int warp = blockIdx.x + (threadIdx.x >> 6);
#else
int warp = blockIdx.x + (threadIdx.x >> 5); int warp = blockIdx.x + (threadIdx.x >> 5);
#endif
int dest[RANKS]; int dest[RANKS];
#pragma unroll #pragma unroll
for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1);
...@@ -1084,7 +1107,11 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1084,7 +1107,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1; if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1;
} }
#ifdef __HIP_PLATFORM_AMD__
int warp = blockIdx.x + (threadIdx.x >> 6);
#else
int warp = blockIdx.x + (threadIdx.x >> 5); int warp = blockIdx.x + (threadIdx.x >> 5);
#endif
int dest[RANKS]; int dest[RANKS];
#pragma unroll #pragma unroll
for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1);
...@@ -1181,7 +1208,11 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1181,7 +1208,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1; if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1;
} }
#ifdef __HIP_PLATFORM_AMD__
int warp = blockIdx.x + (threadIdx.x >> 6);
#else
int warp = blockIdx.x + (threadIdx.x >> 5); int warp = blockIdx.x + (threadIdx.x >> 5);
#endif
int dest[RANKS]; int dest[RANKS];
#pragma unroll #pragma unroll
for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1);
...@@ -1236,7 +1267,11 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1236,7 +1267,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
clock_t s = clock64(); clock_t s = clock64();
} }
#ifdef __HIP_PLATFORM_AMD__
int warp = blockIdx.x + (threadIdx.x >> 6);
#else
int warp = blockIdx.x + (threadIdx.x >> 5); int warp = blockIdx.x + (threadIdx.x >> 5);
#endif
int dest[RANKS]; int dest[RANKS];
int skipmy = 0; int skipmy = 0;
...@@ -1314,7 +1349,11 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1314,7 +1349,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
__syncthreads(); __syncthreads();
localptr = userptr[myrank]; localptr = userptr[myrank];
#ifdef __HIP_PLATFORM_AMD__
int warp = blockIdx.x + (threadIdx.x >> 6);
#else
int warp = blockIdx.x + (threadIdx.x >> 5); int warp = blockIdx.x + (threadIdx.x >> 5);
#endif
int dest[RANKS - 1]; int dest[RANKS - 1];
int skipmy = 0; int skipmy = 0;
#pragma unroll #pragma unroll
...@@ -1719,6 +1758,12 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1719,6 +1758,12 @@ __global__ void __launch_bounds__(MAX_THREADS)
kernelArgs)); \ kernelArgs)); \
} }
#ifdef __HIP_PLATFORM_AMD__
#define WARPSIZE 64
#else
#define WARPSIZE 32
#endif
void reducescatter2_userbuff_strided(void *output, const int handler, const int offset, void reducescatter2_userbuff_strided(void *output, const int handler, const int offset,
const int rowelements, const int colelements, const int rowelements, const int colelements,
const int strideelements, communicator *comm, const int strideelements, communicator *comm,
...@@ -1733,10 +1778,10 @@ void reducescatter2_userbuff_strided(void *output, const int handler, const int ...@@ -1733,10 +1778,10 @@ void reducescatter2_userbuff_strided(void *output, const int handler, const int
if (elements < 64) return; if (elements < 64) return;
int sms = ar_nvsize == 1 ? 2 : comm->sms; int sms = ar_nvsize == 1 ? 2 : comm->sms;
int warps = comm->threads / 32; int warps = comm->threads / WARPSIZE;
if (warps < ar_nvsize) warps = ar_nvsize; if (warps < ar_nvsize) warps = ar_nvsize;
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); SETUP_LAUNCH_CONFIG(sms, warps * WARPSIZE, stream);
callranks_rs_oop_stride(2) callranks_rs_oop_stride(4) callranks_rs_oop_stride(8) callranks_rs_oop_stride(2) callranks_rs_oop_stride(4) callranks_rs_oop_stride(8)
callranks_rs_oop_stride(16) callranks_rs_oop_stride(32) callranks_rs_oop_stride(16) callranks_rs_oop_stride(32)
} }
...@@ -1755,10 +1800,10 @@ void reducescatter2_userbuff_strided_atomic(void *output, const int handler, con ...@@ -1755,10 +1800,10 @@ void reducescatter2_userbuff_strided_atomic(void *output, const int handler, con
if (elements < 64) return; if (elements < 64) return;
int sms = ar_nvsize == 1 ? 2 : comm->sms; int sms = ar_nvsize == 1 ? 2 : comm->sms;
int warps = comm->threads / 32; int warps = comm->threads / WARPSIZE;
if (warps < ar_nvsize) warps = ar_nvsize; if (warps < ar_nvsize) warps = ar_nvsize;
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); SETUP_LAUNCH_CONFIG(sms, warps * WARPSIZE, stream);
callranks_rs_oop_stride_atomic(2) callranks_rs_oop_stride_atomic(4) callranks_rs_oop_stride_atomic(2) callranks_rs_oop_stride_atomic(4)
callranks_rs_oop_stride_atomic(8) callranks_rs_oop_stride_atomic(16) callranks_rs_oop_stride_atomic(8) callranks_rs_oop_stride_atomic(16)
callranks_rs_oop_stride_atomic(32) callranks_rs_oop_stride_atomic(32)
...@@ -1782,10 +1827,10 @@ void reducescatter2_userbuff_strided_universal_fp8(void *output, float *scale, c ...@@ -1782,10 +1827,10 @@ void reducescatter2_userbuff_strided_universal_fp8(void *output, float *scale, c
assert(comm->sm_arch >= 9); assert(comm->sm_arch >= 9);
if (elements < 128) return; if (elements < 128) return;
int sms = ar_nvsize == 1 ? 2 : comm->sms; int sms = ar_nvsize == 1 ? 2 : comm->sms;
int warps = comm->threads / 32; int warps = comm->threads / WARPSIZE;
if (warps < ar_nvsize) warps = ar_nvsize; if (warps < ar_nvsize) warps = ar_nvsize;
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); SETUP_LAUNCH_CONFIG(sms, warps * WARPSIZE, stream);
callranks_rs_oop_atomic_fp8(2) callranks_rs_oop_atomic_fp8(4) callranks_rs_oop_atomic_fp8(8) callranks_rs_oop_atomic_fp8(2) callranks_rs_oop_atomic_fp8(4) callranks_rs_oop_atomic_fp8(8)
callranks_rs_oop_atomic_fp8(16) callranks_rs_oop_atomic_fp8(32) callranks_rs_oop_atomic_fp8(16) callranks_rs_oop_atomic_fp8(32)
} }
...@@ -1827,10 +1872,10 @@ void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler ...@@ -1827,10 +1872,10 @@ void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler
if (elements < 64) return; if (elements < 64) return;
int sms = ar_nvsize == 1 ? 2 : comm->sms; int sms = ar_nvsize == 1 ? 2 : comm->sms;
int warps = comm->threads / 32; int warps = comm->threads / WARPSIZE;
if (warps < ar_nvsize) warps = ar_nvsize; if (warps < ar_nvsize) warps = ar_nvsize;
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); SETUP_LAUNCH_CONFIG(sms, warps * WARPSIZE, stream);
callranks_rs_oop_stride_multiatomic(2) callranks_rs_oop_stride_multiatomic(4) callranks_rs_oop_stride_multiatomic(2) callranks_rs_oop_stride_multiatomic(4)
callranks_rs_oop_stride_multiatomic(8) callranks_rs_oop_stride_multiatomic(16) callranks_rs_oop_stride_multiatomic(8) callranks_rs_oop_stride_multiatomic(16)
callranks_rs_oop_stride_multiatomic(32) callranks_rs_oop_stride_multiatomic(32)
...@@ -1848,18 +1893,18 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int ...@@ -1848,18 +1893,18 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int
if (elements < 64) return; if (elements < 64) return;
int sms = ar_nvsize == 1 ? 2 : comm->sms; int sms = ar_nvsize == 1 ? 2 : comm->sms;
int warps = comm->threads / 32; int warps = comm->threads / WARPSIZE;
if (warps < ar_nvsize) warps = ar_nvsize; if (warps < ar_nvsize) warps = ar_nvsize;
if (comm_launch_event) { if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * WARPSIZE, stream, comm_launch_event);
if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) {
callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) callranks_agMC(16) callranks_agMC(32) callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) callranks_agMC(16) callranks_agMC(32)
} else { } else {
callranks_ag(2) callranks_ag(4) callranks_ag(8) callranks_ag(16) callranks_ag(32) callranks_ag(2) callranks_ag(4) callranks_ag(8) callranks_ag(16) callranks_ag(32)
} }
} else { } else {
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); SETUP_LAUNCH_CONFIG(sms, warps * WARPSIZE, stream);
if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) {
callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) callranks_agMC(16) callranks_agMC(32) callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) callranks_agMC(16) callranks_agMC(32)
} else { } else {
...@@ -1895,18 +1940,18 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const ...@@ -1895,18 +1940,18 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const
if (elements < 64) return; if (elements < 64) return;
int sms = ar_nvsize == 1 ? 2 : comm->sms; int sms = ar_nvsize == 1 ? 2 : comm->sms;
int warps = comm->threads / 32; int warps = comm->threads / WARPSIZE;
if (warps < ar_nvsize) warps = ar_nvsize; if (warps < ar_nvsize) warps = ar_nvsize;
if (comm_launch_event) { if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * WARPSIZE, stream, comm_launch_event);
if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) {
callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) callranks_rsMC(16) callranks_rsMC(32) callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) callranks_rsMC(16) callranks_rsMC(32)
} else { } else {
callranks_rs(2) callranks_rs(4) callranks_rs(8) callranks_rs(16) callranks_rs(32) callranks_rs(2) callranks_rs(4) callranks_rs(8) callranks_rs(16) callranks_rs(32)
} }
} else { } else {
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); SETUP_LAUNCH_CONFIG(sms, warps * WARPSIZE, stream);
if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) {
callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) callranks_rsMC(16) callranks_rsMC(32) callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) callranks_rsMC(16) callranks_rsMC(32)
} else { } else {
...@@ -1928,11 +1973,11 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons ...@@ -1928,11 +1973,11 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons
if (elements < 64) return; if (elements < 64) return;
int sms = ar_nvsize == 1 ? 2 : comm->sms; int sms = ar_nvsize == 1 ? 2 : comm->sms;
int warps = comm->threads / 32; int warps = comm->threads / WARPSIZE;
if (warps < ar_nvsize) warps = ar_nvsize; if (warps < ar_nvsize) warps = ar_nvsize;
if (comm_launch_event) { if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * WARPSIZE, stream, comm_launch_event);
if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) {
callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) callranks_rs_oopMC(16) callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) callranks_rs_oopMC(16)
callranks_rs_oopMC(32) callranks_rs_oopMC(32)
...@@ -1941,7 +1986,7 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons ...@@ -1941,7 +1986,7 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons
callranks_rs_oop(32) callranks_rs_oop(32)
} }
} else { } else {
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); SETUP_LAUNCH_CONFIG(sms, warps * WARPSIZE, stream);
if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) {
callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) callranks_rs_oopMC(16) callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) callranks_rs_oopMC(16)
callranks_rs_oopMC(32) callranks_rs_oopMC(32)
...@@ -1974,15 +2019,15 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const ...@@ -1974,15 +2019,15 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const
assert(comm->sm_arch >= 9); assert(comm->sm_arch >= 9);
if (elements < 128) return; if (elements < 128) return;
int sms = ar_nvsize == 1 ? 2 : comm->sms; int sms = ar_nvsize == 1 ? 2 : comm->sms;
int warps = comm->threads / 32; int warps = comm->threads / WARPSIZE;
if (warps < ar_nvsize) warps = ar_nvsize; if (warps < ar_nvsize) warps = ar_nvsize;
if (comm_launch_event) { if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * WARPSIZE, stream, comm_launch_event);
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) callranks_rs_oop_fp8(16) callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) callranks_rs_oop_fp8(16)
callranks_rs_oop_fp8(32) callranks_rs_oop_fp8(32)
} else { } else {
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); SETUP_LAUNCH_CONFIG(sms, warps * WARPSIZE, stream);
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) callranks_rs_oop_fp8(16) callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) callranks_rs_oop_fp8(16)
callranks_rs_oop_fp8(32) callranks_rs_oop_fp8(32)
} }
......
...@@ -93,11 +93,11 @@ def general_gemm( ...@@ -93,11 +93,11 @@ def general_gemm(
transb = layout[1] == "T" transb = layout[1] == "T"
# assert quantization_params is None, "FP8 output not supported yet" # assert quantization_params is None, "FP8 output not supported yet"
if ub_type is not None: # if ub_type is not None:
assert ub is not None, ( # assert ub is not None, (
f"{'AG+GEMM' if ub_type == tex.CommOverlapType.AG else 'GEMM+RS'} overlap requires" # f"{'AG+GEMM' if ub_type == tex.CommOverlapType.AG else 'GEMM+RS'} overlap requires"
+ "a valid `ub` communicator object." # + "a valid `ub` communicator object."
) # )
if ub is not None: if ub is not None:
assert ub_type is not None, "Comm+GEMM overlap requires a valid `comm_type` argument." assert ub_type is not None, "Comm+GEMM overlap requires a valid `comm_type` argument."
......
...@@ -47,7 +47,7 @@ _multi_stream_cublas_workspace = [] ...@@ -47,7 +47,7 @@ _multi_stream_cublas_workspace = []
_multi_stream_cublas_batchgemm_workspace = [] _multi_stream_cublas_batchgemm_workspace = []
_cublas_workspace = None _cublas_workspace = None
_ub_communicators = None _ub_communicators = None
_NUM_MAX_UB_STREAMS = 1 if IS_HIP_EXTENSION else 3 _NUM_MAX_UB_STREAMS = 2 if IS_HIP_EXTENSION else 3
_MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None
layers_atomic_ring_exchange = [] layers_atomic_ring_exchange = []
...@@ -92,6 +92,10 @@ def get_multi_stream_cublas_batchgemm_workspace() -> List[torch.Tensor]: ...@@ -92,6 +92,10 @@ def get_multi_stream_cublas_batchgemm_workspace() -> List[torch.Tensor]:
) )
return _multi_stream_cublas_batchgemm_workspace return _multi_stream_cublas_batchgemm_workspace
if bool(int(os.getenv("NVTE_DISABLE_FC2_DGRAD_OVERLAP", "0"))):
remove_ag_gemm_dgrad = ["fc2_dgrad"]
else:
remove_ag_gemm_dgrad = []
def initialize_ub( def initialize_ub(
shape: list, shape: list,
...@@ -237,11 +241,18 @@ def initialize_ub( ...@@ -237,11 +241,18 @@ def initialize_ub(
layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"] layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"]
dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"] dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"]
# Default overlap methods for layers # Default overlap methods for layers
methods = { if bool(int(os.getenv("NVTE_NO_PIPELINE_OVERLAP", "0"))):
"ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"], methods = {
"pipeline": ["proj_fprop", "fc2_fprop"], "ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad", "proj_fprop", "fc2_fprop"],
"bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"], "pipeline": [],
} "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
}
else:
methods = {
"ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"],
"pipeline": ["proj_fprop", "fc2_fprop"],
"bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
}
# AG-RS overlap pairs of layers forming a tensor-parallel block # AG-RS overlap pairs of layers forming a tensor-parallel block
ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"} ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"}
...@@ -264,7 +275,7 @@ def initialize_ub( ...@@ -264,7 +275,7 @@ def initialize_ub(
default_cfg = { default_cfg = {
"method": method, "method": method,
"is_reduce_scatter": is_reduce_scatter, "is_reduce_scatter": is_reduce_scatter,
"num_sm": 1 if method == "ring_exchange" else 16, "num_sm": 1 if method == "ring_exchange" else 8,
"cga_size": 1 if method == "ring_exchange" else 2, "cga_size": 1 if method == "ring_exchange" else 2,
"set_sm_margin": not method == "ring_exchange", "set_sm_margin": not method == "ring_exchange",
"num_splits": tp_size if method == "ring_exchange" else 4, "num_splits": tp_size if method == "ring_exchange" else 4,
...@@ -377,6 +388,8 @@ def initialize_ub( ...@@ -377,6 +388,8 @@ def initialize_ub(
methods[new_method].append(name) methods[new_method].append(name)
for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]: for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]:
if name in remove_ag_gemm_dgrad:
continue
ub_cfg = get_default_config(name) ub_cfg = get_default_config(name)
if ub_cfgs is not None and name in ub_cfgs: if ub_cfgs is not None and name in ub_cfgs:
fp8_buf = (name in layers_all_gather_overlap) or ( fp8_buf = (name in layers_all_gather_overlap) or (
...@@ -390,7 +403,9 @@ def initialize_ub( ...@@ -390,7 +403,9 @@ def initialize_ub(
def get_ub(name: str): def get_ub(name: str):
"""Get userbuffer communicator corresponding to give key.""" """Get userbuffer communicator corresponding to give key."""
assert _ub_communicators is not None, "UB manager is not initialized." assert _ub_communicators is not None, "UB manager is not initialized."
assert name in _ub_communicators, f"UB for {name} is not registered." # assert name in _ub_communicators, f"UB for {name} is not registered."
if name in remove_ag_gemm_dgrad:
return None
return _ub_communicators[name] return _ub_communicators[name]
...@@ -841,7 +856,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -841,7 +856,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Non-FP8 case: bgrad is fused with wgrad for this case. # Non-FP8 case: bgrad is fused with wgrad for this case.
if not ctx.fp8: if not ctx.fp8:
if gather_grad_output: if gather_grad_output:
if not ctx.ub_overlap_ag: if not ctx.ub_overlap_ag or ctx.ub_obj_gradout is None:
grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group) grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group)
else: else:
ctx.ub_obj_gradout.copy_into_buffer(grad_output, quantizer, local_chunk=True) ctx.ub_obj_gradout.copy_into_buffer(grad_output, quantizer, local_chunk=True)
...@@ -853,7 +868,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -853,7 +868,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_bias = None grad_bias = None
if ctx.use_bias: if ctx.use_bias:
grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
if ctx.ub_overlap_ag: if ctx.ub_overlap_ag and ctx.ub_obj_gradout is not None:
# Quantize the gradient if needed # Quantize the gradient if needed
if not isinstance( if not isinstance(
grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase) grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment