"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "fbb16f4a71393fe7188f9f198438be6f5265e29b"
Unverified Commit f458fcf4 authored by Sangkug Lym's avatar Sangkug Lym Committed by GitHub
Browse files

Add the option to use SM for P2P comm in TP overlap (#914)



* Add the option to use SM for P2P comm in TP overlap
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* cleanup
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* Python formatting with black
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Format C++ with clang-format
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/pytorch/csrc/comm_gemm_overlap.h
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent e2caf78d
...@@ -129,9 +129,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -129,9 +129,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
at::cuda::CUDAStream _stream_comm = at::cuda::getStreamFromPool(true); at::cuda::CUDAStream _stream_comm = at::cuda::getStreamFromPool(true);
std::vector<at::cuda::CUDAStream> _stream_compute; std::vector<at::cuda::CUDAStream> _stream_compute;
cudaEvent_t _start_compute, _stop_compute, _start_d2dcopy, _start_comm, _stop_comm; cudaEvent_t _start_compute, _stop_compute, _start_d2dcopy, _start_comm, _stop_comm;
int comm_sms; int _num_comm_sm;
int cga_size; int _cga_size;
int use_ce; int _use_ce;
bool _atomic_gemm; bool _atomic_gemm;
UbufCommOverlap(torch::Tensor sample, int rank, int world_size, int tp_rank, int tp_size, UbufCommOverlap(torch::Tensor sample, int rank, int world_size, int tp_rank, int tp_size,
...@@ -151,9 +151,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -151,9 +151,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
} }
comm_created = true; comm_created = true;
} }
use_ce = 0; _use_ce = 0;
comm_sms = num_comm_sm; _num_comm_sm = num_comm_sm;
cga_size = comm_cga_size; _cga_size = comm_cga_size;
_empty_tensor = empty_tensor; _empty_tensor = empty_tensor;
// Allocate and register extra userbuffers // Allocate and register extra userbuffers
...@@ -211,9 +211,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -211,9 +211,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, int comm_type, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int comm_type,
at::Tensor rs_output) { at::Tensor rs_output) {
_ub_comm->use_ce = use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = comm_sms; _ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = cga_size; _ub_comm->cga_size = _cga_size;
// Get the current userbuf offset // Get the current userbuf offset
char *ubuf_wt_ptr = reinterpret_cast<char *>(_ubuf.data_ptr()); char *ubuf_wt_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size
...@@ -283,9 +283,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -283,9 +283,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
bool grad, at::Tensor workspace, size_t workspaceSize, bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, bool gemm_overlap, bool accumulate, bool use_split_accumulator, bool gemm_overlap,
at::Tensor rs_output) { at::Tensor rs_output) {
_ub_comm->use_ce = use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = comm_sms; _ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = cga_size; _ub_comm->cga_size = _cga_size;
// Get GEMM dimensions // Get GEMM dimensions
int m = A.size(0); int m = A.size(0);
int k = A.size(1); int k = A.size(1);
...@@ -389,9 +389,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -389,9 +389,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
size_t workspaceSize, bool accumulate, bool use_split_accumulator, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
bool gemm_overlap, at::Tensor rs_output) { bool gemm_overlap, at::Tensor rs_output) {
// Get GEMM dimensions // Get GEMM dimensions
_ub_comm->use_ce = use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = comm_sms; _ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = cga_size; _ub_comm->cga_size = _cga_size;
int m = A.size(0); int m = A.size(0);
int k = A.size(1); int k = A.size(1);
int n = B.size(0); int n = B.size(0);
...@@ -605,14 +605,14 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -605,14 +605,14 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
at::cuda::CUDAStream _stream_recv = at::cuda::getStreamFromPool(true); at::cuda::CUDAStream _stream_recv = at::cuda::getStreamFromPool(true);
std::vector<at::cuda::CUDAStream> _stream_compute; std::vector<at::cuda::CUDAStream> _stream_compute;
cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_send, _stop_recv; cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_send, _stop_recv;
int use_ce; int _use_ce;
int sms; int _num_comm_sm;
int cga_size; int _cga_size;
bool _atomic_gemm; bool _atomic_gemm;
UbufP2PCommOverlap(torch::Tensor sample, int rank, int world_size, int tp_rank, int tp_size, UbufP2PCommOverlap(torch::Tensor sample, int rank, int world_size, int tp_rank, int tp_size,
int num_comm_sm, int comm_cga_size, bool set_sm_margin, bool aggregate2, int num_comm_sm, int comm_cga_size, bool set_sm_margin, bool aggregate2,
int num_max_streams, bool is_reduce_scatter, bool atomic_gemm, int num_max_streams, bool is_reduce_scatter, bool atomic_gemm, bool use_ce,
torch::Tensor empty_tensor) { torch::Tensor empty_tensor) {
// Initialize userbuf communicator // Initialize userbuf communicator
if (!comm_created) { if (!comm_created) {
...@@ -628,9 +628,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -628,9 +628,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
} }
comm_created = true; comm_created = true;
} }
use_ce = 1; _use_ce = use_ce;
sms = 1; _num_comm_sm = num_comm_sm;
cga_size = 1; _cga_size = comm_cga_size;
_empty_tensor = empty_tensor; _empty_tensor = empty_tensor;
// Create workspace tensor with userbuffer // Create workspace tensor with userbuffer
...@@ -726,9 +726,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -726,9 +726,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias,
transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) {
_ub_comm->use_ce = use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = sms; _ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = cga_size; _ub_comm->cga_size = _cga_size;
// Get GEMM dimensions between TN and NN input layouts // Get GEMM dimensions between TN and NN input layouts
const int m = (transa) ? A.size(0) : A.size(1); const int m = (transa) ? A.size(0) : A.size(1);
const int n = _ubuf.size(0); const int n = _ubuf.size(0);
...@@ -834,9 +834,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -834,9 +834,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
transformer_engine::DType bias_type, at::Tensor pre_gelu_out, transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
bool grad, at::Tensor workspace, size_t workspaceSize, bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { bool accumulate, bool use_split_accumulator, at::Tensor B_copy) {
_ub_comm->use_ce = use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = sms; _ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = cga_size; _ub_comm->cga_size = _cga_size;
// Get GEMM dimensions between TN and NN input layouts // Get GEMM dimensions between TN and NN input layouts
const int m = (transa) ? A.size(0) : A.size(1); const int m = (transa) ? A.size(0) : A.size(1);
const int k = (transa) ? A.size(1) : A.size(0); const int k = (transa) ? A.size(1) : A.size(0);
...@@ -1002,9 +1002,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1002,9 +1002,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
transformer_engine::DType bias_type, at::Tensor pre_gelu_out, transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
bool grad, at::Tensor workspace, size_t workspaceSize, bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, at::Tensor rs_output) { bool accumulate, bool use_split_accumulator, at::Tensor rs_output) {
_ub_comm->use_ce = use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = sms; _ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = cga_size; _ub_comm->cga_size = _cga_size;
// Get communication and GEMM input chunk sizes // Get communication and GEMM input chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
...@@ -1077,9 +1077,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1077,9 +1077,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
at::Tensor rs_output) { at::Tensor rs_output) {
_ub_comm->use_ce = use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = sms; _ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = cga_size; _ub_comm->cga_size = _cga_size;
int k = A.size(1); int k = A.size(1);
int n = B.size(0); int n = B.size(0);
......
...@@ -191,7 +191,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -191,7 +191,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("is_p2p_overlap", &ubuf::UbufCommOverlap::is_p2p_overlap); .def("is_p2p_overlap", &ubuf::UbufCommOverlap::is_p2p_overlap);
py::class_<ubuf::UbufP2PCommOverlap>(m, "UbufP2PCommOverlap") py::class_<ubuf::UbufP2PCommOverlap>(m, "UbufP2PCommOverlap")
.def(py::init<torch::Tensor&, int, int, int, int, int, int, bool, bool, int, bool, bool, .def(py::init<torch::Tensor&, int, int, int, int, int, int, bool, bool, int, bool, bool, bool,
torch::Tensor>()) torch::Tensor>())
.def("split_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_ag) .def("split_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_ag)
.def("split_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_rs) .def("split_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_rs)
......
...@@ -124,6 +124,7 @@ def initialize_ub( ...@@ -124,6 +124,7 @@ def initialize_ub(
num_splits: int = 0, num_splits: int = 0,
aggregate: int = 0, aggregate: int = 0,
atomic_gemm: int = 0, atomic_gemm: int = 0,
use_ce: bool = True,
fp8_buf: bool = False, fp8_buf: bool = False,
) -> None: ) -> None:
if atomic_gemm: if atomic_gemm:
...@@ -177,6 +178,7 @@ def initialize_ub( ...@@ -177,6 +178,7 @@ def initialize_ub(
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams
is_reduce_scatter, # overlap with reduce scatter is_reduce_scatter, # overlap with reduce scatter
atomic_gemm, # use a single GEMM with atomic-counters atomic_gemm, # use a single GEMM with atomic-counters
use_ce, # use copy engine for P2P communications
torch.Tensor(), # empty tensor to pass to counters torch.Tensor(), # empty tensor to pass to counters
) )
else: else:
...@@ -224,12 +226,13 @@ def initialize_ub( ...@@ -224,12 +226,13 @@ def initialize_ub(
if ub_cfgs is not None and name in ub_cfgs: if ub_cfgs is not None and name in ub_cfgs:
ub_cfg = ub_cfgs[name] ub_cfg = ub_cfgs[name]
method = ub_cfg.get("method", get_method(name)) method = ub_cfg.get("method", get_method(name))
num_sm = ub_cfg.get("num_sm", 16) num_sm = ub_cfg.get("num_sm", 1 if method == "ring_exchange" else 16)
cga_size = ub_cfg.get("cga_size", 2) cga_size = ub_cfg.get("cga_size", 1 if method == "ring_exchange" else 2)
num_splits = ub_cfg.get("num_splits", 4 if method == "pipeline" else 0) num_splits = ub_cfg.get("num_splits", 4 if method == "pipeline" else 0)
set_sm_margin = ub_cfg.get("set_sm_margin", 0) set_sm_margin = ub_cfg.get("set_sm_margin", 0)
aggregate = ub_cfg.get("aggregate", 0) aggregate = ub_cfg.get("aggregate", 0)
atomic_gemm = ub_cfg.get("atomic_gemm", 0) atomic_gemm = ub_cfg.get("atomic_gemm", 0)
use_ce = ub_cfg.get("use_ce", True)
is_reduce_scatter = 1 if name in layers_reduce_scatter_overlap else 0 is_reduce_scatter = 1 if name in layers_reduce_scatter_overlap else 0
# Support FP8 userbuffer when (1) AllGather and (2) FP8-GEMM output ReduceScatter # Support FP8 userbuffer when (1) AllGather and (2) FP8-GEMM output ReduceScatter
fp8_buf = (name in layers_all_gather_overlap) or ( fp8_buf = (name in layers_all_gather_overlap) or (
...@@ -245,6 +248,7 @@ def initialize_ub( ...@@ -245,6 +248,7 @@ def initialize_ub(
num_splits, num_splits,
aggregate, aggregate,
atomic_gemm, atomic_gemm,
use_ce,
fp8_buf, fp8_buf,
) )
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment