Commit 75e9ef24 authored by yuguo's avatar yuguo
Browse files
parents 5753c5bb 291fcf52
......@@ -45,6 +45,21 @@ bool ubuf_built_with_mpi() {
#endif
}
static inline int getIntEnv(const char *name, int defval, int minval)
{
int val = defval;
const char* env = std::getenv(name);
if (env != nullptr && env[0] != '\0')
{
val = atoi(env);
if (val < minval)
{
val = minval;
}
}
return val;
}
CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode,
int numnodes, int tp_size, ExtAllgatherOp allgather_handle,
ExtBarrierOp barrier_handle, int num_splits, int num_max_streams,
......@@ -74,10 +89,41 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
_gemm_priority = gemm_priority;
_comm_priority = comm_priority;
}
int comm_cu_nums = getIntEnv("NVTE_UB_COMM_CU_NUMS", 8, 4);
unsigned int cuMask[4];
unsigned int cuMaskSize = 4;
if (comm_cu_nums == 4) {
cuMask[0] = 0xfffffff0;
cuMask[1] = 0xffffffff;
cuMask[2] = 0xffffffff;
cuMask[3] = 0xffffffff;
} else if (comm_cu_nums == 8) {
cuMask[0] = 0xffffff00;
cuMask[1] = 0xffffffff;
cuMask[2] = 0xffffffff;
cuMask[3] = 0xffffffff;
} else if (comm_cu_nums == 16) {
cuMask[0] = 0xffff0000;
cuMask[1] = 0xffffffff;
cuMask[2] = 0xffffffff;
cuMask[3] = 0xffffffff;
} else if (comm_cu_nums == 32) {
cuMask[0] = 0x00000000;
cuMask[1] = 0xffffffff;
cuMask[2] = 0xffffffff;
cuMask[3] = 0xffffffff;
} else {
NVTE_CHECK(false, "comm_cu_nums must be 4,8,16,32");
}
static cudaStream_t compute_streams[NVTE_COMM_OVERLAP_MAX_STREAMS];
for (int i = 0; i < std::min(num_max_streams, num_splits); i++) {
if (compute_streams[i] == nullptr) {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams[i], cudaStreamNonBlocking, _gemm_priority));
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA(hipExtStreamCreateWithCUMask(&compute_streams[i], cuMaskSize, cuMask));
#endif
}
_stream_compute.push_back(compute_streams[i]);
}
......@@ -268,6 +314,10 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size,
gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false,
atomic_gemm) {
const char *NVTE_BLAS_MULSTREAM = std::getenv("NVTE_FORCE_BLAS_MULSTREAM");
if(NVTE_BLAS_MULSTREAM != nullptr && NVTE_BLAS_MULSTREAM[0] == '1'){
_ub_force_blas_multistream = true;
}
_ub_stream_nums = num_max_streams;
_rs_overlap_first_gemm = rs_overlap_first_gemm;
_rs_kernel_type = getenv<int>("NVTE_RS_STRIDED_ATOMIC", 0);
......@@ -282,10 +332,41 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg);
_ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype);
int comm_cu_nums = getIntEnv("NVTE_UB_COMM_CU_NUMS", 8, 4);
unsigned int cuMask[4];
unsigned int cuMaskSize = 4;
if (comm_cu_nums == 4) {
cuMask[0] = 0x0000000f;
cuMask[1] = 0x00000000;
cuMask[2] = 0x00000000;
cuMask[3] = 0x00000000;
} else if (comm_cu_nums == 8) {
cuMask[0] = 0x000000ff;
cuMask[1] = 0x00000000;
cuMask[2] = 0x00000000;
cuMask[3] = 0x00000000;
} else if (comm_cu_nums == 16) {
cuMask[0] = 0x0000ffff;
cuMask[1] = 0x00000000;
cuMask[2] = 0x00000000;
cuMask[3] = 0x00000000;
} else if (comm_cu_nums == 32) {
cuMask[0] = 0xffffffff;
cuMask[1] = 0x00000000;
cuMask[2] = 0x00000000;
cuMask[3] = 0x00000000;
} else {
NVTE_CHECK(false, "comm_cu_nums must be 4,8,16,32");
}
static cudaStream_t comm_stream;
if (comm_stream == nullptr) {
NVTE_CHECK_CUDA(
cudaStreamCreateWithPriority(&comm_stream, cudaStreamNonBlocking, _comm_priority));
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA(
hipExtStreamCreateWithCUMask(&comm_stream, cuMaskSize, cuMask));
#endif
}
_stream_comm = comm_stream;
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_d2dcopy, 0));
......@@ -499,7 +580,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
auto bias_chunk = maybe_get_bias_chunk(0);
auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk});
if (_ub_stream_nums == 1) {
if (_ub_stream_nums == 1 || _ub_force_blas_multistream == 1) {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[0]);
......@@ -516,7 +597,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
if (_ub_stream_nums == 1) {
if (_ub_stream_nums == 1 || _ub_force_blas_multistream == 1) {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
......@@ -572,7 +653,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
if (_ub_stream_nums == 1) {
if (_ub_stream_nums == 1 || _ub_force_blas_multistream == 1) {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
......@@ -631,7 +712,10 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size,
gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce,
atomic_gemm) {
const char *NVTE_BLAS_MULSTREAM = std::getenv("NVTE_FORCE_BLAS_MULSTREAM");
if(NVTE_BLAS_MULSTREAM != nullptr && NVTE_BLAS_MULSTREAM[0] == '1'){
_ub_force_blas_multistream = true;
}
_ub_stream_nums = num_max_streams;
_is_p2p = true;
_is_reduce_scatter = comm_type == CommOverlapType::RS;
......@@ -902,7 +986,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
if (_ub_stream_nums == 1) {
if (_ub_stream_nums == 1 || _ub_force_blas_multistream == 1) {
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms,
......@@ -962,7 +1046,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
if (_ub_stream_nums == 1) {
if (_ub_stream_nums == 1 || _ub_force_blas_multistream == 1) {
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms,
......@@ -1115,7 +1199,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
auto workspace_chunk =
get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk});
if (_ub_stream_nums == 1) {
if (_ub_stream_nums == 1 || _ub_force_blas_multistream == 1) {
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[stream_id]);
......
......@@ -784,17 +784,17 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams[s], cublas_event[0]));
}
const char *NVTE_HIPBLAS_MULSTREAM = std::getenv("NVTE_FORCE_HIPBLAS_MULSTREAM");
const char *NVTE_BLAS_MULSTREAM = std::getenv("NVTE_FORCE_BLAS_MULSTREAM");
const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
bool NVTE_FORCE_HIPBLAS_MULSTREAM;
if(NVTE_HIPBLAS_MULSTREAM != nullptr && NVTE_HIPBLAS_MULSTREAM[0] == '1'){
NVTE_FORCE_HIPBLAS_MULSTREAM = true;
if((NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') && (NVTE_HIPBLAS_MULSTREAM != nullptr && NVTE_HIPBLAS_MULSTREAM[0] == '1'))
NVTE_ERROR("NVTE_FORCE_HIPBLAS_MULSTREAM and NVTE_FORCE_ROCM_GEMM can't be set at the same time.");
bool NVTE_FORCE_BLAS_MULSTREAM;
if(NVTE_BLAS_MULSTREAM != nullptr && NVTE_BLAS_MULSTREAM[0] == '1'){
NVTE_FORCE_BLAS_MULSTREAM = true;
if((NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') && (NVTE_BLAS_MULSTREAM != nullptr && NVTE_BLAS_MULSTREAM[0] == '1'))
NVTE_ERROR("NVTE_FORCE_BLAS_MULSTREAM and NVTE_FORCE_ROCM_GEMM can't be set at the same time.");
} else{
NVTE_FORCE_HIPBLAS_MULSTREAM = false;
NVTE_FORCE_BLAS_MULSTREAM = false;
}
if (NVTE_FORCE_HIPBLAS_MULSTREAM){
if (NVTE_FORCE_BLAS_MULSTREAM){
for (int i = 0; i < num_gemms; i++) {
nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
......@@ -838,7 +838,7 @@ void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_stream_cublas_batchgemm);
using namespace transformer_engine;
int batch_count = getIntEnv("NVTE_MOE_BATCHCOUNT", 2, 1);;
int batch_count = getIntEnv("NVTE_MOE_BATCHCOUNT", 2, 1);
// Inits streams and events (once, globally)
std::call_once(init_flag_batchgemm, init_streams_and_events_batchgemm);
......
......@@ -138,6 +138,7 @@ class CommOverlapCore {
class CommOverlapBase : public CommOverlapCore {
protected:
int _ub_stream_nums;
bool _ub_force_blas_multistream;
int _rs_kernel_type;
bool _rs_overlap_first_gemm;
cudaStream_t _stream_comm;
......@@ -204,6 +205,7 @@ class CommOverlapBase : public CommOverlapCore {
class CommOverlapP2PBase : public CommOverlapCore {
protected:
int _ub_stream_nums;
bool _ub_force_blas_multistream;
bool _is_reduce_scatter{false};
bool _use_multiatomic_ag{false};
bool _aggregate;
......
......@@ -128,7 +128,7 @@ def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor
_dummy_wgrads[(shape[0], shape[1], dtype)].fill_(0)
return _dummy_wgrads[(shape[0], shape[1], dtype)].detach()
ub_comm_cu_nums = int(os.getenv("NVTE_UB_COMM_CU_NUMS", "8"))
def initialize_ub(
shape: list,
tp_size: int,
......@@ -279,12 +279,24 @@ def initialize_ub(
layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"]
dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"]
# Default overlap methods for layers
if bool(int(os.getenv("NVTE_NO_PIPELINE_OVERLAP", "0"))):
if bool(int(os.getenv("NVTE_PROJ_NO_PIPELINE_OVERLAP", "0"))) and bool(int(os.getenv("NVTE_FC2_NO_PIPELINE_OVERLAP", "0"))):
methods = {
"ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad", "proj_fprop", "fc2_fprop"],
"pipeline": [],
"bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
}
elif bool(int(os.getenv("NVTE_PROJ_NO_PIPELINE_OVERLAP", "0"))):
methods = {
"ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad", "proj_fprop"],
"pipeline": ["fc2_fprop"],
"bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
}
elif bool(int(os.getenv("NVTE_FC2_NO_PIPELINE_OVERLAP", "0"))):
methods = {
"ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad", "fc2_fprop"],
"pipeline": ["proj_fprop"],
"bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
}
else:
methods = {
"ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"],
......@@ -313,7 +325,7 @@ def initialize_ub(
default_cfg = {
"method": method,
"is_reduce_scatter": is_reduce_scatter,
"num_sm": 1 if method == "ring_exchange" else 8,
"num_sm": 1 if method == "ring_exchange" else ub_comm_cu_nums,
"cga_size": 1 if method == "ring_exchange" else 2,
"set_sm_margin": not method == "ring_exchange",
"num_splits": tp_size if method == "ring_exchange" else 4,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment