Commit 75e9ef24 authored by yuguo's avatar yuguo
Browse files
parents 5753c5bb 291fcf52
...@@ -45,6 +45,21 @@ bool ubuf_built_with_mpi() { ...@@ -45,6 +45,21 @@ bool ubuf_built_with_mpi() {
#endif #endif
} }
static inline int getIntEnv(const char *name, int defval, int minval)
{
int val = defval;
const char* env = std::getenv(name);
if (env != nullptr && env[0] != '\0')
{
val = atoi(env);
if (val < minval)
{
val = minval;
}
}
return val;
}
CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode,
int numnodes, int tp_size, ExtAllgatherOp allgather_handle, int numnodes, int tp_size, ExtAllgatherOp allgather_handle,
ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, ExtBarrierOp barrier_handle, int num_splits, int num_max_streams,
...@@ -74,10 +89,41 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl ...@@ -74,10 +89,41 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
_gemm_priority = gemm_priority; _gemm_priority = gemm_priority;
_comm_priority = comm_priority; _comm_priority = comm_priority;
} }
int comm_cu_nums = getIntEnv("NVTE_UB_COMM_CU_NUMS", 8, 4);
unsigned int cuMask[4];
unsigned int cuMaskSize = 4;
if (comm_cu_nums == 4) {
cuMask[0] = 0xfffffff0;
cuMask[1] = 0xffffffff;
cuMask[2] = 0xffffffff;
cuMask[3] = 0xffffffff;
} else if (comm_cu_nums == 8) {
cuMask[0] = 0xffffff00;
cuMask[1] = 0xffffffff;
cuMask[2] = 0xffffffff;
cuMask[3] = 0xffffffff;
} else if (comm_cu_nums == 16) {
cuMask[0] = 0xffff0000;
cuMask[1] = 0xffffffff;
cuMask[2] = 0xffffffff;
cuMask[3] = 0xffffffff;
} else if (comm_cu_nums == 32) {
cuMask[0] = 0x00000000;
cuMask[1] = 0xffffffff;
cuMask[2] = 0xffffffff;
cuMask[3] = 0xffffffff;
} else {
NVTE_CHECK(false, "comm_cu_nums must be 4,8,16,32");
}
static cudaStream_t compute_streams[NVTE_COMM_OVERLAP_MAX_STREAMS]; static cudaStream_t compute_streams[NVTE_COMM_OVERLAP_MAX_STREAMS];
for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { for (int i = 0; i < std::min(num_max_streams, num_splits); i++) {
if (compute_streams[i] == nullptr) { if (compute_streams[i] == nullptr) {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams[i], cudaStreamNonBlocking, _gemm_priority)); NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams[i], cudaStreamNonBlocking, _gemm_priority));
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA(hipExtStreamCreateWithCUMask(&compute_streams[i], cuMaskSize, cuMask));
#endif
} }
_stream_compute.push_back(compute_streams[i]); _stream_compute.push_back(compute_streams[i]);
} }
...@@ -268,6 +314,10 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType ...@@ -268,6 +314,10 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size, allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size,
gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false,
atomic_gemm) { atomic_gemm) {
const char *NVTE_BLAS_MULSTREAM = std::getenv("NVTE_FORCE_BLAS_MULSTREAM");
if(NVTE_BLAS_MULSTREAM != nullptr && NVTE_BLAS_MULSTREAM[0] == '1'){
_ub_force_blas_multistream = true;
}
_ub_stream_nums = num_max_streams; _ub_stream_nums = num_max_streams;
_rs_overlap_first_gemm = rs_overlap_first_gemm; _rs_overlap_first_gemm = rs_overlap_first_gemm;
_rs_kernel_type = getenv<int>("NVTE_RS_STRIDED_ATOMIC", 0); _rs_kernel_type = getenv<int>("NVTE_RS_STRIDED_ATOMIC", 0);
...@@ -282,10 +332,41 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType ...@@ -282,10 +332,41 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg); if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg);
_ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype); _ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype);
int comm_cu_nums = getIntEnv("NVTE_UB_COMM_CU_NUMS", 8, 4);
unsigned int cuMask[4];
unsigned int cuMaskSize = 4;
if (comm_cu_nums == 4) {
cuMask[0] = 0x0000000f;
cuMask[1] = 0x00000000;
cuMask[2] = 0x00000000;
cuMask[3] = 0x00000000;
} else if (comm_cu_nums == 8) {
cuMask[0] = 0x000000ff;
cuMask[1] = 0x00000000;
cuMask[2] = 0x00000000;
cuMask[3] = 0x00000000;
} else if (comm_cu_nums == 16) {
cuMask[0] = 0x0000ffff;
cuMask[1] = 0x00000000;
cuMask[2] = 0x00000000;
cuMask[3] = 0x00000000;
} else if (comm_cu_nums == 32) {
cuMask[0] = 0xffffffff;
cuMask[1] = 0x00000000;
cuMask[2] = 0x00000000;
cuMask[3] = 0x00000000;
} else {
NVTE_CHECK(false, "comm_cu_nums must be 4,8,16,32");
}
static cudaStream_t comm_stream; static cudaStream_t comm_stream;
if (comm_stream == nullptr) { if (comm_stream == nullptr) {
NVTE_CHECK_CUDA( NVTE_CHECK_CUDA(
cudaStreamCreateWithPriority(&comm_stream, cudaStreamNonBlocking, _comm_priority)); cudaStreamCreateWithPriority(&comm_stream, cudaStreamNonBlocking, _comm_priority));
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA(
hipExtStreamCreateWithCUMask(&comm_stream, cuMaskSize, cuMask));
#endif
} }
_stream_comm = comm_stream; _stream_comm = comm_stream;
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_d2dcopy, 0)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_d2dcopy, 0));
...@@ -499,7 +580,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -499,7 +580,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
auto bias_chunk = maybe_get_bias_chunk(0); auto bias_chunk = maybe_get_bias_chunk(0);
auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk}); auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk});
if (_ub_stream_nums == 1) { if (_ub_stream_nums == 1 || _ub_force_blas_multistream == 1) {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(), nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[0]); use_split_accumulator, _math_sms, _stream_compute[0]);
...@@ -516,7 +597,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -516,7 +597,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
workspace_chunk = get_tensor_chunk( workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
if (_ub_stream_nums == 1) { if (_ub_stream_nums == 1 || _ub_force_blas_multistream == 1) {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(), nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms, accumulate, use_split_accumulator, _math_sms,
...@@ -572,7 +653,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -572,7 +653,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
auto workspace_chunk = get_tensor_chunk( auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
if (_ub_stream_nums == 1) { if (_ub_stream_nums == 1 || _ub_force_blas_multistream == 1) {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(), nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms, accumulate, use_split_accumulator, _math_sms,
...@@ -631,7 +712,10 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape, ...@@ -631,7 +712,10 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size, allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size,
gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce,
atomic_gemm) { atomic_gemm) {
const char *NVTE_BLAS_MULSTREAM = std::getenv("NVTE_FORCE_BLAS_MULSTREAM");
if(NVTE_BLAS_MULSTREAM != nullptr && NVTE_BLAS_MULSTREAM[0] == '1'){
_ub_force_blas_multistream = true;
}
_ub_stream_nums = num_max_streams; _ub_stream_nums = num_max_streams;
_is_p2p = true; _is_p2p = true;
_is_reduce_scatter = comm_type == CommOverlapType::RS; _is_reduce_scatter = comm_type == CommOverlapType::RS;
...@@ -902,7 +986,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -902,7 +986,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
auto workspace_chunk = get_tensor_chunk( auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
if (_ub_stream_nums == 1) { if (_ub_stream_nums == 1 || _ub_force_blas_multistream == 1) {
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, use_split_accumulator, _math_sms,
...@@ -962,7 +1046,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -962,7 +1046,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
auto workspace_chunk = get_tensor_chunk( auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
if (_ub_stream_nums == 1) { if (_ub_stream_nums == 1 || _ub_force_blas_multistream == 1) {
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, use_split_accumulator, _math_sms,
...@@ -1115,7 +1199,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, ...@@ -1115,7 +1199,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
auto workspace_chunk = auto workspace_chunk =
get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk}); get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk});
if (_ub_stream_nums == 1) { if (_ub_stream_nums == 1 || _ub_force_blas_multistream == 1) {
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[stream_id]); use_split_accumulator, _math_sms, _stream_compute[stream_id]);
......
...@@ -784,17 +784,17 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT ...@@ -784,17 +784,17 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
for (int s = 0; s < num_stream_used; s++) { for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams[s], cublas_event[0])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams[s], cublas_event[0]));
} }
const char *NVTE_HIPBLAS_MULSTREAM = std::getenv("NVTE_FORCE_HIPBLAS_MULSTREAM"); const char *NVTE_BLAS_MULSTREAM = std::getenv("NVTE_FORCE_BLAS_MULSTREAM");
const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM"); const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
bool NVTE_FORCE_HIPBLAS_MULSTREAM; bool NVTE_FORCE_BLAS_MULSTREAM;
if(NVTE_HIPBLAS_MULSTREAM != nullptr && NVTE_HIPBLAS_MULSTREAM[0] == '1'){ if(NVTE_BLAS_MULSTREAM != nullptr && NVTE_BLAS_MULSTREAM[0] == '1'){
NVTE_FORCE_HIPBLAS_MULSTREAM = true; NVTE_FORCE_BLAS_MULSTREAM = true;
if((NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') && (NVTE_HIPBLAS_MULSTREAM != nullptr && NVTE_HIPBLAS_MULSTREAM[0] == '1')) if((NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') && (NVTE_BLAS_MULSTREAM != nullptr && NVTE_BLAS_MULSTREAM[0] == '1'))
NVTE_ERROR("NVTE_FORCE_HIPBLAS_MULSTREAM and NVTE_FORCE_ROCM_GEMM can't be set at the same time."); NVTE_ERROR("NVTE_FORCE_BLAS_MULSTREAM and NVTE_FORCE_ROCM_GEMM can't be set at the same time.");
} else{ } else{
NVTE_FORCE_HIPBLAS_MULSTREAM = false; NVTE_FORCE_BLAS_MULSTREAM = false;
} }
if (NVTE_FORCE_HIPBLAS_MULSTREAM){ if (NVTE_FORCE_BLAS_MULSTREAM){
for (int i = 0; i < num_gemms; i++) { for (int i = 0; i < num_gemms; i++) {
nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad, nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count, workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
...@@ -838,7 +838,7 @@ void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B ...@@ -838,7 +838,7 @@ void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_stream_cublas_batchgemm); NVTE_API_CALL(nvte_multi_stream_cublas_batchgemm);
using namespace transformer_engine; using namespace transformer_engine;
int batch_count = getIntEnv("NVTE_MOE_BATCHCOUNT", 2, 1);; int batch_count = getIntEnv("NVTE_MOE_BATCHCOUNT", 2, 1);
// Inits streams and events (once, globally) // Inits streams and events (once, globally)
std::call_once(init_flag_batchgemm, init_streams_and_events_batchgemm); std::call_once(init_flag_batchgemm, init_streams_and_events_batchgemm);
......
...@@ -138,6 +138,7 @@ class CommOverlapCore { ...@@ -138,6 +138,7 @@ class CommOverlapCore {
class CommOverlapBase : public CommOverlapCore { class CommOverlapBase : public CommOverlapCore {
protected: protected:
int _ub_stream_nums; int _ub_stream_nums;
bool _ub_force_blas_multistream;
int _rs_kernel_type; int _rs_kernel_type;
bool _rs_overlap_first_gemm; bool _rs_overlap_first_gemm;
cudaStream_t _stream_comm; cudaStream_t _stream_comm;
...@@ -204,6 +205,7 @@ class CommOverlapBase : public CommOverlapCore { ...@@ -204,6 +205,7 @@ class CommOverlapBase : public CommOverlapCore {
class CommOverlapP2PBase : public CommOverlapCore { class CommOverlapP2PBase : public CommOverlapCore {
protected: protected:
int _ub_stream_nums; int _ub_stream_nums;
bool _ub_force_blas_multistream;
bool _is_reduce_scatter{false}; bool _is_reduce_scatter{false};
bool _use_multiatomic_ag{false}; bool _use_multiatomic_ag{false};
bool _aggregate; bool _aggregate;
......
...@@ -128,7 +128,7 @@ def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor ...@@ -128,7 +128,7 @@ def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor
_dummy_wgrads[(shape[0], shape[1], dtype)].fill_(0) _dummy_wgrads[(shape[0], shape[1], dtype)].fill_(0)
return _dummy_wgrads[(shape[0], shape[1], dtype)].detach() return _dummy_wgrads[(shape[0], shape[1], dtype)].detach()
ub_comm_cu_nums = int(os.getenv("NVTE_UB_COMM_CU_NUMS", "8"))
def initialize_ub( def initialize_ub(
shape: list, shape: list,
tp_size: int, tp_size: int,
...@@ -279,12 +279,24 @@ def initialize_ub( ...@@ -279,12 +279,24 @@ def initialize_ub(
layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"] layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"]
dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"] dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"]
# Default overlap methods for layers # Default overlap methods for layers
if bool(int(os.getenv("NVTE_NO_PIPELINE_OVERLAP", "0"))): if bool(int(os.getenv("NVTE_PROJ_NO_PIPELINE_OVERLAP", "0"))) and bool(int(os.getenv("NVTE_FC2_NO_PIPELINE_OVERLAP", "0"))):
methods = { methods = {
"ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad", "proj_fprop", "fc2_fprop"], "ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad", "proj_fprop", "fc2_fprop"],
"pipeline": [], "pipeline": [],
"bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"], "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
} }
elif bool(int(os.getenv("NVTE_PROJ_NO_PIPELINE_OVERLAP", "0"))):
methods = {
"ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad", "proj_fprop"],
"pipeline": ["fc2_fprop"],
"bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
}
elif bool(int(os.getenv("NVTE_FC2_NO_PIPELINE_OVERLAP", "0"))):
methods = {
"ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad", "fc2_fprop"],
"pipeline": ["proj_fprop"],
"bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
}
else: else:
methods = { methods = {
"ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"], "ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"],
...@@ -313,7 +325,7 @@ def initialize_ub( ...@@ -313,7 +325,7 @@ def initialize_ub(
default_cfg = { default_cfg = {
"method": method, "method": method,
"is_reduce_scatter": is_reduce_scatter, "is_reduce_scatter": is_reduce_scatter,
"num_sm": 1 if method == "ring_exchange" else 8, "num_sm": 1 if method == "ring_exchange" else ub_comm_cu_nums,
"cga_size": 1 if method == "ring_exchange" else 2, "cga_size": 1 if method == "ring_exchange" else 2,
"set_sm_margin": not method == "ring_exchange", "set_sm_margin": not method == "ring_exchange",
"num_splits": tp_size if method == "ring_exchange" else 4, "num_splits": tp_size if method == "ring_exchange" else 4,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment