Unverified Commit 933294dc authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[C/PyTorch] Userbuffers and comm+GEMM overlap algorithms refactored and moved to TE/common (#1067)



* moved userbuffers code to TE/common
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* moved comm+GEMM overlap code to TE/common
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* removed PyTorch depdency from comm+GEMM overlap in TE/common
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added TE/PyTorch wrappers for refactored comm+GEMM overlap code in TE/common
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* updated TE/PyTorch Python API to match the refactored comm+GEMM overlap code
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* updated unit tests to work with refactored comm+GEMM overlap code
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added a pylint exception to comm+GEMM overlap test runner
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixing linting errors
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* added documentation for te.initialize_ub
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed compile errors when building with NVTE_UB_WITH_MPI=1
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed default bootstrap backend
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* switched default bootstrap backend priority to MPI > Gloo > NCCL
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* updated bootstrap backend documentation
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* close UB bootstrap socket to avoid interfering with CUDA Multicast shareable file handle send/recv
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added torch::Tensor wrappers for communication buffer and atomic counters so PyTorch can factor externally allocated memory into its garbage collection threshold
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* automated handling of world, local and node ranks/sizes within C++ CommOverlapHelper to simplify Python function signatures
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed incorrect read of environment variables
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected priority for _SOCKET_IFNAME environment variables in UB bootstrapping
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* moved multicast support check to cuda_runtime.h and replaced cudaDeviceGetProp call with cached sm_count()
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* removed commented out old code and replaced external collective function type defines with aliases
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* compile-time CUDA version guard for CUDA Driver Multicast attribute
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added compile-time CUDA version guards to Multicast code in Userbuffers
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* condensed UB docs, corrected const violations
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed autodoc rst for UB calls, added CUDA version guard on Multicast UB kernels
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed incorrect UB type reporting for P2P overlaps, comment reformatting
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* add docstring to tex.ubuf_built_with_mpi()
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 35bbe740
...@@ -45,8 +45,8 @@ def fp8_gemm( ...@@ -45,8 +45,8 @@ def fp8_gemm(
use_bias: bool = False, use_bias: bool = False,
use_split_accumulator: bool = False, use_split_accumulator: bool = False,
D_dtype: Optional[tex.DType] = None, D_dtype: Optional[tex.DType] = None,
ub_algo: tex.UbufOverlapAlgo = None, ub_algo: tex.CommOverlapAlgo = None,
ub: Union[tex.UbufCommOverlap, tex.UbufP2PCommOverlap] = None, ub: Union[tex.CommOverlap, tex.CommOverlapP2P] = None,
extra_output_tensor: torch.Tensor = None, extra_output_tensor: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""TN layout GEMM with fp8 inputs.""" """TN layout GEMM with fp8 inputs."""
...@@ -107,7 +107,7 @@ def fp8_gemm( ...@@ -107,7 +107,7 @@ def fp8_gemm(
fn = torch.ops.tex_ts.te_gemm_ts fn = torch.ops.tex_ts.te_gemm_ts
if ub_algo is not None: if ub_algo is not None:
assert ub is not None, "ub object is None!" assert ub is not None, "ub object is None!"
if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG: if ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_AG:
fn = ub.bulk_overlap fn = ub.bulk_overlap
extra_output_tensor = ( extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor empty_tensor if extra_output_tensor is None else extra_output_tensor
...@@ -115,11 +115,11 @@ def fp8_gemm( ...@@ -115,11 +115,11 @@ def fp8_gemm(
args = tuple( args = tuple(
args args
+ ( + (
1, tex.CommOverlapType.AG,
extra_output_tensor, extra_output_tensor,
) )
) )
elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS: elif ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_RS:
fn = ub.bulk_overlap fn = ub.bulk_overlap
extra_output_tensor = ( extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor empty_tensor if extra_output_tensor is None else extra_output_tensor
...@@ -127,23 +127,23 @@ def fp8_gemm( ...@@ -127,23 +127,23 @@ def fp8_gemm(
args = tuple( args = tuple(
args args
+ ( + (
0, tex.CommOverlapType.RS,
extra_output_tensor, extra_output_tensor,
) )
) )
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P: elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P:
fn = ub.split_overlap_ag_p2p fn = ub.split_overlap_ag_p2p
extra_output_tensor = ( extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor empty_tensor if extra_output_tensor is None else extra_output_tensor
) )
args = tuple(args + (extra_output_tensor,)) args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P: elif ub_algo == tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P:
fn = ub.atomic_gemm_overlap_ag_p2p fn = ub.atomic_gemm_overlap_ag_p2p
extra_output_tensor = ( extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor empty_tensor if extra_output_tensor is None else extra_output_tensor
) )
args = tuple(args + (extra_output_tensor,)) args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS: elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS:
fn = ub.split_overlap_rs fn = ub.split_overlap_rs
assert ( assert (
extra_output_tensor is not None extra_output_tensor is not None
...@@ -155,13 +155,13 @@ def fp8_gemm( ...@@ -155,13 +155,13 @@ def fp8_gemm(
extra_output_tensor, extra_output_tensor,
) )
) )
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P: elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P:
fn = ub.split_overlap_rs_p2p fn = ub.split_overlap_rs_p2p
assert ( assert (
extra_output_tensor is not None extra_output_tensor is not None
), "SPLIT_PIPELINED_RS_P2P requires extra output tensor" ), "SPLIT_PIPELINED_RS_P2P requires extra output tensor"
args = tuple(args + (extra_output_tensor,)) args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS: elif ub_algo == tex.CommOverlapAlgo.ATOMIC_GEMM_RS:
fn = ub.atomic_gemm_overlap_rs fn = ub.atomic_gemm_overlap_rs
assert extra_output_tensor is not None, "ATOMIC_GEMM_RS requires extra output tensor" assert extra_output_tensor is not None, "ATOMIC_GEMM_RS requires extra output tensor"
args = tuple( args = tuple(
...@@ -171,16 +171,13 @@ def fp8_gemm( ...@@ -171,16 +171,13 @@ def fp8_gemm(
extra_output_tensor, extra_output_tensor,
) )
) )
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P: elif ub_algo == tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P:
fn = ub.atomic_gemm_overlap_rs_p2p fn = ub.atomic_gemm_overlap_rs_p2p
assert ( assert (
extra_output_tensor is not None extra_output_tensor is not None
), "ATOMIC_GEMM_RS_P2P requires extra output tensor" ), "ATOMIC_GEMM_RS_P2P requires extra output tensor"
args = tuple(args + (extra_output_tensor,)) args = tuple(args + (extra_output_tensor,))
if ub_algo is not None and ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P: _ = fn(*args)
out = fn(*args)
else:
_ = fn(*args)
return out, gelu_input return out, gelu_input
...@@ -198,8 +195,8 @@ def gemm( ...@@ -198,8 +195,8 @@ def gemm(
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
use_bias: bool = False, use_bias: bool = False,
ub_algo: tex.UbufOverlapAlgo = None, ub_algo: tex.CommOverlapAlgo = None,
ub: tex.UbufCommOverlap = None, ub: Union[tex.CommOverlap, tex.CommOverlapP2P] = None,
extra_output_tensor: torch.Tensor = None, extra_output_tensor: torch.Tensor = None,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
"""Non FP8 GEMM.""" """Non FP8 GEMM."""
...@@ -270,19 +267,19 @@ def gemm( ...@@ -270,19 +267,19 @@ def gemm(
fn = torch.ops.tex_ts.te_gemm_ts fn = torch.ops.tex_ts.te_gemm_ts
if ub_algo is not None: if ub_algo is not None:
assert ub is not None, "ub object is None!" assert ub is not None, "ub object is None!"
if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG: if ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_AG:
fn = ub.bulk_overlap fn = ub.bulk_overlap
args = tuple(args + (1, empty_tensor)) args = tuple(args + (tex.CommOverlapType.AG, empty_tensor))
elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS: elif ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_RS:
fn = ub.bulk_overlap fn = ub.bulk_overlap
args = tuple(args + (0, empty_tensor)) args = tuple(args + (tex.CommOverlapType.RS, empty_tensor))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P: elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P:
fn = ub.split_overlap_ag_p2p fn = ub.split_overlap_ag_p2p
extra_output_tensor = ( extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor empty_tensor if extra_output_tensor is None else extra_output_tensor
) )
args = tuple(args + (extra_output_tensor,)) args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS: elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS:
fn = ub.split_overlap_rs fn = ub.split_overlap_rs
assert ( assert (
extra_output_tensor is not None extra_output_tensor is not None
...@@ -294,7 +291,7 @@ def gemm( ...@@ -294,7 +291,7 @@ def gemm(
extra_output_tensor, extra_output_tensor,
) )
) )
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P: elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P:
fn = ub.split_overlap_rs_p2p fn = ub.split_overlap_rs_p2p
assert ( assert (
extra_output_tensor is not None extra_output_tensor is not None
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_COMM_GEMM_OVERLAP_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_COMM_GEMM_OVERLAP_H_
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp8.h>
#include <stdio.h>
#include <stdlib.h>
#include <torch/cuda.h>
#include <torch/custom_class.h>
#include <torch/extension.h>
#include <torch/types.h>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include "common/common.h"
#include "common/util/cuda_driver.h"
#include "common/util/logging.h"
#include "common/util/system.h"
#include "extensions.h"
#include "userbuffers/userbuffers.h"
#define HALF_BYTES 2
#define UB_MAX_SM 32
using namespace torch::indexing;
using namespace std::placeholders;
namespace ubuf {
bool device_supports_multicast() {
int dev, supports_multicast;
CUdevice cudev;
NVTE_CHECK_CUDA(cudaGetDevice(&dev));
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &cudev, dev);
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &supports_multicast,
CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, cudev);
return static_cast<bool>(supports_multicast);
}
bool ubuf_built_with_mpi() {
#ifdef NVTE_UB_WITH_MPI
return true;
#else
return false;
#endif
}
class UbufBootstrapCallbacks : torch::CustomClassHolder {
private:
bool initialized{false};
bool backend_is_nccl{false};
std::map<std::string, c10d::ProcessGroup *> pgs;
public:
UbufBootstrapCallbacks() {
#ifndef NVTE_UB_WITH_MPI
NVTE_ERROR("Internal TE error: Dummy UbufBootstrapCallbacks init without NVTE_UB_WITH_MPI=1!");
#endif
} // empty constructor for NVTE_UB_WITH_MPI=1
UbufBootstrapCallbacks(c10d::ProcessGroup *world_group, c10d::ProcessGroup *intra_node_group) {
pgs.insert({"world", world_group});
c10d::ProcessGroup::BackendType backend = world_group->getBackendType();
backend_is_nccl = (backend == c10d::ProcessGroup::BackendType::NCCL);
NVTE_CHECK(intra_node_group->getBackendType() == backend,
"Internal TE error: Intra-node group must be on the same backend (%s) as the world ",
"group!", world_group->getBackendName());
pgs.insert({"intra", intra_node_group});
initialized = true;
}
~UbufBootstrapCallbacks() {
for (auto &pg : pgs) pg.second = nullptr;
backend_is_nccl = false;
initialized = false;
}
void ub_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes,
char *group) {
NVTE_CHECK(initialized, "Internal TE error: tex.UbufBootstrapCallbacks() is not initialized ",
"with valid process groups!");
auto localtensor =
torch::from_blob(localdata, {static_cast<int64_t>(localbytes / sizeof(uint8_t))},
at::device(torch::kCPU).dtype(torch::kUInt8));
auto localtmp = (backend_is_nccl) ? localtensor.cuda() : localtensor;
auto globaltensor =
torch::from_blob(globaldata, {static_cast<int64_t>(globalbytes / sizeof(uint8_t))},
at::device(torch::kCPU).dtype(torch::kUInt8));
auto globaltmp = (backend_is_nccl) ? globaltensor.cuda() : globaltensor;
std::vector<std::vector<torch::Tensor>> globalchunks = {globaltmp.chunk(pgs[group]->getSize())};
std::vector<torch::Tensor> localchunk = {localtmp};
auto work = pgs[group]->allgather(globalchunks, localchunk);
work->wait();
if (backend_is_nccl) {
globaltensor.copy_(globaltmp.cpu());
globaltmp = torch::Tensor();
localtmp = torch::Tensor();
}
}
void ub_barrier(char *group) {
NVTE_CHECK(initialized, "Internal TE error: tex.UbufBootstrapCallbacks() is not initialized ",
"with valid process groups!");
auto work = pgs[group]->barrier();
work->wait();
}
};
enum class COMM_TYPE { RS = 0, AG = 1 };
enum class UBOverlapAlgo {
BULK_OVERLAP_AG = 0,
BULK_OVERLAP_RS = 1,
SPLIT_PIPELINED_AG_P2P = 2,
SPLIT_PIPELINED_RS = 3,
SPLIT_PIPELINED_RS_P2P = 4,
ATOMIC_GEMM_RS = 5,
ATOMIC_GEMM_AG_P2P = 6,
ATOMIC_GEMM_RS_P2P = 7
};
struct UbufBase {
static inline communicator *_ub_comm{nullptr};
static inline bool comm_created{false};
};
struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
int _tp_id;
int _tp_size;
int _num_splits;
int _math_sms;
int _ub_reg;
void *_ubuf_ptr;
torch::Tensor _ubuf;
torch::Tensor output_tensor;
torch::Tensor _ubuf_scale_inv;
bool _ubuf_scale_inv_initialized;
torch::Tensor counter;
at::cuda::CUDAStream _stream_comm = at::cuda::getStreamFromPool(true);
std::vector<at::cuda::CUDAStream> _stream_compute;
cudaEvent_t _start_compute, _stop_compute, _start_d2dcopy, _start_comm, _stop_comm;
int _num_comm_sm;
int _cga_size;
int _use_ce;
bool _atomic_gemm;
UbufCommOverlap(torch::Tensor sample, int myrank, int numranks, int mylocal, int numlocal,
int mynode, int numnodes, int tp_size, int num_comm_sm, int comm_cga_size,
int num_splits, bool set_sm_margin, int num_max_streams, bool atomic_gemm,
UbufBootstrapCallbacks &callbacks) {
// Initialize userbuf communicator
if (!comm_created) {
if (myrank == 0) {
printf("!!! [UB] Create Userbuffers Communicator\n");
}
#ifdef NVTE_UB_WITH_MPI
create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1);
#else
create_communicator_grouped2(
&_ub_comm, myrank, numranks, mylocal, numlocal, mynode, numnodes,
std::bind(&UbufBootstrapCallbacks::ub_allgather, callbacks, _1, _2, _3, _4, _5),
std::bind(&UbufBootstrapCallbacks::ub_barrier, callbacks, _1), 1, 1, tp_size, 1);
#endif
comm_created = true;
}
_use_ce = 0;
_num_comm_sm = num_comm_sm;
_cga_size = comm_cga_size;
// Allocate and register extra userbuffers
int ubuf_bytes = sample.numel() * sample.element_size();
_ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes,
_ub_comm, true);
_ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options());
if (_ub_comm->myrank == 0) {
printf("!!! [UB] Register UBuf %d\n", _ub_reg);
}
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
for (int i = 0; i < std::min(num_max_streams, num_splits); i++) {
cudaStream_t stream;
cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1);
_stream_compute.push_back(
at::cuda::getStreamFromExternal(stream, stream_main.device_index()));
}
_num_splits = num_splits;
_tp_size = tp_size;
_tp_id = (_ub_comm->myrank % _tp_size);
_ubuf_scale_inv_initialized = false;
// Set the number of SMs for GEMM with margin
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
_math_sms = (set_sm_margin) ? prop.multiProcessorCount - num_comm_sm : prop.multiProcessorCount;
_math_sms -= transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", 0);
output_tensor = torch::Tensor();
_atomic_gemm = atomic_gemm;
if (_atomic_gemm) {
auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA);
counter = torch::zeros({num_splits * 2}, counter_options);
counter.index_put_({Slice(None, num_splits)}, 1);
}
// CUDA event creation
cudaEventCreateWithFlags(&_start_compute, 0);
cudaEventCreateWithFlags(&_stop_compute, 0);
cudaEventCreateWithFlags(&_start_d2dcopy, 0);
cudaEventCreateWithFlags(&_start_comm, 0);
cudaEventCreateWithFlags(&_stop_comm, 0);
}
~UbufCommOverlap() {
cudaEventDestroy(_stop_comm);
cudaEventDestroy(_start_comm);
cudaEventDestroy(_start_d2dcopy);
cudaEventDestroy(_stop_compute);
cudaEventDestroy(_start_compute);
for (size_t i = 0; i < _stream_compute.size(); i++) cudaStreamDestroy(_stream_compute[i]);
if (comm_created) {
#ifdef NVTE_UB_WITH_MPI
destroy_communicator_mpi(_ub_comm);
#else
destroy_communicator(_ub_comm);
#endif
comm_created = false;
}
}
/*
** Bulk GEMM + COMM
** This function assumes the communication input is pre-copied to _ubuf
*/
std::vector<at::Tensor> bulk_overlap(
at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse,
int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias,
transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, int comm_type,
at::Tensor rs_output) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
// Get the current userbuf offset
char *ubuf_wt_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size
COMM_TYPE _comm_type = static_cast<COMM_TYPE>(comm_type);
if (_comm_type == COMM_TYPE::RS) {
ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size();
}
// Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, (cudaStream_t)stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
// Communication: AG and RS
if (_comm_type == COMM_TYPE::AG) {
allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, (cudaStream_t)_stream_comm);
} else if (_comm_type == COMM_TYPE::RS) {
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
comm_elements *= 2;
float *scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
assert(rs_output.numel() == _ubuf.numel() / _tp_size);
assert(rs_output.size(0) == _ubuf.size(0) / _tp_size);
assert(rs_output.element_size() == 2);
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, scale_inv_ptr, _ub_reg, 0,
comm_elements, _ub_comm,
(cudaStream_t)_stream_comm);
} else {
reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm,
(cudaStream_t)_stream_comm);
}
} else {
NVTE_ERROR("Not supported communication type.");
}
if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor];
if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor];
assert(pre_gelu_out.numel() == 0);
te_gemm(A, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, D, D_scale,
D_type, D_amax, bias, bias_type, pre_gelu_out, grad, workspace, workspaceSize,
accumulate, use_split_accumulator, _math_sms);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0));
// Generate output tensor from userbuf data pointer
int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size;
int output_c_dim1 = _ubuf.size(1);
output_tensor = torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options());
_ub_comm->sms = ori_sms;
return {D, output_tensor};
} // bulk_overlap
/*
** Split FPROP GEMM + ReduceScatter
*/
void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type,
at::Tensor D_amax, at::Tensor bias,
transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, bool gemm_overlap,
at::Tensor rs_output) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
// Get GEMM dimensions
int m = A.size(0);
int k = A.size(1);
int n = B.size(0);
int m_chunk = m / _num_splits;
int workspace_size_chunk = workspaceSize / _stream_compute.size();
// Get input, output, and workspace data pointers
char *input_a_chunk_ptr = reinterpret_cast<char *>(A.data_ptr());
char *output_buf_chunk_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr());
int *counter_ptr = reinterpret_cast<int *>(counter.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
// Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0));
if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor];
if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor];
assert(pre_gelu_out.numel() == 0);
torch::Tensor input_a = torch::from_blob(input_a_chunk_ptr, {m, k}, A.options());
torch::Tensor output_d = torch::from_blob(output_buf_chunk_ptr, {n, m}, _ubuf.options());
// torch::zeros({n, m}, _ubuf.options());
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options());
at::cuda::setCurrentCUDAStream(_stream_compute[0]);
te_atomic_gemm(input_a, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb,
output_d, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms, _num_splits /*m_split*/, 0 /*n_split*/, true /*gemm_producer*/,
counter);
for (int i = 0; i < _num_splits; i++) {
const char *env_p = std::getenv("NVTE_RS_STRIDED_ATOMIC");
if (env_p != nullptr && env_p[0] == '1') {
if (i == _num_splits - 1) {
_ub_comm->sms = UB_MAX_SM;
}
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D_type, fp8_type,
reducescatter2_userbuff_strided_atomic_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, m,
_num_splits, &counter_ptr[i], _ub_comm, (cudaStream_t)_stream_comm););
} else {
reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m,
_num_splits, &counter_ptr[i], _ub_comm,
(cudaStream_t)_stream_comm);
}
} else if (env_p != nullptr && env_p[0] == '2') {
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D_type, fp8_type,
reducescatter2_userbuff_strided_multiatomic_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits,
counter_ptr, _ub_comm, (cudaStream_t)_stream_comm););
} else {
reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, n,
m, _num_splits, counter_ptr, _ub_comm,
(cudaStream_t)_stream_comm);
}
break;
} else {
assert(_ubuf.element_size() != 1);
consumer(counter_ptr, i, (cudaStream_t)_stream_comm);
reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm);
}
rs_output_ptr += m_chunk * rs_output.element_size();
}
_ub_comm->sms = ori_sms;
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[0]));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0));
at::cuda::setCurrentCUDAStream(stream_main);
return;
} // split_overlap_rs
/*
** Split FPROP GEMM + ReduceScatter
*/
void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax,
at::Tensor bias, transformer_engine::DType bias_type,
at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator,
bool gemm_overlap, at::Tensor rs_output) {
// Get GEMM dimensions
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
int m = A.size(0);
int k = A.size(1);
int n = B.size(0);
int m_chunk = m / _num_splits;
int input_a_chunk_size = m_chunk * k;
int output_chunk_size = n * m_chunk;
int workspace_size_chunk = workspaceSize / _stream_compute.size();
// Get input, output, and workspace data pointers
char *input_a_chunk_ptr = reinterpret_cast<char *>(A.data_ptr());
char *output_buf_chunk_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
// Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0));
}
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0));
if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor];
if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor];
assert(pre_gelu_out.numel() == 0);
if (gemm_overlap) {
torch::Tensor input_a_chunk = torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options());
torch::Tensor output_chunk =
torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options());
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options());
at::cuda::setCurrentCUDAStream(_stream_compute[0]);
te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb,
output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms);
for (int i = 1; i < _num_splits; i++) {
input_a_chunk_ptr += input_a_chunk_size * B.element_size();
output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size();
torch::Tensor input_a_chunk =
torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options());
torch::Tensor output_chunk =
torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options());
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
{workspace_size_chunk}, workspace.options());
at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]);
te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb,
output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms);
NVTE_CHECK_CUDA(cudaEventRecord(
_start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
// Communication chunk
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D_type, fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n,
m, _ub_comm, (cudaStream_t)_stream_comm););
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size,
m_chunk, n, m, _ub_comm,
(cudaStream_t)_stream_comm);
}
rs_output_ptr += m_chunk * rs_output.element_size();
}
int last_compute_stream_id =
(_num_splits + _stream_compute.size() - 1) % _stream_compute.size();
NVTE_CHECK_CUDA(
cudaEventRecord(_start_comm, (cudaStream_t)_stream_compute[last_compute_stream_id]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
// Last communication chunk with max SM
_ub_comm->sms = UB_MAX_SM;
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D_type, fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size,
m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm););
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg,
(_num_splits - 1) * output_chunk_size, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm);
}
} else {
for (int i = 0; i < _num_splits; i++) {
torch::Tensor input_a_chunk =
torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options());
torch::Tensor output_chunk =
torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options());
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
{workspace_size_chunk}, workspace.options());
at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]);
te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb,
output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms);
NVTE_CHECK_CUDA(cudaEventRecord(_start_comm,
(cudaStream_t)_stream_compute[i % _stream_compute.size()]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
// Communication chunk. Uses MAX_SM at the last chunk
if (i == _num_splits - 1) {
_ub_comm->sms = UB_MAX_SM;
}
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D_type, fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * output_chunk_size, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm););
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size,
m_chunk, n, m, _ub_comm,
(cudaStream_t)_stream_comm);
}
rs_output_ptr += m_chunk * rs_output.element_size();
input_a_chunk_ptr += input_a_chunk_size * B.element_size();
output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size();
}
}
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
}
_ub_comm->sms = ori_sms;
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0));
at::cuda::setCurrentCUDAStream(stream_main);
return;
} // split_overlap_rs
void set_ubuf_scale_inv(const torch::Tensor &scale_inv) {
_ubuf_scale_inv = scale_inv;
_ubuf_scale_inv_initialized = true;
}
bool is_fp8_ubuf() { return (_ubuf.element_size() == 1); }
/*
** Helper function to copy input to _ubuf
*/
void copy_input_to_ubuf(torch::Tensor input, int comm_type) {
char *ubuf_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
COMM_TYPE _comm_type = static_cast<COMM_TYPE>(comm_type);
if (_comm_type == COMM_TYPE::AG) {
if ((input.numel() * _tp_size) != _ubuf.numel() ||
input.element_size() != _ubuf.element_size()) {
NVTE_ERROR("input and ubuf size do not match!");
}
ubuf_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size();
} else {
if (input.numel() != _ubuf.numel() || input.element_size() != _ubuf.element_size()) {
NVTE_ERROR("input and ubuf size do not match!");
}
}
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0));
NVTE_CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.data_ptr(),
input.numel() * input.element_size(), cudaMemcpyDeviceToDevice,
(cudaStream_t)_stream_comm));
}
torch::Tensor &get_ubuf_output(int comm_type) {
char *ubuf_wt_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
COMM_TYPE _comm_type = static_cast<COMM_TYPE>(comm_type);
if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS) NVTE_ERROR("Invalid comm_type");
if (_comm_type == COMM_TYPE::RS)
ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size();
int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size;
int output_c_dim1 = _ubuf.size(1);
output_tensor = torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options());
return output_tensor;
}
bool is_atomic_gemm() { return _atomic_gemm; }
bool is_p2p_overlap() { return false; }
}; // UbufCommOverlap
struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
int _tp_id;
int _tp_size;
int _ub_reg, _ub_reg2;
int _next_rank, _prev_rank, _rank, _rank_round_tp;
int _aggregate2;
int _math_sms;
int _self_chunk_id;
void *_ubuf_ptr;
torch::Tensor _ubuf;
torch::Tensor counter;
torch::Tensor _ubuf_scale_inv;
bool _ubuf_scale_inv_initialized;
std::vector<torch::Tensor> _ubufs;
at::cuda::CUDAStream _stream_send = at::cuda::getStreamFromPool(true);
at::cuda::CUDAStream _stream_recv = at::cuda::getStreamFromPool(true);
std::vector<at::cuda::CUDAStream> _stream_compute;
cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_send, _stop_recv;
int _use_ce;
int _num_comm_sm;
int _cga_size;
bool _atomic_gemm;
UbufP2PCommOverlap(torch::Tensor sample, int myrank, int numranks, int mylocal, int numlocal,
int mynode, int numnodes, int tp_size, int num_comm_sm, int comm_cga_size,
bool set_sm_margin, bool aggregate2, int num_max_streams,
bool is_reduce_scatter, bool atomic_gemm, bool use_ce,
UbufBootstrapCallbacks &callbacks) {
// Initialize userbuf communicator
if (!comm_created) {
if (myrank == 0) {
printf("!!! [UB] Create Userbuffers Communicator\n");
}
#ifdef NVTE_UB_WITH_MPI
create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1);
#else
create_communicator_grouped2(
&_ub_comm, myrank, numranks, mylocal, numlocal, mynode, numnodes,
std::bind(&UbufBootstrapCallbacks::ub_allgather, callbacks, _1, _2, _3, _4, _5),
std::bind(&UbufBootstrapCallbacks::ub_barrier, callbacks, _1), 1, 1, tp_size, 1);
#endif
comm_created = true;
}
_use_ce = use_ce;
_num_comm_sm = num_comm_sm;
_cga_size = comm_cga_size;
// Create workspace tensor with userbuffer
int ubuf_bytes = sample.numel() * sample.element_size();
int ubuf_chunk_bytes = ubuf_bytes / tp_size;
int num_ubuf_chunks = tp_size;
if (is_reduce_scatter) {
// GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk
// outputs for reduction at the end of the pipelining.
ubuf_bytes = static_cast<int>(ubuf_bytes / tp_size * (tp_size * 2 - 1));
num_ubuf_chunks = static_cast<int>(tp_size * 2 - 1);
}
_ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes,
_ub_comm, true);
_ubuf = torch::from_blob(
_ubuf_ptr, {sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)}, sample.options());
if (_ub_comm->myrank == 0) {
printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg);
}
// Create tensor chunks for easy management
char *ubuf_byte_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
for (int i = 0; i < num_ubuf_chunks; i++) {
auto ubuf_chunk = torch::from_blob(ubuf_byte_ptr, {sample.size(0) / tp_size, sample.size(1)},
sample.options());
_ubufs.push_back(std::move(ubuf_chunk));
ubuf_byte_ptr += ubuf_chunk_bytes;
}
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
for (int i = 0; i < std::min(num_max_streams, tp_size); i++) {
cudaStream_t stream;
cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1);
_stream_compute.push_back(
at::cuda::getStreamFromExternal(stream, stream_main.device_index()));
}
// Set the number of SMs for GEMM with margin
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
_math_sms = (set_sm_margin) ? prop.multiProcessorCount - num_comm_sm : prop.multiProcessorCount;
_math_sms -= transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", 0);
_tp_size = tp_size;
_aggregate2 = aggregate2;
_rank = _ub_comm->myrank;
_tp_id = (_rank % _tp_size);
_rank_round_tp = (_rank / _tp_size) * _tp_size;
_next_rank = (_tp_size + _rank + 1) % _tp_size + _rank_round_tp;
_prev_rank = (_tp_size + _rank + -1) % _tp_size + _rank_round_tp;
_ubuf_scale_inv_initialized = false;
_atomic_gemm = atomic_gemm;
_self_chunk_id = _tp_id;
if (_atomic_gemm) {
auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA);
counter = torch::zeros({_tp_size * 2}, counter_options);
counter.index_put_({Slice(None, _tp_size)}, 1);
if (!is_reduce_scatter) {
const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC");
if (_rank == 0 && env_p != nullptr) {
if (env_p[0] == '1') {
_use_ce = 0;
_ub_comm->push = 1;
printf("!!userbuffers_sendrecv_multi_atomic_shuffle\n");
}
}
_self_chunk_id = 0;
counter.index_put_({_self_chunk_id}, 0);
}
}
// CUDA event creation
cudaEventCreateWithFlags(&_start_compute, 0);
cudaEventCreateWithFlags(&_stop_compute, 0);
cudaEventCreateWithFlags(&_start_comm, 0);
cudaEventCreateWithFlags(&_stop_send, 0);
cudaEventCreateWithFlags(&_stop_recv, 0);
}
~UbufP2PCommOverlap() {
cudaEventDestroy(_stop_recv);
cudaEventDestroy(_stop_send);
cudaEventDestroy(_start_comm);
cudaEventDestroy(_stop_compute);
cudaEventDestroy(_start_compute);
for (size_t i = 0; i < _stream_compute.size(); i++) cudaStreamDestroy(_stream_compute[i]);
if (comm_created) {
#ifdef NVTE_UB_WITH_MPI
destroy_communicator_mpi(_ub_comm);
#else
destroy_communicator(_ub_comm);
#endif
comm_created = false;
}
}
/*
** Split AllGather + AtomicGEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is
*needed to have AG outputs
** in each rank to be in the contiguous memory space after all ring exchange
*phases.
*/
torch::Tensor atomic_gemm_overlap_ag(
at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse,
int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias,
transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
// Get GEMM dimensions between TN and NN input layouts
const int m = (transa) ? A.size(0) : A.size(1);
const int n = _ubuf.size(0);
const int n_chunk = n / _tp_size;
// Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
// Create an GEMM output buffer with N+1 chunks in a contiguous memory
torch::Tensor D_buffer = torch::empty({n_chunk * (_tp_size + 1), m}, D.options());
D = torch::from_blob(D_buffer.data_ptr(), {D.size(0), D.size(1)}, D.options());
// Get output and workspace data pointers
char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr());
int *counter_ptr = reinterpret_cast<int *>(counter.data_ptr());
int workspace_size_chunk = workspaceSize / _stream_compute.size();
if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor];
if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor];
assert(pre_gelu_out.numel() == 0);
// Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options());
for (int i = 0; i < _tp_size - 1; i++) {
// Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
// have the AG output in all ranks to be contiguous after the ring
// exchanges
int send_chunk_id = i;
int recv_chunk_id = i + 1;
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;
const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC");
if (env_p != nullptr && env_p[0] == '1') {
if (i == 0) {
_ub_comm->use_ce = 0;
userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes,
_ub_comm, _next_rank, _prev_rank, _tp_size, counter_ptr,
true, (cudaStream_t)_stream_recv);
}
} else {
userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm,
_next_rank, (cudaStream_t)_stream_recv);
userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm,
_prev_rank, (cudaStream_t)_stream_recv);
producer(counter_ptr, recv_chunk_id, (cudaStream_t)_stream_recv);
}
if (i == 0) {
te_atomic_gemm(A, A_scale_inverse, A_type, transa, _ubuf, B_scale_inverse, B_type, transb,
D, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms, 0, _tp_size, false, counter);
}
}
// Store the input activation for backprop
if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_self_chunk_id].numel());
assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size());
NVTE_CHECK_CUDA(
cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_self_chunk_id].data_ptr(),
_ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
}
// Reset atomic counters
consumer_batch(counter_ptr, 1, _tp_size, (cudaStream_t)stream_main);
// Copy the first GEMM output chunk to the end chunk position of D_buffer
char *src_ptr = reinterpret_cast<char *>(D_buffer.data_ptr());
NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + (D.numel() * D.element_size()), src_ptr,
n_chunk * m * D.element_size(), cudaMemcpyDeviceToDevice,
(cudaStream_t)stream_main));
// Return the last N rows of D_buffer
_ub_comm->sms = ori_sms;
torch::Tensor D_return = D_buffer.narrow(0, n_chunk, n);
return D_return;
} // atomic_gemm_overlap_ag
/*
** Split AllGather + GEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is
*needed to have AG outputs
** in each rank to be in the contiguous memory space after all ring exchange
*phases.
*/
torch::Tensor split_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type,
at::Tensor D_amax, at::Tensor bias,
transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, at::Tensor B_copy) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
// Get GEMM dimensions between TN and NN input layouts
const int m = (transa) ? A.size(0) : A.size(1);
const int k = (transa) ? A.size(1) : A.size(0);
const int n_chunk = _ubufs[0].size(0);
// Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const bool do_gelu = pre_gelu_out.numel() > 0;
const int output_chunk_bytes = (n_chunk * m) * D.element_size();
const int aux_chunk_bytes = do_gelu ? (n_chunk * m) * pre_gelu_out.element_size() : 0;
// Get output and workspace data pointers
char *output_ptr = reinterpret_cast<char *>(D.data_ptr());
char *pre_gelu_out_ptr = reinterpret_cast<char *>(pre_gelu_out.data_ptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr());
int workspace_size_chunk = workspaceSize / _stream_compute.size();
if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor];
if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor];
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0));
}
if (_aggregate2) {
const int num_steps = _tp_size / 2;
char *input_b_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
// Initial 1X input chunk exchange between neighboring peers
int send_chunk_id = _tp_id;
int recv_chunk_id = (_tp_id % 2 == 0) ? _tp_id + 1 : _tp_id - 1;
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;
int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank;
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank,
(cudaStream_t)_stream_send);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank,
(cudaStream_t)_stream_recv);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _stop_recv, 0));
int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1;
const int next_rank = (_tp_size + _tp_id + 2) % _tp_size + _rank_round_tp;
const int prev_rank = (_tp_size + _tp_id - 2) % _tp_size + _rank_round_tp;
// Ring exchange of 2X inputs chunks
for (int i = 0; i < num_steps; i++) {
send_chunk_id = (_tp_size + local_rank_round2 - i * 2) % _tp_size;
recv_chunk_id = (_tp_size + local_rank_round2 - i * 2 - 2) % _tp_size;
send_offset = comm_bytes * send_chunk_id;
recv_offset = comm_bytes * recv_chunk_id;
// GEMM
torch::Tensor input_b_chunk =
torch::from_blob(input_b_ptr + send_offset, {n_chunk * 2, k}, _ubuf.options());
torch::Tensor output_chunk = torch::from_blob(
output_ptr + (send_chunk_id * output_chunk_bytes), {n_chunk * 2, m}, D.options());
if (do_gelu) {
pre_gelu_out = torch::from_blob(pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes),
{n_chunk * 2, m}, pre_gelu_out.options());
}
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
{workspace_size_chunk}, workspace.options());
at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]);
te_gemm(A, A_scale_inverse, A_type, transa, input_b_chunk, B_scale_inverse, B_type, transb,
output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms);
if (i < num_steps - 1) {
// P2P communication
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm,
next_rank, (cudaStream_t)_stream_send);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm,
prev_rank, (cudaStream_t)_stream_recv);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(
(cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
} else if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send));
}
}
} else {
for (int i = 0; i < _tp_size; i++) {
// Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
// have the AG output in all ranks to be contiguous after the ring
// exchanges
int send_chunk_id = (_tp_size + _tp_id - i) % _tp_size;
int recv_chunk_id = (_tp_size + _tp_id - i - 1) % _tp_size;
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;
// GEMM
torch::Tensor output_chunk = torch::from_blob(
output_ptr + (send_chunk_id * output_chunk_bytes), {n_chunk, m}, D.options());
if (do_gelu) {
pre_gelu_out = torch::from_blob(pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes),
{n_chunk, m}, pre_gelu_out.options());
}
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
{workspace_size_chunk}, workspace.options());
at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]);
te_gemm(A, A_scale_inverse, A_type, transa, _ubufs[send_chunk_id], B_scale_inverse, B_type,
transb, output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms);
if (i < _tp_size - 1) {
// P2P communication
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm,
_next_rank, (cudaStream_t)_stream_send);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm,
_prev_rank, (cudaStream_t)_stream_recv);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(
(cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
} else if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send));
}
}
}
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
at::cuda::setCurrentCUDAStream(stream_main);
_ub_comm->sms = ori_sms;
return D;
} // split_overlap_ag
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type,
at::Tensor D_amax, at::Tensor bias,
transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, at::Tensor rs_output) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
// Get communication and GEMM input chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
// Get input and workspace data pointers
char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr());
int *counter_ptr = reinterpret_cast<int *>(counter.data_ptr());
int workspace_size_chunk = workspaceSize / _stream_compute.size();
if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor];
if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor];
// Catch up the main stream
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
// Atomic GEMM
// Process GEMM chunks in the order that AG+GEMM places the output chunks.
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options());
te_atomic_gemm(A, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, _ubuf,
D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, workspace_chunk,
workspace_size_chunk, accumulate, use_split_accumulator, _math_sms, 0, _tp_size,
true, counter);
// P2P communication chunk
for (int i = 1; i < _tp_size; i++) {
int send_chunk_id = i - 1;
int recv_chunk_id = send_chunk_id + _tp_size;
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;
int send_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp;
int recv_rank = (_tp_id + i) % _tp_size + _rank_round_tp;
consumer(counter_ptr, send_chunk_id, (cudaStream_t)_stream_recv);
userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank,
(cudaStream_t)_stream_recv);
userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank,
(cudaStream_t)_stream_recv);
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
// Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].data_ptr());
if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D_type, fp8_type,
reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size,
_ubufs[0].numel(), (cudaStream_t)stream_main););
} else {
torch::Tensor reduce_buf = torch::from_blob(
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0);
}
_ub_comm->sms = ori_sms;
}
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax,
at::Tensor bias, transformer_engine::DType bias_type,
at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator,
at::Tensor rs_output) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
int k = A.size(1);
int n = B.size(0);
// Get communication and GEMM input chunk sizes
int n_chunk = n / _tp_size;
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const int input_b_chunk_bytes = n_chunk * k * B.element_size();
// Get input and workspace data pointers
char *input_b_ptr = reinterpret_cast<char *>(B.data_ptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr());
int workspace_size_chunk = workspaceSize / _stream_compute.size();
if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor];
if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor];
// Catch up the main stream
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0));
}
// GEMM and send/recv chunks
for (int i = 0; i < _tp_size; i++) {
// GEMM chunk
int input_b_chunk_id = (_tp_id + i + 1) % _tp_size;
char *input_b_chunk_ptr = input_b_ptr + (input_b_chunk_id * input_b_chunk_bytes);
torch::Tensor input_b_chunk = torch::from_blob(input_b_chunk_ptr, {n_chunk, k}, B.options());
// Store the last GEMM chunk output to the recieve buffer.
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
{workspace_size_chunk}, workspace.options());
at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]);
te_gemm(A, A_scale_inverse, A_type, transa, input_b_chunk, B_scale_inverse, B_type, transb,
_ubufs[i], D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms);
if (i > 0) {
// P2P communication chunk
int send_offset = comm_bytes * (i - 1);
int recv_offset = comm_bytes * (i - 1 + _tp_size);
int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp;
int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp;
NVTE_CHECK_CUDA(cudaEventRecord(
_start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_comm, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_comm, 0));
userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm,
send_rank, (cudaStream_t)_stream_send);
userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm,
recv_rank, (cudaStream_t)_stream_recv);
}
}
at::cuda::setCurrentCUDAStream(stream_main);
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
// Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].data_ptr());
if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D_type, fp8_type,
reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size,
_ubufs[0].numel(), (cudaStream_t)stream_main););
} else {
torch::Tensor reduce_buf = torch::from_blob(
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0);
}
_ub_comm->sms = ori_sms;
}
/*
** Copy input to _ubufs[0]
*/
void copy_input_to_ubuf(torch::Tensor input, bool chunk) {
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
if (chunk) {
// Copy input to the target ubuf chunk by rank offset
if (input.numel() != _ubufs[0].numel() || input.element_size() != _ubufs[0].element_size()) {
NVTE_ERROR("input and ubuf size do not match!");
}
NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].data_ptr(), input.data_ptr(),
input.numel() * input.element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main));
} else {
if (input.numel() != _ubuf.numel() || input.element_size() != _ubuf.element_size()) {
NVTE_ERROR("input and ubuf size do not match!");
}
NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.data_ptr(), input.data_ptr(),
input.numel() * input.element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main));
}
}
torch::Tensor get_ubuf_output(int comm_type) {
char *ubuf_wt_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
COMM_TYPE _comm_type = static_cast<COMM_TYPE>(comm_type);
if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS) NVTE_ERROR("Invalid comm_type");
if (_comm_type == COMM_TYPE::RS)
ubuf_wt_ptr += _ubuf.numel() / _tp_size * _self_chunk_id * _ubuf.element_size();
int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size;
int output_c_dim1 = _ubuf.size(1);
return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options());
}
void set_ubuf_scale_inv(const torch::Tensor &scale_inv) {
_ubuf_scale_inv = scale_inv;
_ubuf_scale_inv_initialized = true;
}
bool is_fp8_ubuf() { return (_ubuf.element_size() == 1); }
bool is_atomic_gemm() { return _atomic_gemm; }
bool is_p2p_overlap() { return true; }
}; // UbufP2PCommOverlap
} // namespace ubuf
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMM_GEMM_OVERLAP_H_
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <transformer_engine/activation.h> #include <transformer_engine/activation.h>
#include <transformer_engine/cast.h> #include <transformer_engine/cast.h>
#include <transformer_engine/cast_transpose_noop.h> #include <transformer_engine/cast_transpose_noop.h>
#include <transformer_engine/comm_gemm_overlap.h>
#include <transformer_engine/fused_attn.h> #include <transformer_engine/fused_attn.h>
#include <transformer_engine/fused_rope.h> #include <transformer_engine/fused_rope.h>
#include <transformer_engine/gemm.h> #include <transformer_engine/gemm.h>
...@@ -37,12 +38,14 @@ ...@@ -37,12 +38,14 @@
#include <transformer_engine/transpose.h> #include <transformer_engine/transpose.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh> #include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <cassert>
#include <cstring> #include <cstring>
#include <iomanip> #include <iomanip>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <random> #include <random>
#include <stdexcept> #include <stdexcept>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <vector> #include <vector>
#include "common/util/logging.h" #include "common/util/logging.h"
......
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ #ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ #define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_
#include <optional>
#include "common.h" #include "common.h"
#include "common/common.h" #include "common/common.h"
...@@ -504,4 +506,184 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, ...@@ -504,4 +506,184 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
std::vector<size_t> input_row_list, std::vector<size_t> input_row_list,
std::vector<size_t> padded_input_row_list); std::vector<size_t> padded_input_row_list);
/***************************************************************************************************
* Comm+GEMM Overlap Wrappers
**************************************************************************************************/
class CommOverlapHelper : torch::CustomClassHolder {
private:
bool initialized{false};
bool backend_is_nccl{false};
std::map<std::string, c10d::ProcessGroup *> pgs;
public:
int myrank = -1;
int numranks = -1;
int mylocal = -1;
int numlocal = -1;
int mynode = -1;
int numnodes = -1;
CommOverlapHelper();
CommOverlapHelper(c10d::ProcessGroup *world_group,
std::optional<c10d::ProcessGroup *> intra_node_group,
std::optional<c10d::ProcessGroup *> inter_node_group);
~CommOverlapHelper();
void ub_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes,
ExtComm comm);
void ub_barrier(ExtComm comm);
};
class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase {
private:
torch::Tensor _ubuf_torch;
torch::Tensor _ubuf_counter;
public:
CommOverlap(const std::vector<size_t> &buffer_shape, at::ScalarType buffer_dtype,
CommOverlapHelper *helper, int tp_size, int num_splits = 3,
int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2,
int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false);
void set_ubuf_scale_inv(torch::Tensor scale_inv) {
assert(scale_inv.numel());
assert(scale_inv.scalar_type() == torch::kFloat32);
transformer_engine::CommOverlapBase::set_ubuf_scale_inv(
reinterpret_cast<float *>(scale_inv.data_ptr()));
}
void copy_input_to_ubuf(torch::Tensor input, int comm_type);
torch::Tensor get_ubuf_output(int comm_type);
/*
** Bulk GEMM + COMM
** This function assumes the communication input is pre-copied to _ubuf
*/
std::vector<at::Tensor> bulk_overlap(
at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse,
int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias,
transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator,
transformer_engine::CommOverlapType comm_type, at::Tensor rs_output);
/*
** Split FPROP GEMM + ReduceScatter
*/
void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type,
at::Tensor D_amax, at::Tensor bias,
transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, bool gemm_overlap,
at::Tensor rs_output);
/*
** Split FPROP GEMM + ReduceScatter
*/
void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax,
at::Tensor bias, transformer_engine::DType bias_type,
at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator,
bool gemm_overlap, at::Tensor rs_output);
}; // CommOverlap
class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase {
private:
torch::Tensor _ubuf_torch;
torch::Tensor _ubuf_counter;
public:
CommOverlapP2P(const std::vector<size_t> &buffer_shape, at::ScalarType buffer_dtype,
CommOverlapHelper *helper, int tp_size,
transformer_engine::CommOverlapType comm_type,
int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2,
int num_comm_sm = 3, bool set_sm_margin = true, bool atomic_gemm = false,
bool use_ce = true, bool aggregate = false);
void set_ubuf_scale_inv(torch::Tensor scale_inv) {
assert(scale_inv.numel());
assert(scale_inv.scalar_type() == torch::kFloat32);
transformer_engine::CommOverlapP2PBase::set_ubuf_scale_inv(
reinterpret_cast<float *>(scale_inv.data_ptr()));
}
void copy_input_to_ubuf(torch::Tensor input, bool chunk);
torch::Tensor get_ubuf_output(int comm_type);
/*
** Split AllGather + AtomicGEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is
*needed to have AG outputs
** in each rank to be in the contiguous memory space after all ring exchange
*phases.
*/
void atomic_gemm_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type,
at::Tensor D_amax, at::Tensor bias,
transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, at::Tensor B_copy);
/*
** Split AllGather + GEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is
*needed to have AG outputs
** in each rank to be in the contiguous memory space after all ring exchange
*phases.
*/
void split_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax,
at::Tensor bias, transformer_engine::DType bias_type,
at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator,
at::Tensor B_copy);
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type,
at::Tensor D_amax, at::Tensor bias,
transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, at::Tensor rs_output);
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax,
at::Tensor bias, transformer_engine::DType bias_type,
at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator,
at::Tensor rs_output);
}; // CommOverlapP2P
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../extensions.h"
#define HALF_BYTES 2
#define UB_MAX_SM 32
using namespace torch::indexing;
using namespace std::placeholders;
namespace te = transformer_engine;
#define MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inv, A_fp8_index, A_type, B, B_scale_inv, \
B_fp8_index, B_type, D, D_amax, D_scale, D_type, bias, \
bias_type, pre_gelu_out, workspace) \
A = A.contiguous(); \
void *A_scale_inv_ptr = nullptr; \
if (te::is_fp8_dtype(A_type)) { \
assert(A_scale_inv.numel()); \
A_scale_inv_ptr = A_scale_inv[A_fp8_index].data_ptr(); \
} \
auto A_ = makeTransformerEngineTensor( \
A.data_ptr(), {static_cast<size_t>(A.size(0)), static_cast<size_t>(A.size(1))}, A_type, \
nullptr, nullptr, A_scale_inv_ptr); \
B = B.contiguous(); \
void *B_scale_inv_ptr = nullptr; \
if (te::is_fp8_dtype(B_type)) { \
assert(B_scale_inv.numel()); \
B_scale_inv_ptr = B_scale_inv[B_fp8_index].data_ptr(); \
} \
auto B_ = makeTransformerEngineTensor( \
B.data_ptr(), {static_cast<size_t>(B.size(0)), static_cast<size_t>(B.size(1))}, B_type, \
nullptr, nullptr, B_scale_inv_ptr); \
void *D_amax_ptr = nullptr; \
void *D_scale_ptr = nullptr; \
if (te::is_fp8_dtype(D_type)) { \
assert(D_amax.numel()); \
D_amax_ptr = D_amax.data_ptr(); \
assert(D_scale.numel()); \
D_scale_ptr = D_scale.data_ptr(); \
} \
auto D_ = makeTransformerEngineTensor( \
D.data_ptr(), {static_cast<size_t>(D.size(0)), static_cast<size_t>(D.size(1))}, D_type, \
D_amax_ptr, D_scale_ptr, nullptr); \
auto bias_ = makeTransformerEngineTensor( \
bias.data_ptr(), std::vector<size_t>{static_cast<size_t>(bias.size(0))}, bias_type); \
const auto gelu_shape = (pre_gelu_out.data_ptr() == nullptr) \
? std::vector<size_t>{static_cast<size_t>(pre_gelu_out.size(0))} \
: std::vector<size_t>{static_cast<size_t>(pre_gelu_out.size(0)), \
static_cast<size_t>(pre_gelu_out.size(1))}; \
auto pre_gelu_out_ = makeTransformerEngineTensor( \
pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); \
auto workspace_ = makeTransformerEngineTensor( \
workspace.data_ptr(), std::vector<size_t>{static_cast<size_t>(workspace.size(0))}, \
te::DType::kByte);
/***************************************************************************************************
* CommOverlapHelper
**************************************************************************************************/
CommOverlapHelper::CommOverlapHelper() {
#ifndef NVTE_UB_WITH_MPI
NVTE_ERROR("Internal TE error: Dummy CommOverlapHelper init without NVTE_UB_WITH_MPI=1!");
#endif
} // empty constructor for NVTE_UB_WITH_MPI=1
CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group,
std::optional<c10d::ProcessGroup *> intra_domain_group,
std::optional<c10d::ProcessGroup *> inter_domain_group) {
#ifndef NVTE_UB_WITH_MPI
pgs.insert({"world", world_group});
myrank = pgs["world"]->getRank();
numranks = pgs["world"]->getSize();
c10d::ProcessGroup::BackendType backend = pgs["world"]->getBackendType();
backend_is_nccl = (backend == c10d::ProcessGroup::BackendType::NCCL);
if (intra_domain_group.has_value()) {
// Get local rank on node and number of local ranks
NVTE_CHECK(intra_domain_group.value()->getBackendType() == backend,
"Internal TE error: Intra-node group must be on the same backend (%s) as the world ",
"group!", pgs["world"]->getBackendName());
pgs.insert({"intra", intra_domain_group.value()});
mylocal = pgs["intra"]->getRank();
numlocal = pgs["intra"]->getSize();
if (numlocal == numranks) {
// Intra-node group is same as the world group so there can only be 1 node
NVTE_CHECK(
mylocal == myrank,
"Internal TE error: Local rank must be equal to global rank when intra-node group size ",
"is equal to the world group size!");
mynode = 0;
numnodes = 1;
} else {
// Intra-node group is different than the world group so there must be multiple nodes
NVTE_CHECK(
inter_domain_group.has_value(),
"Internal TE error: Inter-node group cannot be `None` when intra-node group is not ",
"identical to the world_group!");
// Get node ID and number of nodes
NVTE_CHECK(
inter_domain_group.value()->getBackendType() == backend,
"Internal TE error: Inter-node group must be on the same backend (%s) as the world ",
"group!", pgs["world"]->getBackendName());
pgs.insert({"inter", inter_domain_group.value()});
mynode = pgs["inter"]->getRank();
numnodes = pgs["inter"]->getSize();
}
} else {
// Intra-node group is not set so we assume there is only 1 node
mylocal = myrank;
numlocal = numranks;
pgs.insert({"intra", world_group});
mynode = 0;
numnodes = 1;
}
initialized = true;
#else
NVTE_ERROR("Internal TE error: CommOverlapHelper cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_UB_WITH_MPI=1!");
#endif
}
CommOverlapHelper::~CommOverlapHelper() {
#ifndef NVTE_UB_WITH_MPI
for (auto &pg : pgs) pg.second = nullptr;
backend_is_nccl = false;
initialized = false;
#endif
}
void CommOverlapHelper::ub_allgather(void *globaldata, size_t globalbytes, void *localdata,
size_t localbytes, ExtComm group) {
#ifndef NVTE_UB_WITH_MPI
NVTE_CHECK(initialized, "Internal TE error: tex.CommOverlapHelper() is not initialized ",
"with valid process groups!");
auto localtensor =
torch::from_blob(localdata, {static_cast<int64_t>(localbytes / sizeof(uint8_t))},
at::device(torch::kCPU).dtype(torch::kUInt8));
auto localtmp = (backend_is_nccl) ? localtensor.cuda() : localtensor;
auto globaltensor =
torch::from_blob(globaldata, {static_cast<int64_t>(globalbytes / sizeof(uint8_t))},
at::device(torch::kCPU).dtype(torch::kUInt8));
auto globaltmp = (backend_is_nccl) ? globaltensor.cuda() : globaltensor;
std::vector<std::vector<torch::Tensor>> globalchunks = {globaltmp.chunk(pgs[group]->getSize())};
std::vector<torch::Tensor> localchunk = {localtmp};
auto work = pgs[group]->allgather(globalchunks, localchunk);
work->wait();
if (backend_is_nccl) {
globaltensor.copy_(globaltmp.cpu());
globaltmp = torch::Tensor();
localtmp = torch::Tensor();
}
#else
NVTE_ERROR("Internal TE error: CommOverlapHelper::ub_allgather is a no-op when TE is compiled ",
"with NVTE_UB_WITH_MPI=1!");
#endif
}
void CommOverlapHelper::ub_barrier(ExtComm group) {
#ifndef NVTE_UB_WITH_MPI
NVTE_CHECK(initialized, "Internal TE error: tex.CommOverlapHelper() is not initialized ",
"with valid process groups!");
auto work = pgs[group]->barrier();
work->wait();
#else
NVTE_ERROR("Internal TE error: CommOverlapHelper::ub_barrier is a no-op when TE is compiled ",
"with NVTE_UB_WITH_MPI=1!");
#endif
}
/***************************************************************************************************
* CommOverlap
**************************************************************************************************/
CommOverlap::CommOverlap(const std::vector<size_t> &buffer_shape, at::ScalarType buffer_dtype,
CommOverlapHelper *helper, int tp_size, int num_splits,
int num_max_streams, int comm_cga_size, int num_comm_sm,
bool set_sm_margin, bool atomic_gemm)
: te::CommOverlapBase(buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank,
helper->numranks, helper->mylocal, helper->numlocal, helper->mynode,
helper->numnodes, tp_size,
std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5),
std::bind(&CommOverlapHelper::ub_barrier, helper, _1), num_splits,
num_max_streams, comm_cga_size, num_comm_sm, set_sm_margin, atomic_gemm) {
// Even though we never use these PyTorch tensor wrappers directly, they're still necessary to
// for PyTorch to factor externally allocated memory into its memory pool and garbage collection
// threshold calculation.
_ubuf_torch = torch::from_blob(
_ubuf.dptr(), {static_cast<int64_t>(_ubuf.size(0)), static_cast<int64_t>(_ubuf.size(1))},
at::device(torch::kCUDA).dtype(buffer_dtype));
if (_atomic_gemm) {
_ubuf_counter = torch::from_blob(_counter.dptr(), {static_cast<int64_t>(_num_splits * 2)},
at::device(torch::kCUDA).dtype(torch::kInt32));
}
}
/*
** Bulk GEMM + COMM
** This function assumes the communication input is pre-copied to _ubuf
*/
std::vector<at::Tensor> CommOverlap::bulk_overlap(
at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa,
at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb,
at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias,
te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator,
te::CommOverlapType comm_type, at::Tensor rs_output) {
MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse,
B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type,
pre_gelu_out, workspace)
auto rs_out_ = makeTransformerEngineTensor(rs_output);
cudaStream_t stream_main = static_cast<cudaStream_t>(at::cuda::getCurrentCUDAStream());
te::CommOverlapBase::bulk_overlap(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, workspace_,
grad, accumulate, use_split_accumulator, comm_type, rs_out_,
stream_main);
// Get the current userbuf offset
char *ubuf_wt_ptr = reinterpret_cast<char *>(_ubuf.dptr());
if (comm_type == te::CommOverlapType::RS) {
ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size();
}
// Generate output tensor from userbuf data pointer
int output_c_dim0 =
(comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size;
int output_c_dim1 = _ubuf.size(1);
auto output_tensor =
torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf_torch.options());
return {D, output_tensor};
} // CommOverlap::bulk_overlap
/*
** Split FPROP GEMM + ReduceScatter
*/
void CommOverlap::atomic_gemm_overlap_rs(
at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa,
at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb,
at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias,
te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, bool gemm_overlap,
at::Tensor rs_output) {
MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse,
B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type,
pre_gelu_out, workspace)
auto rs_out_ = makeTransformerEngineTensor(rs_output);
cudaStream_t stream_main = static_cast<cudaStream_t>(at::cuda::getCurrentCUDAStream());
te::CommOverlapBase::atomic_gemm_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_,
workspace_, grad, accumulate, use_split_accumulator,
gemm_overlap, rs_out_, stream_main);
} // CommOverlap::split_overlap_rs
/*
** Split FPROP GEMM + ReduceScatter
*/
void CommOverlap::split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
te::DType A_type, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
te::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale,
te::DType D_type, at::Tensor D_amax, at::Tensor bias,
te::DType bias_type, at::Tensor pre_gelu_out, bool grad,
at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, bool gemm_overlap,
at::Tensor rs_output) {
MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse,
B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type,
pre_gelu_out, workspace)
auto rs_out_ = makeTransformerEngineTensor(rs_output);
cudaStream_t stream_main = static_cast<cudaStream_t>(at::cuda::getCurrentCUDAStream());
te::CommOverlapBase::split_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_,
workspace_, grad, accumulate, use_split_accumulator,
gemm_overlap, rs_out_, stream_main);
} // CommOverlap::split_overlap_rs
/*
** Helper function to copy input to _ubuf
*/
void CommOverlap::copy_input_to_ubuf(torch::Tensor input, int comm_type) {
char *ubuf_ptr = reinterpret_cast<char *>(_ubuf.dptr());
te::CommOverlapType _comm_type = static_cast<te::CommOverlapType>(comm_type);
if (_comm_type == te::CommOverlapType::AG) {
if ((input.numel() * _tp_size) != (int64_t)_ubuf.numel() ||
input.element_size() != (int64_t)_ubuf.element_size()) {
NVTE_ERROR("input and ubuf size do not match!");
}
ubuf_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size();
} else {
if (input.numel() != (int64_t)_ubuf.numel() ||
input.element_size() != (int64_t)_ubuf.element_size()) {
NVTE_ERROR("input and ubuf size do not match!");
}
}
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0));
NVTE_CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.data_ptr(), input.numel() * input.element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm));
}
torch::Tensor CommOverlap::get_ubuf_output(int comm_type) {
char *ubuf_wt_ptr = reinterpret_cast<char *>(_ubuf.dptr());
te::CommOverlapType _comm_type = static_cast<te::CommOverlapType>(comm_type);
if (_comm_type != te::CommOverlapType::AG && _comm_type != te::CommOverlapType::RS)
NVTE_ERROR("Invalid comm_type");
if (_comm_type == te::CommOverlapType::RS)
ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size();
int output_c_dim0 =
(_comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size;
int output_c_dim1 = _ubuf.size(1);
return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1},
torch::device(torch::kCUDA).dtype(GetATenDType(_ubuf.dtype())));
}
/***************************************************************************************************
* CommOverlapP2P
**************************************************************************************************/
CommOverlapP2P::CommOverlapP2P(const std::vector<size_t> &buffer_shape, at::ScalarType buffer_dtype,
CommOverlapHelper *helper, int tp_size,
te::CommOverlapType comm_type, int num_max_streams,
int comm_cga_size, int num_comm_sm, bool set_sm_margin,
bool atomic_gemm, bool use_ce, bool aggregate)
: te::CommOverlapP2PBase(
buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank, helper->numranks,
helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, tp_size,
std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5),
std::bind(&CommOverlapHelper::ub_barrier, helper, _1), comm_type, num_max_streams,
comm_cga_size, num_comm_sm, set_sm_margin, use_ce, atomic_gemm, aggregate) {
// Even though we never use these PyTorch tensor wrappers directly, they're still necessary to
// for PyTorch to factor externally allocated memory into its memory pool and garbage collection
// threshold calculation.
_ubuf_torch = torch::from_blob(
_ubuf.dptr(), {static_cast<int64_t>(_ubuf.size(0)), static_cast<int64_t>(_ubuf.size(1))},
at::device(torch::kCUDA).dtype(buffer_dtype));
if (_atomic_gemm) {
_ubuf_counter = torch::from_blob(_counter.dptr(), {static_cast<int64_t>(_num_splits * 2)},
at::device(torch::kCUDA).dtype(torch::kInt32));
}
}
/*
** Split AllGather + AtomicGEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is
*needed to have AG outputs
** in each rank to be in the contiguous memory space after all ring exchange
*phases.
*/
void CommOverlapP2P::atomic_gemm_overlap_ag(
at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa,
at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb,
at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias,
te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) {
MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse,
B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type,
pre_gelu_out, workspace)
auto B_copy_ = makeTransformerEngineTensor(B_copy);
cudaStream_t stream_main = static_cast<cudaStream_t>(at::cuda::getCurrentCUDAStream());
te::CommOverlapP2PBase::atomic_gemm_overlap_ag(A_, transa, B_, transb, D_, bias_, pre_gelu_out_,
workspace_, grad, accumulate,
use_split_accumulator, B_copy_, stream_main);
} // atomic_gemm_overlap_ag
/*
** Split AllGather + GEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is
*needed to have AG outputs
** in each rank to be in the contiguous memory space after all ring exchange
*phases.
*/
void CommOverlapP2P::split_overlap_ag(
at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa,
at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb,
at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias,
te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) {
MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse,
B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type,
pre_gelu_out, workspace)
auto B_copy_ = makeTransformerEngineTensor(B_copy);
cudaStream_t stream_main = static_cast<cudaStream_t>(at::cuda::getCurrentCUDAStream());
te::CommOverlapP2PBase::split_overlap_ag(A_, transa, B_, transb, D_, bias_, pre_gelu_out_,
workspace_, grad, accumulate, use_split_accumulator,
B_copy_, stream_main);
} // split_overlap_ag
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void CommOverlapP2P::atomic_gemm_overlap_rs(
at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa,
at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb,
at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias,
te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor rs_output) {
MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse,
B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type,
pre_gelu_out, workspace)
auto rs_out_ = makeTransformerEngineTensor(rs_output);
cudaStream_t stream_main = static_cast<cudaStream_t>(at::cuda::getCurrentCUDAStream());
te::CommOverlapP2PBase::atomic_gemm_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_,
workspace_, grad, accumulate,
use_split_accumulator, rs_out_, stream_main);
}
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void CommOverlapP2P::split_overlap_rs(
at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa,
at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb,
at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias,
te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor rs_output) {
MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse,
B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type,
pre_gelu_out, workspace)
auto rs_out_ = makeTransformerEngineTensor(rs_output);
cudaStream_t stream_main = static_cast<cudaStream_t>(at::cuda::getCurrentCUDAStream());
te::CommOverlapP2PBase::split_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_,
workspace_, grad, accumulate, use_split_accumulator,
rs_out_, stream_main);
}
/*
** Copy input to _ubufs[0]
*/
void CommOverlapP2P::copy_input_to_ubuf(torch::Tensor input, bool chunk) {
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
if (chunk) {
// Copy input to the target ubuf chunk by rank offset
if (input.numel() != (int64_t)_ubufs[0].numel() ||
input.element_size() != (int64_t)_ubufs[0].element_size()) {
NVTE_ERROR("input and ubuf size do not match!");
}
NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].dptr(), input.data_ptr(),
input.numel() * input.element_size(), cudaMemcpyDeviceToDevice,
(cudaStream_t)stream_main));
} else {
if (input.numel() != (int64_t)_ubuf.numel() ||
input.element_size() != (int64_t)_ubuf.element_size()) {
NVTE_ERROR("input and ubuf size do not match!");
}
NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.dptr(), input.data_ptr(),
input.numel() * input.element_size(), cudaMemcpyDeviceToDevice,
(cudaStream_t)stream_main));
}
}
torch::Tensor CommOverlapP2P::get_ubuf_output(int comm_type) {
char *ubuf_wt_ptr = reinterpret_cast<char *>(_ubuf.dptr());
te::CommOverlapType _comm_type = static_cast<te::CommOverlapType>(comm_type);
if (_comm_type != te::CommOverlapType::AG && _comm_type != te::CommOverlapType::RS)
NVTE_ERROR("Invalid comm_type");
if (_comm_type == te::CommOverlapType::RS)
ubuf_wt_ptr += _ubuf.numel() / _tp_size * _self_chunk_id * _ubuf.element_size();
int output_c_dim0 =
(_comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size;
int output_c_dim1 = _ubuf.size(1);
return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf_torch.options());
}
...@@ -4,12 +4,15 @@ ...@@ -4,12 +4,15 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <pybind11/functional.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "../comm_gemm_overlap.h"
#include "../extensions.h" #include "../extensions.h"
#include "common/util/pybind_helper.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m)
// Permutation functions // Permutation functions
m.def("moe_permute_fwd", moe_permute_fwd); m.def("moe_permute_fwd", moe_permute_fwd);
m.def("moe_permute_bwd", moe_permute_bwd); m.def("moe_permute_bwd", moe_permute_bwd);
...@@ -226,90 +229,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -226,90 +229,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv) .def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv)
.def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history); .def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history);
m.def("device_supports_multicast", &ubuf::device_supports_multicast,
py::call_guard<py::gil_scoped_release>());
m.def("ubuf_built_with_mpi", &ubuf::ubuf_built_with_mpi,
py::call_guard<py::gil_scoped_release>());
py::class_<ubuf::UbufBootstrapCallbacks>(m, "UbufBootstrapCallbacks")
.def(py::init<>(), py::call_guard<py::gil_scoped_release>())
.def(py::init<c10d::ProcessGroup *, c10d::ProcessGroup *>(),
py::call_guard<py::gil_scoped_release>());
py::enum_<ubuf::UBOverlapAlgo>(m, "UbufOverlapAlgo")
.value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG)
.value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS)
.value("SPLIT_PIPELINED_RS", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS)
.value("SPLIT_PIPELINED_RS_P2P", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS_P2P)
.value("SPLIT_PIPELINED_AG_P2P", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG_P2P)
.value("ATOMIC_GEMM_RS", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS)
.value("ATOMIC_GEMM_AG_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_AG_P2P)
.value("ATOMIC_GEMM_RS_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS_P2P);
// Note: Can't release GIL in constructor since it may bootstrap
// communicator with Python functions (e.g. PyTorch distributed
// communication)
py::class_<ubuf::UbufCommOverlap>(m, "UbufCommOverlap")
.def(py::init<torch::Tensor &, int, int, int, int, int, int, int, int, int, int, bool, int,
bool, ubuf::UbufBootstrapCallbacks &>(),
py::call_guard<py::gil_scoped_release>())
.def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap,
py::call_guard<py::gil_scoped_release>())
.def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs,
py::call_guard<py::gil_scoped_release>())
.def("set_ubuf_scale_inv", &ubuf::UbufCommOverlap::set_ubuf_scale_inv,
py::call_guard<py::gil_scoped_release>())
.def("atomic_gemm_overlap_rs", &ubuf::UbufCommOverlap::atomic_gemm_overlap_rs,
py::call_guard<py::gil_scoped_release>())
.def("is_fp8_ubuf", &ubuf::UbufCommOverlap::is_fp8_ubuf,
py::call_guard<py::gil_scoped_release>())
.def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf,
py::call_guard<py::gil_scoped_release>())
.def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output,
py::call_guard<py::gil_scoped_release>())
.def("is_atomic_gemm", &ubuf::UbufCommOverlap::is_atomic_gemm,
py::call_guard<py::gil_scoped_release>())
.def("is_p2p_overlap", &ubuf::UbufCommOverlap::is_p2p_overlap,
py::call_guard<py::gil_scoped_release>());
// Note: Can't release GIL in constructor since it may bootstrap
// communicator with Python functions (e.g. PyTorch distributed
// communication)
py::class_<ubuf::UbufP2PCommOverlap>(m, "UbufP2PCommOverlap")
.def(py::init<torch::Tensor &, int, int, int, int, int, int, int, int, int, bool, bool, int,
bool, bool, bool, ubuf::UbufBootstrapCallbacks &>(),
py::call_guard<py::gil_scoped_release>())
.def("split_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_ag,
py::call_guard<py::gil_scoped_release>())
.def("split_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_rs,
py::call_guard<py::gil_scoped_release>())
.def("atomic_gemm_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_ag,
py::call_guard<py::gil_scoped_release>())
.def("atomic_gemm_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_rs,
py::call_guard<py::gil_scoped_release>())
.def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf,
py::call_guard<py::gil_scoped_release>())
.def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output,
py::call_guard<py::gil_scoped_release>())
.def("is_fp8_ubuf", &ubuf::UbufP2PCommOverlap::is_fp8_ubuf,
py::call_guard<py::gil_scoped_release>())
.def("is_atomic_gemm", &ubuf::UbufP2PCommOverlap::is_atomic_gemm,
py::call_guard<py::gil_scoped_release>())
.def("is_p2p_overlap", &ubuf::UbufP2PCommOverlap::is_p2p_overlap,
py::call_guard<py::gil_scoped_release>())
.def("set_ubuf_scale_inv", &ubuf::UbufP2PCommOverlap::set_ubuf_scale_inv,
py::call_guard<py::gil_scoped_release>());
py::enum_<transformer_engine::DType>(m, "DType", py::module_local())
.value("kByte", transformer_engine::DType::kByte)
.value("kInt32", transformer_engine::DType::kInt32)
.value("kFloat32", transformer_engine::DType::kFloat32)
.value("kFloat16", transformer_engine::DType::kFloat16)
.value("kBFloat16", transformer_engine::DType::kBFloat16)
.value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3)
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2);
py::enum_<transformer_engine::FP8FwdTensors>(m, "FP8FwdTensors") py::enum_<transformer_engine::FP8FwdTensors>(m, "FP8FwdTensors")
.value("GEMM1_INPUT", transformer_engine::FP8FwdTensors::GEMM1_INPUT) .value("GEMM1_INPUT", transformer_engine::FP8FwdTensors::GEMM1_INPUT)
.value("GEMM1_WEIGHT", transformer_engine::FP8FwdTensors::GEMM1_WEIGHT) .value("GEMM1_WEIGHT", transformer_engine::FP8FwdTensors::GEMM1_WEIGHT)
...@@ -329,41 +248,61 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -329,41 +248,61 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.value("GRAD_OUTPUT3", transformer_engine::FP8BwdTensors::GRAD_OUTPUT3) .value("GRAD_OUTPUT3", transformer_engine::FP8BwdTensors::GRAD_OUTPUT3)
.value("GRAD_INPUT3", transformer_engine::FP8BwdTensors::GRAD_INPUT3); .value("GRAD_INPUT3", transformer_engine::FP8BwdTensors::GRAD_INPUT3);
py::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type") py::class_<CommOverlapHelper>(m, "CommOverlapHelper")
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) .def(py::init<>(), py::call_guard<py::gil_scoped_release>())
.value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) .def(py::init<c10d::ProcessGroup *, std::optional<c10d::ProcessGroup *>,
.value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) std::optional<c10d::ProcessGroup *>>(),
.value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); py::call_guard<py::gil_scoped_release>(), py::arg("world_group"),
py::arg("intra_node_group") = py::none(), py::arg("inter_node_group") = py::none());
py::enum_<NVTE_Mask_Type>(m, "NVTE_Mask_Type")
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK)
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK)
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK)
.value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK)
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK",
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK);
py::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout") py::class_<CommOverlap>(m, "CommOverlap")
.value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) .def(py::init<const std::vector<size_t> &, at::ScalarType, CommOverlapHelper *, int, int, int,
.value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) int, int, bool, bool>(),
.value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) py::call_guard<py::gil_scoped_release>(), py::arg("buffer_shape"),
.value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"),
.value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) py::arg("num_splits") = 3, py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS,
.value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) py::arg("comm_cga_size") = 2, py::arg("num_comm_sm") = 16,
.value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) py::arg("set_sm_margin") = true, py::arg("atomic_gemm") = false)
.value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) .def("bulk_overlap", &CommOverlap::bulk_overlap, py::call_guard<py::gil_scoped_release>())
.value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) .def("split_overlap_rs", &CommOverlap::split_overlap_rs,
.value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) py::call_guard<py::gil_scoped_release>())
.value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) .def("atomic_gemm_overlap_rs", &CommOverlap::atomic_gemm_overlap_rs,
.value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) py::call_guard<py::gil_scoped_release>())
.value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) .def("copy_input_to_ubuf", &CommOverlap::copy_input_to_ubuf,
.value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) py::call_guard<py::gil_scoped_release>())
.value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); .def("get_ubuf_output", &CommOverlap::get_ubuf_output,
py::call_guard<py::gil_scoped_release>())
.def("set_ubuf_scale_inv", &CommOverlap::set_ubuf_scale_inv,
py::call_guard<py::gil_scoped_release>())
.def("is_atomic_gemm", &CommOverlap::is_atomic_gemm, py::call_guard<py::gil_scoped_release>())
.def("is_p2p_overlap", &CommOverlap::is_p2p_overlap, py::call_guard<py::gil_scoped_release>())
.def("is_fp8_ubuf", &CommOverlap::is_fp8_ubuf, py::call_guard<py::gil_scoped_release>());
py::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend") py::class_<CommOverlapP2P>(m, "CommOverlapP2P")
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) .def(py::init<const std::vector<size_t> &, at::ScalarType, CommOverlapHelper *, int,
.value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) transformer_engine::CommOverlapType, int, int, int, bool, bool, bool, bool>(),
.value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) py::call_guard<py::gil_scoped_release>(), py::arg("buffer_shape"),
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), py::arg("comm_type"),
py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 1,
py::arg("num_comm_sm") = 1, py::arg("set_sm_margin") = false,
py::arg("atomic_gemm") = false, py::arg("use_ce") = true, py::arg("aggregate") = false)
.def("split_overlap_ag_p2p", &CommOverlapP2P::split_overlap_ag,
py::call_guard<py::gil_scoped_release>())
.def("split_overlap_rs_p2p", &CommOverlapP2P::split_overlap_rs,
py::call_guard<py::gil_scoped_release>())
.def("atomic_gemm_overlap_ag_p2p", &CommOverlapP2P::atomic_gemm_overlap_ag,
py::call_guard<py::gil_scoped_release>())
.def("atomic_gemm_overlap_rs_p2p", &CommOverlapP2P::atomic_gemm_overlap_rs,
py::call_guard<py::gil_scoped_release>())
.def("copy_input_to_ubuf", &CommOverlapP2P::copy_input_to_ubuf,
py::call_guard<py::gil_scoped_release>())
.def("get_ubuf_output", &CommOverlapP2P::get_ubuf_output,
py::call_guard<py::gil_scoped_release>())
.def("set_ubuf_scale_inv", &CommOverlapP2P::set_ubuf_scale_inv,
py::call_guard<py::gil_scoped_release>())
.def("is_fp8_ubuf", &CommOverlapP2P::is_fp8_ubuf, py::call_guard<py::gil_scoped_release>())
.def("is_atomic_gemm", &CommOverlapP2P::is_atomic_gemm,
py::call_guard<py::gil_scoped_release>())
.def("is_p2p_overlap", &CommOverlapP2P::is_p2p_overlap,
py::call_guard<py::gil_scoped_release>());
} }
...@@ -87,9 +87,55 @@ def initialize_ub( ...@@ -87,9 +87,55 @@ def initialize_ub(
ub_cfgs: Optional[dict] = None, ub_cfgs: Optional[dict] = None,
bootstrap_backend: Union[str, torch.distributed.Backend] = None, bootstrap_backend: Union[str, torch.distributed.Backend] = None,
) -> None: ) -> None:
"""Initialize communicators for TP comm overlap using userbuffers.""" r"""
Initialize the Userbuffers communicator for overlapping tensor-parallel communications with
GEMM compute in te.Linear, te.LayerNormLinear and te.LayerNormMLP modules.
Parameters
----------
shape : list
shape of the communication buffer, typically set to be the same as the global shape of
the input tensor to a te.TransformerLayer forward pass, with the sequence and batch
dimensions collapsed together -- i.e.: `(sequence_length * batch_size, hidden_size)`
tp_size : int
number of GPUs in the tensor-parallel process group
use_fp8 : bool = False
allocate the communication buffer for FP8 GEMM inputs/outputs
dtype : torch.dtype = torch.bfloat16
non-FP8 data type of the communication buffer when `use_fp8 = False`
ub_cfgs: dict = None
Configuration dictionary with the structure
```
{
<gemm_name> : {
"method": <"ring_exchange" or "pipeline">,
"is_reduce_scatter": bool,
"num_sm": int,
"cga_size": int,
"set_sm_margin": bool,
"num_splits": int,
"aggregate": bool,
"atomic_gemm": bool,
"use_ce": bool,
"fp8_buf": bool,
}
}
```
for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad",
"proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad",
"fc2_fprop", "fc2_dgrad"]`.
bootstrap_backend : str = None
`torch.distributed` communication backend for the all-gather, broadcast and
barrier collectives during Userbuffers initialization. Not all backends are
valid for every cluster configuration and distributed launch method even if
they are available in PyTorch. When left unset, the initialization prefers
to use the MPI backend, falling back first on Gloo and then NCCL if MPI is
not available. Setting `NVTE_UB_WITH_MPI=1` when building TE overrides this
option and always initializes Userbuffers with direct MPI calls in C++,
which also requires `MPI_HOME=/path/to/mpi/root` to be set at compile time.
"""
if not tex.device_supports_multicast(): if not tex.device_supports_multicast():
assert bool(os.getenv("UB_SKIPMC", "0")), ( assert bool(int(os.getenv("UB_SKIPMC", "0"))), (
"CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with " "CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with "
+ "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead." + "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead."
) )
...@@ -99,50 +145,52 @@ def initialize_ub( ...@@ -99,50 +145,52 @@ def initialize_ub(
_ub_communicators = {} _ub_communicators = {}
if tex.ubuf_built_with_mpi(): if tex.ubuf_built_with_mpi():
# Userbuffers will ignore all these values when it is built with MPI, so these are just # We're bootstrapping with direct calls to MPI in Userbuffers code so we need to force
# placeholders based on an assumption that tp_size covers all devices in a physical node. # an MPI_Init() here by creating a new MPI process group...
assert torch.distributed.is_mpi_available() assert torch.distributed.is_mpi_available()
mpi_group = torch.distributed.new_group(backend="mpi") _ = torch.distributed.new_group(backend="mpi")
world_rank = torch.distributed.get_rank(mpi_group) helper = tex.CommOverlapHelper()
world_size = torch.distributed.get_world_size(mpi_group)
local_rank = world_rank % tp_size
local_size = tp_size
self_node_idx = world_rank // tp_size
num_nodes = world_size // tp_size
ub_callbacks = tex.UbufBootstrapCallbacks()
else: else:
# Bootstrapping with torch.distributed API, so check backend and construct
# intra/inter-node process groups...
assert ( assert (
torch.distributed.is_initialized() torch.distributed.is_initialized()
), "torch.distributed must be initialized before Userbuffers" ), "torch.distributed must be initialized before Userbuffers"
if bootstrap_backend is None: if bootstrap_backend is None:
bootstrap_backend = "nccl" bootstrap_backend = "nccl"
if torch.distributed.is_gloo_available(): if torch.distributed.is_mpi_available():
bootstrap_backend = "gloo"
elif torch.distributed.is_mpi_available():
bootstrap_backend = "mpi" bootstrap_backend = "mpi"
elif torch.distributed.is_gloo_available():
bootstrap_backend = "gloo"
else: else:
assert bootstrap_backend in ["gloo", "mpi", "nccl"] assert bootstrap_backend in [
"gloo",
"mpi",
"nccl",
], "Invalid torch.distributed backend for bootstrapping Userbuffers!"
assert torch.distributed.is_backend_available(bootstrap_backend), (
f"PyTorch must be compiled with '{bootstrap_backend}' support in order to "
f"bootstrap Userbuffers with '{bootstrap_backend}' collectives."
)
world_group = torch.distributed.new_group(backend=bootstrap_backend) world_group = torch.distributed.new_group(backend=bootstrap_backend)
world_rank = torch.distributed.get_rank(world_group) world_rank = torch.distributed.get_rank(world_group)
world_size = torch.distributed.get_world_size(world_group) world_size = torch.distributed.get_world_size(world_group)
# Construct an intra-node communicator based on global ranks that share the same hostname # We have single-node NVLink so we can color based on physical node hostnames.
# NOTE: If the user specified a valid network interface for NCCL or GLOO, use the host # NOTE: Prefer a network interface defined via the NVTE_UB_SOCKET_IFNAME variable, and
# address on that interface instead of the hostname. This can help avoid issues when # otherwise fall back on NCCL_SOCKET_IFNAME or GLOO_SOCKET_IFNAME depending on
# different hosts have the same hostname on Kubernetes clusters. # the chosen bootstrap backend.
hostname = socket.gethostname() mydomain = socket.gethostname()
ifname = os.getenv( ifname = os.getenv(
"NVTE_UB_SOCKET_IFNAME", "NVTE_UB_SOCKET_IFNAME", os.getenv(f"{bootstrap_backend.upper()}_SOCKET_IFNAME")
os.getenv("NCCL_SOCKET_IFNAME", os.getenv("GLOO_SOCKET_IFNAME")),
) )
if ifname is not None: if ifname is not None:
# Make sure the ifname found in the environment is a valid network interface # Make sure the ifname found in the environment is a valid network interface
if ifname in [name for _, name in socket.if_nameindex()]: if ifname in [name for _, name in socket.if_nameindex()]:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try: try:
hostname = socket.inet_ntoa( mydomain = socket.inet_ntoa(
fcntl.ioctl( fcntl.ioctl(
s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8")) s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8"))
)[20:24] )[20:24]
...@@ -154,57 +202,63 @@ def initialize_ub( ...@@ -154,57 +202,63 @@ def initialize_ub(
else: else:
ifname_warning = ( ifname_warning = (
f"'{ifname}' is not a valid network interface! `te.initialize_ub()` will" f"'{ifname}' is not a valid network interface! `te.initialize_ub()` will"
" attempt to " + " attempt to detect ranks on the same node by matching "
+ "detect ranks on the same node by matching 'socket.gethostname()', which is " + "'socket.gethostname()', which is known to fail on virtual clusters like "
+ "known to fail on virtual clusters like Kubernetes. If Userbuffers " + "Kubernetes. If Userbuffers initialization fails, please set the "
+ "initialization fails, please set the 'NVTE_UB_SOCKET_IFNAME' variable in " + "'NVTE_UB_SOCKET_IFNAME' variable in your environment to the correct network "
+ "your environment to the correct network interface." + "interface."
) )
warnings.warn(ifname_warning, UserWarning) warnings.warn(ifname_warning, UserWarning)
hostnames = [None for _ in range(world_size)] # Allgather the domain colors across ranks and reduce to a list of unique domains
torch.distributed.all_gather_object(hostnames, hostname, world_group) domain_per_rank_list = [None for _ in range(world_size)]
unique_hosts = [] torch.distributed.all_gather_object(domain_per_rank_list, mydomain, world_group)
for host in hostnames: unique_domains = []
if host not in unique_hosts: for domain in domain_per_rank_list:
unique_hosts.append(host) if domain not in unique_domains:
num_nodes = len(unique_hosts) unique_domains.append(domain)
num_domains = len(unique_domains)
if num_nodes > 1:
ranks_per_node_list = [[] for _ in range(num_nodes)] if num_domains > 1:
self_node_idx = -1 # DP/TP model replicated on multiple NVLink domains
for i, host in enumerate(hostnames): ranks_per_domain_list = [[] for _ in range(num_domains)]
node_idx = unique_hosts.index(host) mydomain_idx = -1
ranks_per_node_list[node_idx].append(i) for i, domain in enumerate(domain_per_rank_list):
if host == hostname: domain_idx = unique_domains.index(domain)
self_node_idx = node_idx ranks_per_domain_list[domain_idx].append(i)
assert self_node_idx >= 0, "Internal TE error!" if domain == mydomain:
mydomain_idx = domain_idx
intra_node_group, _ = torch.distributed.new_subgroups_by_enumeration( assert mydomain_idx >= 0, "Internal TE error!"
ranks_per_node_list, backend=bootstrap_backend
intra_domain_group, _ = torch.distributed.new_subgroups_by_enumeration(
ranks_per_domain_list, backend=bootstrap_backend
)
local_rank = torch.distributed.get_rank(intra_domain_group)
inter_domain_group, _ = torch.distributed.new_subgroups_by_enumeration(
[list(ranks) for ranks in zip(*ranks_per_domain_list)],
backend=bootstrap_backend,
) )
local_rank = torch.distributed.get_rank(intra_node_group)
local_size = torch.distributed.get_world_size(intra_node_group) helper = tex.CommOverlapHelper(world_group, intra_domain_group, inter_domain_group)
intra_node_ranks = torch.distributed.get_process_group_ranks(intra_node_group)
else: else:
self_node_idx = 0 # TP model on single NVLink domain, no replication, no data-parallelism
intra_node_group = world_group mydomain_idx = 0
local_rank = world_rank local_rank = world_rank
local_size = world_size intra_domain_ranks = list(range(world_size))
intra_node_ranks = list(range(world_size))
helper = tex.CommOverlapHelper(world_group)
if world_rank == 0: if world_rank == 0:
print(f"!!! [UB] Number of physical nodes: {num_nodes}\n", end="", flush=True) print(f"!!! [UB] Number of NVLink domains: {num_domains}\n", end="", flush=True)
if local_rank == 0: if local_rank == 0:
print( print(
f"!!! [UB] Global ranks on node {self_node_idx}: {intra_node_ranks}\n", f"!!! [UB] Global ranks on domain {mydomain_idx}: {intra_domain_ranks}\n",
end="", end="",
flush=True, flush=True,
) )
ub_callbacks = tex.UbufBootstrapCallbacks(world_group, intra_node_group)
# Increase the workspace by the number of maximum concurrent streams # Increase the workspace by the number of maximum concurrent streams
global _cublas_workspace global _cublas_workspace
_cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS) _cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS)
...@@ -303,46 +357,34 @@ def initialize_ub( ...@@ -303,46 +357,34 @@ def initialize_ub(
if atomic_gemm and method == "ring_exchange": if atomic_gemm and method == "ring_exchange":
assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message
sample_buffer = torch.empty( buffer_dtype = torch.uint8 if (use_fp8 and fp8_buf) else dtype
shape, dtype=torch.uint8 if (use_fp8 and fp8_buf) else dtype, device="cuda"
)
if method == "ring_exchange": if method == "ring_exchange":
ub_obj = tex.UbufP2PCommOverlap( ub_obj = tex.CommOverlapP2P(
sample_buffer, # Sample userbuffer shape, # Communication buffer shape
world_rank, # World rank buffer_dtype, # Communication buffer data type
world_size, # World size helper, # Helper for torch.distributed callbacks during bootstrapping
local_rank, # Rank within the node
local_size, # Number of ranks/GPUs per node
self_node_idx, # Node ID
num_nodes, # Number of nodes
tp_size, # Tensor-parallel group size (may be different than local_size) tp_size, # Tensor-parallel group size (may be different than local_size)
num_sm, # Number of communication SMs tex.CommOverlapType.RS if is_reduce_scatter else tex.CommOverlapType.AG,
cga_size, # CGA cluster size num_max_streams=_NUM_MAX_UB_STREAMS,
set_sm_margin, # Set SM margin comm_cga_size=cga_size,
aggregate, # Aggregate 2X GEMM chunks num_comm_sm=num_sm,
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams set_sm_margin=set_sm_margin,
is_reduce_scatter, # Overlap with reduce scatter atomic_gemm=atomic_gemm,
atomic_gemm, # Use a single GEMM with atomic-counters use_ce=use_ce,
use_ce, # Use copy engine for P2P communications aggregate=aggregate,
ub_callbacks,
) )
else: else:
ub_obj = tex.UbufCommOverlap( ub_obj = tex.CommOverlap(
sample_buffer, # Sample userbuffer shape, # Communication buffer shape
world_rank, # World rank buffer_dtype, # Communication buffer data type
world_size, # World size helper, # Helper for torch.distributed callbacks during bootstrapping
local_rank, # Rank within the node
local_size, # Number of ranks/GPUs per node
self_node_idx, # Node ID
num_nodes, # Number of nodes
tp_size, # Tensor-parallel group size (may be different than local_size) tp_size, # Tensor-parallel group size (may be different than local_size)
num_sm, # Number of communication SMs num_splits=num_splits,
cga_size, # CGA cluster size num_max_streams=_NUM_MAX_UB_STREAMS,
num_splits, # Number of communication splits comm_cga_size=cga_size,
set_sm_margin, # Set SM margin num_comm_sm=num_sm,
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams set_sm_margin=set_sm_margin,
atomic_gemm, # Use a single GEMM with atomic-counters atomic_gemm=atomic_gemm,
ub_callbacks,
) )
_ub_communicators[name] = ub_obj _ub_communicators[name] = ub_obj
......
...@@ -161,9 +161,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -161,9 +161,9 @@ class _LayerNormLinear(torch.autograd.Function):
if not return_layernorm_output: if not return_layernorm_output:
ln_out = torch.empty_like(ln_out) ln_out = torch.empty_like(ln_out)
if ub_obj_lnout.is_atomic_gemm(): if ub_obj_lnout.is_atomic_gemm():
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P
else: else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
elif parallel_mode == "column" and sequence_parallel: elif parallel_mode == "column" and sequence_parallel:
ln_out_gathered = True ln_out_gathered = True
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
...@@ -293,7 +293,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -293,7 +293,7 @@ class _LayerNormLinear(torch.autograd.Function):
get_workspace(), get_workspace(),
bias=bias, bias=bias,
use_bias=use_bias, use_bias=use_bias,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, ub_algo=tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None,
ub=ub_obj_lnout if ub_overlap_ag else None, ub=ub_obj_lnout if ub_overlap_ag else None,
extra_output_tensor=ln_out if ub_overlap_ag else None, extra_output_tensor=ln_out if ub_overlap_ag else None,
) )
...@@ -485,7 +485,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -485,7 +485,7 @@ class _LayerNormLinear(torch.autograd.Function):
rs_out = None rs_out = None
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG
ub_obj = ub_obj_lnout ub_obj = ub_obj_lnout
elif ctx.ub_overlap_rs_dgrad: elif ctx.ub_overlap_rs_dgrad:
dim_size = list(grad_output.size()) dim_size = list(grad_output.size())
...@@ -496,14 +496,14 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -496,14 +496,14 @@ class _LayerNormLinear(torch.autograd.Function):
) )
if ub_obj_dgrad.is_p2p_overlap(): if ub_obj_dgrad.is_p2p_overlap():
if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm(): if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm():
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P
else: else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P
else: else:
if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm(): if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm():
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS
else: else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS
ub_obj = ub_obj_dgrad ub_obj = ub_obj_dgrad
else: else:
ub_algo = None ub_algo = None
...@@ -616,7 +616,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -616,7 +616,7 @@ class _LayerNormLinear(torch.autograd.Function):
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD, use_split_accumulator=_2X_ACC_WGRAD,
ub_algo=( ub_algo=(
tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None
), ),
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
extra_output_tensor=extra_output_tensor, extra_output_tensor=extra_output_tensor,
...@@ -640,7 +640,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -640,7 +640,7 @@ class _LayerNormLinear(torch.autograd.Function):
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
ub_algo=( ub_algo=(
tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None
), ),
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
extra_output_tensor=extra_output_tensor, extra_output_tensor=extra_output_tensor,
...@@ -658,7 +658,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -658,7 +658,7 @@ class _LayerNormLinear(torch.autograd.Function):
use_bias=ctx.use_bias, use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, ub_algo=tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
) )
clear_tensor_data(ln_out_total) clear_tensor_data(ln_out_total)
......
...@@ -180,9 +180,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -180,9 +180,9 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total = ub_obj_lnout.get_ubuf_output(1) ln_out_total = ub_obj_lnout.get_ubuf_output(1)
ln_out = torch.empty_like(ln_out) ln_out = torch.empty_like(ln_out)
if ub_obj_lnout.is_atomic_gemm(): if ub_obj_lnout.is_atomic_gemm():
ub_algo_ag = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P ub_algo_ag = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P
else: else:
ub_algo_ag = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P ub_algo_ag = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
elif set_parallel_mode and sequence_parallel: elif set_parallel_mode and sequence_parallel:
ln_out_gathered = True ln_out_gathered = True
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
...@@ -298,14 +298,14 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -298,14 +298,14 @@ class _LayerNormMLP(torch.autograd.Function):
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
if ub_obj_fc2out.is_p2p_overlap(): if ub_obj_fc2out.is_p2p_overlap():
if ub_obj_fc2out.is_atomic_gemm(): if ub_obj_fc2out.is_atomic_gemm():
ub_algo_rs = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P ub_algo_rs = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P
else: else:
ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P
else: else:
if ub_obj_fc2out.is_atomic_gemm(): if ub_obj_fc2out.is_atomic_gemm():
ub_algo_rs = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS ub_algo_rs = tex.CommOverlapAlgo.ATOMIC_GEMM_RS
else: else:
ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS
if ub_obj_fc2out.is_fp8_ubuf(): if ub_obj_fc2out.is_fp8_ubuf():
fc2_out_index = tex.FP8FwdTensors.GEMM2_OUTPUT fc2_out_index = tex.FP8FwdTensors.GEMM2_OUTPUT
...@@ -369,7 +369,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -369,7 +369,7 @@ class _LayerNormMLP(torch.autograd.Function):
bias=fc1_bias, bias=fc1_bias,
use_bias=(not bias_gelu_nvfusion) and use_fc1_bias, use_bias=(not bias_gelu_nvfusion) and use_fc1_bias,
gelu=not bias_gelu_nvfusion and (activation == "gelu"), gelu=not bias_gelu_nvfusion and (activation == "gelu"),
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, ub_algo=tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None,
ub=ub_obj_lnout if ub_overlap_ag else None, ub=ub_obj_lnout if ub_overlap_ag else None,
extra_output_tensor=ln_out if ub_overlap_ag else None, extra_output_tensor=ln_out if ub_overlap_ag else None,
) )
...@@ -410,9 +410,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -410,9 +410,9 @@ class _LayerNormMLP(torch.autograd.Function):
dim_size[1] = fc2_weight.size(0) dim_size[1] = fc2_weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
if ub_obj_fc2out.is_p2p_overlap(): if ub_obj_fc2out.is_p2p_overlap():
ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P
else: else:
ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS
else: else:
dim_size = list(gelu_out.size()) dim_size = list(gelu_out.size())
dim_size[1] = fc2_weight.size(0) dim_size[1] = fc2_weight.size(0)
...@@ -615,9 +615,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -615,9 +615,9 @@ class _LayerNormMLP(torch.autograd.Function):
dim_size[0] = dim_size[0] * tp_world_size dim_size[0] = dim_size[0] * tp_world_size
ctx.ub_obj_gradout = get_ub("fc2_dgrad") ctx.ub_obj_gradout = get_ub("fc2_dgrad")
if ctx.ub_obj_gradout.is_atomic_gemm(): if ctx.ub_obj_gradout.is_atomic_gemm():
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P
else: else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess
( (
...@@ -788,7 +788,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -788,7 +788,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap # Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap
rs_out = None rs_out = None
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG
ub_obj = ub_obj_lnout ub_obj = ub_obj_lnout
elif ctx.ub_overlap_rs_dgrad: elif ctx.ub_overlap_rs_dgrad:
dim_size = list(dgelu.size()) dim_size = list(dgelu.size())
...@@ -797,14 +797,14 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -797,14 +797,14 @@ class _LayerNormMLP(torch.autograd.Function):
rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, device=dgelu.device) rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, device=dgelu.device)
if ub_obj_dgrad.is_p2p_overlap(): if ub_obj_dgrad.is_p2p_overlap():
if ub_obj_dgrad.is_atomic_gemm(): if ub_obj_dgrad.is_atomic_gemm():
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P
else: else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P
else: else:
if ub_obj_dgrad.is_atomic_gemm(): if ub_obj_dgrad.is_atomic_gemm():
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS
else: else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS
ub_obj = ub_obj_dgrad ub_obj = ub_obj_dgrad
else: else:
ub_algo = None ub_algo = None
...@@ -842,7 +842,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -842,7 +842,7 @@ class _LayerNormMLP(torch.autograd.Function):
grad=True, grad=True,
gelu_input=fc1_out, gelu_input=fc1_out,
ub_algo=( ub_algo=(
tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ctx.ub_overlap_ag else None tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ctx.ub_overlap_ag else None
), ),
ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None,
) )
...@@ -892,7 +892,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -892,7 +892,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap # Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG
ub_obj = ub_obj_lnout ub_obj = ub_obj_lnout
elif ctx.ub_overlap_rs_dgrad: elif ctx.ub_overlap_rs_dgrad:
dim_size = list(dgelu.size()) dim_size = list(dgelu.size())
...@@ -900,9 +900,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -900,9 +900,9 @@ class _LayerNormMLP(torch.autograd.Function):
dim_size[1] = fc1_weight.size(1) dim_size[1] = fc1_weight.size(1)
rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, device=dgelu.device) rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, device=dgelu.device)
if ub_obj_dgrad.is_p2p_overlap(): if ub_obj_dgrad.is_p2p_overlap():
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P
else: else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS
ub_obj = ub_obj_dgrad ub_obj = ub_obj_dgrad
else: else:
ub_algo = None ub_algo = None
...@@ -967,7 +967,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -967,7 +967,7 @@ class _LayerNormMLP(torch.autograd.Function):
out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD, use_split_accumulator=_2X_ACC_WGRAD,
ub_algo=( ub_algo=(
tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None
), ),
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
extra_output_tensor=extra_output_tensor, extra_output_tensor=extra_output_tensor,
...@@ -991,7 +991,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -991,7 +991,7 @@ class _LayerNormMLP(torch.autograd.Function):
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
ub_algo=( ub_algo=(
tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None
), ),
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
extra_output_tensor=extra_output_tensor, extra_output_tensor=extra_output_tensor,
...@@ -1009,7 +1009,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1009,7 +1009,7 @@ class _LayerNormMLP(torch.autograd.Function):
use_bias=not ctx.bias_gelu_nvfusion, use_bias=not ctx.bias_gelu_nvfusion,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, ub_algo=tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
) )
clear_tensor_data(ln_out_total, dgelu) clear_tensor_data(ln_out_total, dgelu)
......
...@@ -190,14 +190,14 @@ class _Linear(torch.autograd.Function): ...@@ -190,14 +190,14 @@ class _Linear(torch.autograd.Function):
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
if ub_obj_projout.is_p2p_overlap(): if ub_obj_projout.is_p2p_overlap():
if ub_obj_projout.is_atomic_gemm(): if ub_obj_projout.is_atomic_gemm():
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P
else: else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P
else: else:
if ub_obj_projout.is_atomic_gemm(): if ub_obj_projout.is_atomic_gemm():
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS
else: else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS
if ub_obj_projout.is_fp8_ubuf(): if ub_obj_projout.is_fp8_ubuf():
proj_out_index = tex.FP8FwdTensors.GEMM1_OUTPUT proj_out_index = tex.FP8FwdTensors.GEMM1_OUTPUT
meta_tensor = fp8_meta["scaling_fwd"] meta_tensor = fp8_meta["scaling_fwd"]
...@@ -269,9 +269,9 @@ class _Linear(torch.autograd.Function): ...@@ -269,9 +269,9 @@ class _Linear(torch.autograd.Function):
dim_size[1] = out_features dim_size[1] = out_features
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
if ub_obj_projout.is_p2p_overlap(): if ub_obj_projout.is_p2p_overlap():
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P
else: else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS
else: else:
dim_size = list(inputmat_total.size()) dim_size = list(inputmat_total.size())
dim_size[1] = out_features dim_size[1] = out_features
...@@ -407,9 +407,9 @@ class _Linear(torch.autograd.Function): ...@@ -407,9 +407,9 @@ class _Linear(torch.autograd.Function):
dim_size[0] = dim_size[0] * tp_world_size dim_size[0] = dim_size[0] * tp_world_size
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
if ctx.ub_obj_gradout.is_atomic_gemm(): if ctx.ub_obj_gradout.is_atomic_gemm():
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P
else: else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
( (
grad_output, grad_output,
...@@ -496,7 +496,7 @@ class _Linear(torch.autograd.Function): ...@@ -496,7 +496,7 @@ class _Linear(torch.autograd.Function):
layout="NN", layout="NN",
grad=True, grad=True,
ub_algo=( ub_algo=(
tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
if ctx.ub_overlap_ag if ctx.ub_overlap_ag
else None else None
), ),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment