Unverified Commit f458fcf4 authored by Sangkug Lym's avatar Sangkug Lym Committed by GitHub
Browse files

Add the option to use SM for P2P comm in TP overlap (#914)



* Add the option to use SM for P2P comm in TP overlap
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* cleanup
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* Python formatting with black
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Format C++ with clang-format
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/pytorch/csrc/comm_gemm_overlap.h
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent e2caf78d
......@@ -129,9 +129,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
at::cuda::CUDAStream _stream_comm = at::cuda::getStreamFromPool(true);
std::vector<at::cuda::CUDAStream> _stream_compute;
cudaEvent_t _start_compute, _stop_compute, _start_d2dcopy, _start_comm, _stop_comm;
int comm_sms;
int cga_size;
int use_ce;
int _num_comm_sm;
int _cga_size;
int _use_ce;
bool _atomic_gemm;
UbufCommOverlap(torch::Tensor sample, int rank, int world_size, int tp_rank, int tp_size,
......@@ -151,9 +151,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
}
comm_created = true;
}
use_ce = 0;
comm_sms = num_comm_sm;
cga_size = comm_cga_size;
_use_ce = 0;
_num_comm_sm = num_comm_sm;
_cga_size = comm_cga_size;
_empty_tensor = empty_tensor;
// Allocate and register extra userbuffers
......@@ -211,9 +211,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, int comm_type,
at::Tensor rs_output) {
_ub_comm->use_ce = use_ce;
_ub_comm->sms = comm_sms;
_ub_comm->cga_size = cga_size;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
// Get the current userbuf offset
char *ubuf_wt_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size
......@@ -283,9 +283,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, bool gemm_overlap,
at::Tensor rs_output) {
_ub_comm->use_ce = use_ce;
_ub_comm->sms = comm_sms;
_ub_comm->cga_size = cga_size;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
// Get GEMM dimensions
int m = A.size(0);
int k = A.size(1);
......@@ -389,9 +389,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
size_t workspaceSize, bool accumulate, bool use_split_accumulator,
bool gemm_overlap, at::Tensor rs_output) {
// Get GEMM dimensions
_ub_comm->use_ce = use_ce;
_ub_comm->sms = comm_sms;
_ub_comm->cga_size = cga_size;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
int m = A.size(0);
int k = A.size(1);
int n = B.size(0);
......@@ -605,14 +605,14 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
at::cuda::CUDAStream _stream_recv = at::cuda::getStreamFromPool(true);
std::vector<at::cuda::CUDAStream> _stream_compute;
cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_send, _stop_recv;
int use_ce;
int sms;
int cga_size;
int _use_ce;
int _num_comm_sm;
int _cga_size;
bool _atomic_gemm;
UbufP2PCommOverlap(torch::Tensor sample, int rank, int world_size, int tp_rank, int tp_size,
int num_comm_sm, int comm_cga_size, bool set_sm_margin, bool aggregate2,
int num_max_streams, bool is_reduce_scatter, bool atomic_gemm,
int num_max_streams, bool is_reduce_scatter, bool atomic_gemm, bool use_ce,
torch::Tensor empty_tensor) {
// Initialize userbuf communicator
if (!comm_created) {
......@@ -628,9 +628,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
}
comm_created = true;
}
use_ce = 1;
sms = 1;
cga_size = 1;
_use_ce = use_ce;
_num_comm_sm = num_comm_sm;
_cga_size = comm_cga_size;
_empty_tensor = empty_tensor;
// Create workspace tensor with userbuffer
......@@ -726,9 +726,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias,
transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) {
_ub_comm->use_ce = use_ce;
_ub_comm->sms = sms;
_ub_comm->cga_size = cga_size;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
// Get GEMM dimensions between TN and NN input layouts
const int m = (transa) ? A.size(0) : A.size(1);
const int n = _ubuf.size(0);
......@@ -834,9 +834,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, at::Tensor B_copy) {
_ub_comm->use_ce = use_ce;
_ub_comm->sms = sms;
_ub_comm->cga_size = cga_size;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
// Get GEMM dimensions between TN and NN input layouts
const int m = (transa) ? A.size(0) : A.size(1);
const int k = (transa) ? A.size(1) : A.size(0);
......@@ -1002,9 +1002,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, at::Tensor rs_output) {
_ub_comm->use_ce = use_ce;
_ub_comm->sms = sms;
_ub_comm->cga_size = cga_size;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
// Get communication and GEMM input chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
......@@ -1077,9 +1077,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator,
at::Tensor rs_output) {
_ub_comm->use_ce = use_ce;
_ub_comm->sms = sms;
_ub_comm->cga_size = cga_size;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
int k = A.size(1);
int n = B.size(0);
......
......@@ -191,7 +191,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("is_p2p_overlap", &ubuf::UbufCommOverlap::is_p2p_overlap);
py::class_<ubuf::UbufP2PCommOverlap>(m, "UbufP2PCommOverlap")
.def(py::init<torch::Tensor&, int, int, int, int, int, int, bool, bool, int, bool, bool,
.def(py::init<torch::Tensor&, int, int, int, int, int, int, bool, bool, int, bool, bool, bool,
torch::Tensor>())
.def("split_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_ag)
.def("split_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_rs)
......
......@@ -124,6 +124,7 @@ def initialize_ub(
num_splits: int = 0,
aggregate: int = 0,
atomic_gemm: int = 0,
use_ce: bool = True,
fp8_buf: bool = False,
) -> None:
if atomic_gemm:
......@@ -177,6 +178,7 @@ def initialize_ub(
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams
is_reduce_scatter, # overlap with reduce scatter
atomic_gemm, # use a single GEMM with atomic-counters
use_ce, # use copy engine for P2P communications
torch.Tensor(), # empty tensor to pass to counters
)
else:
......@@ -224,12 +226,13 @@ def initialize_ub(
if ub_cfgs is not None and name in ub_cfgs:
ub_cfg = ub_cfgs[name]
method = ub_cfg.get("method", get_method(name))
num_sm = ub_cfg.get("num_sm", 16)
cga_size = ub_cfg.get("cga_size", 2)
num_sm = ub_cfg.get("num_sm", 1 if method == "ring_exchange" else 16)
cga_size = ub_cfg.get("cga_size", 1 if method == "ring_exchange" else 2)
num_splits = ub_cfg.get("num_splits", 4 if method == "pipeline" else 0)
set_sm_margin = ub_cfg.get("set_sm_margin", 0)
aggregate = ub_cfg.get("aggregate", 0)
atomic_gemm = ub_cfg.get("atomic_gemm", 0)
use_ce = ub_cfg.get("use_ce", True)
is_reduce_scatter = 1 if name in layers_reduce_scatter_overlap else 0
# Support FP8 userbuffer when (1) AllGather and (2) FP8-GEMM output ReduceScatter
fp8_buf = (name in layers_all_gather_overlap) or (
......@@ -245,6 +248,7 @@ def initialize_ub(
num_splits,
aggregate,
atomic_gemm,
use_ce,
fp8_buf,
)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment