Unverified Commit 933294dc authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[C/PyTorch] Userbuffers and comm+GEMM overlap algorithms refactored and moved to TE/common (#1067)



* moved userbuffers code to TE/common
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* moved comm+GEMM overlap code to TE/common
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* removed PyTorch depdency from comm+GEMM overlap in TE/common
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added TE/PyTorch wrappers for refactored comm+GEMM overlap code in TE/common
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* updated TE/PyTorch Python API to match the refactored comm+GEMM overlap code
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* updated unit tests to work with refactored comm+GEMM overlap code
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added a pylint exception to comm+GEMM overlap test runner
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixing linting errors
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* added documentation for te.initialize_ub
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed compile errors when building with NVTE_UB_WITH_MPI=1
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed default bootstrap backend
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* switched default bootstrap backend priority to MPI > Gloo > NCCL
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* updated bootstrap backend documentation
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* close UB bootstrap socket to avoid interfering with CUDA Multicast shareable file handle send/recv
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added torch::Tensor wrappers for communication buffer and atomic counters so PyTorch can factor externally allocated memory into its garbage collection threshold
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* automated handling of world, local and node ranks/sizes within C++ CommOverlapHelper to simplify Python function signatures
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed incorrect read of environment variables
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected priority for _SOCKET_IFNAME environment variables in UB bootstrapping
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* moved multicast support check to cuda_runtime.h and replaced cudaDeviceGetProp call with cached sm_count()
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* removed commented out old code and replaced external collective function type defines with aliases
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* compile-time CUDA version guard for CUDA Driver Multicast attribute
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added compile-time CUDA version guards to Multicast code in Userbuffers
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* condensed UB docs, corrected const violations
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed autodoc rst for UB calls, added CUDA version guard on Multicast UB kernels
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed incorrect UB type reporting for P2P overlaps, comment reformatting
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* add docstring to tex.ubuf_built_with_mpi()
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 35bbe740
......@@ -45,8 +45,8 @@ def fp8_gemm(
use_bias: bool = False,
use_split_accumulator: bool = False,
D_dtype: Optional[tex.DType] = None,
ub_algo: tex.UbufOverlapAlgo = None,
ub: Union[tex.UbufCommOverlap, tex.UbufP2PCommOverlap] = None,
ub_algo: tex.CommOverlapAlgo = None,
ub: Union[tex.CommOverlap, tex.CommOverlapP2P] = None,
extra_output_tensor: torch.Tensor = None,
) -> torch.Tensor:
"""TN layout GEMM with fp8 inputs."""
......@@ -107,7 +107,7 @@ def fp8_gemm(
fn = torch.ops.tex_ts.te_gemm_ts
if ub_algo is not None:
assert ub is not None, "ub object is None!"
if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG:
if ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_AG:
fn = ub.bulk_overlap
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
......@@ -115,11 +115,11 @@ def fp8_gemm(
args = tuple(
args
+ (
1,
tex.CommOverlapType.AG,
extra_output_tensor,
)
)
elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
elif ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_RS:
fn = ub.bulk_overlap
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
......@@ -127,23 +127,23 @@ def fp8_gemm(
args = tuple(
args
+ (
0,
tex.CommOverlapType.RS,
extra_output_tensor,
)
)
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P:
elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P:
fn = ub.split_overlap_ag_p2p
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P:
elif ub_algo == tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P:
fn = ub.atomic_gemm_overlap_ag_p2p
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS:
elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS:
fn = ub.split_overlap_rs
assert (
extra_output_tensor is not None
......@@ -155,13 +155,13 @@ def fp8_gemm(
extra_output_tensor,
)
)
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P:
elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P:
fn = ub.split_overlap_rs_p2p
assert (
extra_output_tensor is not None
), "SPLIT_PIPELINED_RS_P2P requires extra output tensor"
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS:
elif ub_algo == tex.CommOverlapAlgo.ATOMIC_GEMM_RS:
fn = ub.atomic_gemm_overlap_rs
assert extra_output_tensor is not None, "ATOMIC_GEMM_RS requires extra output tensor"
args = tuple(
......@@ -171,15 +171,12 @@ def fp8_gemm(
extra_output_tensor,
)
)
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P:
elif ub_algo == tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P:
fn = ub.atomic_gemm_overlap_rs_p2p
assert (
extra_output_tensor is not None
), "ATOMIC_GEMM_RS_P2P requires extra output tensor"
args = tuple(args + (extra_output_tensor,))
if ub_algo is not None and ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P:
out = fn(*args)
else:
_ = fn(*args)
return out, gelu_input
......@@ -198,8 +195,8 @@ def gemm(
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
use_bias: bool = False,
ub_algo: tex.UbufOverlapAlgo = None,
ub: tex.UbufCommOverlap = None,
ub_algo: tex.CommOverlapAlgo = None,
ub: Union[tex.CommOverlap, tex.CommOverlapP2P] = None,
extra_output_tensor: torch.Tensor = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Non FP8 GEMM."""
......@@ -270,19 +267,19 @@ def gemm(
fn = torch.ops.tex_ts.te_gemm_ts
if ub_algo is not None:
assert ub is not None, "ub object is None!"
if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG:
if ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_AG:
fn = ub.bulk_overlap
args = tuple(args + (1, empty_tensor))
elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
args = tuple(args + (tex.CommOverlapType.AG, empty_tensor))
elif ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_RS:
fn = ub.bulk_overlap
args = tuple(args + (0, empty_tensor))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P:
args = tuple(args + (tex.CommOverlapType.RS, empty_tensor))
elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P:
fn = ub.split_overlap_ag_p2p
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS:
elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS:
fn = ub.split_overlap_rs
assert (
extra_output_tensor is not None
......@@ -294,7 +291,7 @@ def gemm(
extra_output_tensor,
)
)
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P:
elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P:
fn = ub.split_overlap_rs_p2p
assert (
extra_output_tensor is not None
......
This diff is collapsed.
......@@ -24,6 +24,7 @@
#include <transformer_engine/activation.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/cast_transpose_noop.h>
#include <transformer_engine/comm_gemm_overlap.h>
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/fused_rope.h>
#include <transformer_engine/gemm.h>
......@@ -37,12 +38,14 @@
#include <transformer_engine/transpose.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <cassert>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <stdexcept>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <vector>
#include "common/util/logging.h"
......
......@@ -7,6 +7,8 @@
#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_
#include <optional>
#include "common.h"
#include "common/common.h"
......@@ -504,4 +506,184 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
std::vector<size_t> input_row_list,
std::vector<size_t> padded_input_row_list);
/***************************************************************************************************
* Comm+GEMM Overlap Wrappers
**************************************************************************************************/
class CommOverlapHelper : torch::CustomClassHolder {
private:
bool initialized{false};
bool backend_is_nccl{false};
std::map<std::string, c10d::ProcessGroup *> pgs;
public:
int myrank = -1;
int numranks = -1;
int mylocal = -1;
int numlocal = -1;
int mynode = -1;
int numnodes = -1;
CommOverlapHelper();
CommOverlapHelper(c10d::ProcessGroup *world_group,
std::optional<c10d::ProcessGroup *> intra_node_group,
std::optional<c10d::ProcessGroup *> inter_node_group);
~CommOverlapHelper();
void ub_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes,
ExtComm comm);
void ub_barrier(ExtComm comm);
};
class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase {
private:
torch::Tensor _ubuf_torch;
torch::Tensor _ubuf_counter;
public:
CommOverlap(const std::vector<size_t> &buffer_shape, at::ScalarType buffer_dtype,
CommOverlapHelper *helper, int tp_size, int num_splits = 3,
int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2,
int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false);
void set_ubuf_scale_inv(torch::Tensor scale_inv) {
assert(scale_inv.numel());
assert(scale_inv.scalar_type() == torch::kFloat32);
transformer_engine::CommOverlapBase::set_ubuf_scale_inv(
reinterpret_cast<float *>(scale_inv.data_ptr()));
}
void copy_input_to_ubuf(torch::Tensor input, int comm_type);
torch::Tensor get_ubuf_output(int comm_type);
/*
** Bulk GEMM + COMM
** This function assumes the communication input is pre-copied to _ubuf
*/
std::vector<at::Tensor> bulk_overlap(
at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse,
int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias,
transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator,
transformer_engine::CommOverlapType comm_type, at::Tensor rs_output);
/*
** Split FPROP GEMM + ReduceScatter
*/
void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type,
at::Tensor D_amax, at::Tensor bias,
transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, bool gemm_overlap,
at::Tensor rs_output);
/*
** Split FPROP GEMM + ReduceScatter
*/
void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax,
at::Tensor bias, transformer_engine::DType bias_type,
at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator,
bool gemm_overlap, at::Tensor rs_output);
}; // CommOverlap
class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase {
private:
torch::Tensor _ubuf_torch;
torch::Tensor _ubuf_counter;
public:
CommOverlapP2P(const std::vector<size_t> &buffer_shape, at::ScalarType buffer_dtype,
CommOverlapHelper *helper, int tp_size,
transformer_engine::CommOverlapType comm_type,
int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2,
int num_comm_sm = 3, bool set_sm_margin = true, bool atomic_gemm = false,
bool use_ce = true, bool aggregate = false);
void set_ubuf_scale_inv(torch::Tensor scale_inv) {
assert(scale_inv.numel());
assert(scale_inv.scalar_type() == torch::kFloat32);
transformer_engine::CommOverlapP2PBase::set_ubuf_scale_inv(
reinterpret_cast<float *>(scale_inv.data_ptr()));
}
void copy_input_to_ubuf(torch::Tensor input, bool chunk);
torch::Tensor get_ubuf_output(int comm_type);
/*
** Split AllGather + AtomicGEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is
*needed to have AG outputs
** in each rank to be in the contiguous memory space after all ring exchange
*phases.
*/
void atomic_gemm_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type,
at::Tensor D_amax, at::Tensor bias,
transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, at::Tensor B_copy);
/*
** Split AllGather + GEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is
*needed to have AG outputs
** in each rank to be in the contiguous memory space after all ring exchange
*phases.
*/
void split_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax,
at::Tensor bias, transformer_engine::DType bias_type,
at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator,
at::Tensor B_copy);
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type,
at::Tensor D_amax, at::Tensor bias,
transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, at::Tensor rs_output);
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax,
at::Tensor bias, transformer_engine::DType bias_type,
at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator,
at::Tensor rs_output);
}; // CommOverlapP2P
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_
......@@ -4,12 +4,15 @@
* See LICENSE for license information.
************************************************************************/
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "../comm_gemm_overlap.h"
#include "../extensions.h"
#include "common/util/pybind_helper.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m)
// Permutation functions
m.def("moe_permute_fwd", moe_permute_fwd);
m.def("moe_permute_bwd", moe_permute_bwd);
......@@ -226,90 +229,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv)
.def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history);
m.def("device_supports_multicast", &ubuf::device_supports_multicast,
py::call_guard<py::gil_scoped_release>());
m.def("ubuf_built_with_mpi", &ubuf::ubuf_built_with_mpi,
py::call_guard<py::gil_scoped_release>());
py::class_<ubuf::UbufBootstrapCallbacks>(m, "UbufBootstrapCallbacks")
.def(py::init<>(), py::call_guard<py::gil_scoped_release>())
.def(py::init<c10d::ProcessGroup *, c10d::ProcessGroup *>(),
py::call_guard<py::gil_scoped_release>());
py::enum_<ubuf::UBOverlapAlgo>(m, "UbufOverlapAlgo")
.value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG)
.value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS)
.value("SPLIT_PIPELINED_RS", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS)
.value("SPLIT_PIPELINED_RS_P2P", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS_P2P)
.value("SPLIT_PIPELINED_AG_P2P", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG_P2P)
.value("ATOMIC_GEMM_RS", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS)
.value("ATOMIC_GEMM_AG_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_AG_P2P)
.value("ATOMIC_GEMM_RS_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS_P2P);
// Note: Can't release GIL in constructor since it may bootstrap
// communicator with Python functions (e.g. PyTorch distributed
// communication)
py::class_<ubuf::UbufCommOverlap>(m, "UbufCommOverlap")
.def(py::init<torch::Tensor &, int, int, int, int, int, int, int, int, int, int, bool, int,
bool, ubuf::UbufBootstrapCallbacks &>(),
py::call_guard<py::gil_scoped_release>())
.def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap,
py::call_guard<py::gil_scoped_release>())
.def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs,
py::call_guard<py::gil_scoped_release>())
.def("set_ubuf_scale_inv", &ubuf::UbufCommOverlap::set_ubuf_scale_inv,
py::call_guard<py::gil_scoped_release>())
.def("atomic_gemm_overlap_rs", &ubuf::UbufCommOverlap::atomic_gemm_overlap_rs,
py::call_guard<py::gil_scoped_release>())
.def("is_fp8_ubuf", &ubuf::UbufCommOverlap::is_fp8_ubuf,
py::call_guard<py::gil_scoped_release>())
.def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf,
py::call_guard<py::gil_scoped_release>())
.def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output,
py::call_guard<py::gil_scoped_release>())
.def("is_atomic_gemm", &ubuf::UbufCommOverlap::is_atomic_gemm,
py::call_guard<py::gil_scoped_release>())
.def("is_p2p_overlap", &ubuf::UbufCommOverlap::is_p2p_overlap,
py::call_guard<py::gil_scoped_release>());
// Note: Can't release GIL in constructor since it may bootstrap
// communicator with Python functions (e.g. PyTorch distributed
// communication)
py::class_<ubuf::UbufP2PCommOverlap>(m, "UbufP2PCommOverlap")
.def(py::init<torch::Tensor &, int, int, int, int, int, int, int, int, int, bool, bool, int,
bool, bool, bool, ubuf::UbufBootstrapCallbacks &>(),
py::call_guard<py::gil_scoped_release>())
.def("split_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_ag,
py::call_guard<py::gil_scoped_release>())
.def("split_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_rs,
py::call_guard<py::gil_scoped_release>())
.def("atomic_gemm_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_ag,
py::call_guard<py::gil_scoped_release>())
.def("atomic_gemm_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_rs,
py::call_guard<py::gil_scoped_release>())
.def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf,
py::call_guard<py::gil_scoped_release>())
.def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output,
py::call_guard<py::gil_scoped_release>())
.def("is_fp8_ubuf", &ubuf::UbufP2PCommOverlap::is_fp8_ubuf,
py::call_guard<py::gil_scoped_release>())
.def("is_atomic_gemm", &ubuf::UbufP2PCommOverlap::is_atomic_gemm,
py::call_guard<py::gil_scoped_release>())
.def("is_p2p_overlap", &ubuf::UbufP2PCommOverlap::is_p2p_overlap,
py::call_guard<py::gil_scoped_release>())
.def("set_ubuf_scale_inv", &ubuf::UbufP2PCommOverlap::set_ubuf_scale_inv,
py::call_guard<py::gil_scoped_release>());
py::enum_<transformer_engine::DType>(m, "DType", py::module_local())
.value("kByte", transformer_engine::DType::kByte)
.value("kInt32", transformer_engine::DType::kInt32)
.value("kFloat32", transformer_engine::DType::kFloat32)
.value("kFloat16", transformer_engine::DType::kFloat16)
.value("kBFloat16", transformer_engine::DType::kBFloat16)
.value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3)
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2);
py::enum_<transformer_engine::FP8FwdTensors>(m, "FP8FwdTensors")
.value("GEMM1_INPUT", transformer_engine::FP8FwdTensors::GEMM1_INPUT)
.value("GEMM1_WEIGHT", transformer_engine::FP8FwdTensors::GEMM1_WEIGHT)
......@@ -329,41 +248,61 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.value("GRAD_OUTPUT3", transformer_engine::FP8BwdTensors::GRAD_OUTPUT3)
.value("GRAD_INPUT3", transformer_engine::FP8BwdTensors::GRAD_INPUT3);
py::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type")
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS)
.value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS)
.value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)
.value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI);
py::enum_<NVTE_Mask_Type>(m, "NVTE_Mask_Type")
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK)
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK)
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK)
.value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK)
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK",
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK);
py::class_<CommOverlapHelper>(m, "CommOverlapHelper")
.def(py::init<>(), py::call_guard<py::gil_scoped_release>())
.def(py::init<c10d::ProcessGroup *, std::optional<c10d::ProcessGroup *>,
std::optional<c10d::ProcessGroup *>>(),
py::call_guard<py::gil_scoped_release>(), py::arg("world_group"),
py::arg("intra_node_group") = py::none(), py::arg("inter_node_group") = py::none());
py::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout")
.value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD)
.value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D)
.value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD)
.value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D)
.value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD)
.value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD)
.value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D)
.value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD)
.value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D)
.value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)
.value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD)
.value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D)
.value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD)
.value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D)
.value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD);
py::class_<CommOverlap>(m, "CommOverlap")
.def(py::init<const std::vector<size_t> &, at::ScalarType, CommOverlapHelper *, int, int, int,
int, int, bool, bool>(),
py::call_guard<py::gil_scoped_release>(), py::arg("buffer_shape"),
py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"),
py::arg("num_splits") = 3, py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS,
py::arg("comm_cga_size") = 2, py::arg("num_comm_sm") = 16,
py::arg("set_sm_margin") = true, py::arg("atomic_gemm") = false)
.def("bulk_overlap", &CommOverlap::bulk_overlap, py::call_guard<py::gil_scoped_release>())
.def("split_overlap_rs", &CommOverlap::split_overlap_rs,
py::call_guard<py::gil_scoped_release>())
.def("atomic_gemm_overlap_rs", &CommOverlap::atomic_gemm_overlap_rs,
py::call_guard<py::gil_scoped_release>())
.def("copy_input_to_ubuf", &CommOverlap::copy_input_to_ubuf,
py::call_guard<py::gil_scoped_release>())
.def("get_ubuf_output", &CommOverlap::get_ubuf_output,
py::call_guard<py::gil_scoped_release>())
.def("set_ubuf_scale_inv", &CommOverlap::set_ubuf_scale_inv,
py::call_guard<py::gil_scoped_release>())
.def("is_atomic_gemm", &CommOverlap::is_atomic_gemm, py::call_guard<py::gil_scoped_release>())
.def("is_p2p_overlap", &CommOverlap::is_p2p_overlap, py::call_guard<py::gil_scoped_release>())
.def("is_fp8_ubuf", &CommOverlap::is_fp8_ubuf, py::call_guard<py::gil_scoped_release>());
py::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend")
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)
.value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)
.value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8)
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend);
py::class_<CommOverlapP2P>(m, "CommOverlapP2P")
.def(py::init<const std::vector<size_t> &, at::ScalarType, CommOverlapHelper *, int,
transformer_engine::CommOverlapType, int, int, int, bool, bool, bool, bool>(),
py::call_guard<py::gil_scoped_release>(), py::arg("buffer_shape"),
py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), py::arg("comm_type"),
py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 1,
py::arg("num_comm_sm") = 1, py::arg("set_sm_margin") = false,
py::arg("atomic_gemm") = false, py::arg("use_ce") = true, py::arg("aggregate") = false)
.def("split_overlap_ag_p2p", &CommOverlapP2P::split_overlap_ag,
py::call_guard<py::gil_scoped_release>())
.def("split_overlap_rs_p2p", &CommOverlapP2P::split_overlap_rs,
py::call_guard<py::gil_scoped_release>())
.def("atomic_gemm_overlap_ag_p2p", &CommOverlapP2P::atomic_gemm_overlap_ag,
py::call_guard<py::gil_scoped_release>())
.def("atomic_gemm_overlap_rs_p2p", &CommOverlapP2P::atomic_gemm_overlap_rs,
py::call_guard<py::gil_scoped_release>())
.def("copy_input_to_ubuf", &CommOverlapP2P::copy_input_to_ubuf,
py::call_guard<py::gil_scoped_release>())
.def("get_ubuf_output", &CommOverlapP2P::get_ubuf_output,
py::call_guard<py::gil_scoped_release>())
.def("set_ubuf_scale_inv", &CommOverlapP2P::set_ubuf_scale_inv,
py::call_guard<py::gil_scoped_release>())
.def("is_fp8_ubuf", &CommOverlapP2P::is_fp8_ubuf, py::call_guard<py::gil_scoped_release>())
.def("is_atomic_gemm", &CommOverlapP2P::is_atomic_gemm,
py::call_guard<py::gil_scoped_release>())
.def("is_p2p_overlap", &CommOverlapP2P::is_p2p_overlap,
py::call_guard<py::gil_scoped_release>());
}
......@@ -87,9 +87,55 @@ def initialize_ub(
ub_cfgs: Optional[dict] = None,
bootstrap_backend: Union[str, torch.distributed.Backend] = None,
) -> None:
"""Initialize communicators for TP comm overlap using userbuffers."""
r"""
Initialize the Userbuffers communicator for overlapping tensor-parallel communications with
GEMM compute in te.Linear, te.LayerNormLinear and te.LayerNormMLP modules.
Parameters
----------
shape : list
shape of the communication buffer, typically set to be the same as the global shape of
the input tensor to a te.TransformerLayer forward pass, with the sequence and batch
dimensions collapsed together -- i.e.: `(sequence_length * batch_size, hidden_size)`
tp_size : int
number of GPUs in the tensor-parallel process group
use_fp8 : bool = False
allocate the communication buffer for FP8 GEMM inputs/outputs
dtype : torch.dtype = torch.bfloat16
non-FP8 data type of the communication buffer when `use_fp8 = False`
ub_cfgs: dict = None
Configuration dictionary with the structure
```
{
<gemm_name> : {
"method": <"ring_exchange" or "pipeline">,
"is_reduce_scatter": bool,
"num_sm": int,
"cga_size": int,
"set_sm_margin": bool,
"num_splits": int,
"aggregate": bool,
"atomic_gemm": bool,
"use_ce": bool,
"fp8_buf": bool,
}
}
```
for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad",
"proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad",
"fc2_fprop", "fc2_dgrad"]`.
bootstrap_backend : str = None
`torch.distributed` communication backend for the all-gather, broadcast and
barrier collectives during Userbuffers initialization. Not all backends are
valid for every cluster configuration and distributed launch method even if
they are available in PyTorch. When left unset, the initialization prefers
to use the MPI backend, falling back first on Gloo and then NCCL if MPI is
not available. Setting `NVTE_UB_WITH_MPI=1` when building TE overrides this
option and always initializes Userbuffers with direct MPI calls in C++,
which also requires `MPI_HOME=/path/to/mpi/root` to be set at compile time.
"""
if not tex.device_supports_multicast():
assert bool(os.getenv("UB_SKIPMC", "0")), (
assert bool(int(os.getenv("UB_SKIPMC", "0"))), (
"CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with "
+ "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead."
)
......@@ -99,50 +145,52 @@ def initialize_ub(
_ub_communicators = {}
if tex.ubuf_built_with_mpi():
# Userbuffers will ignore all these values when it is built with MPI, so these are just
# placeholders based on an assumption that tp_size covers all devices in a physical node.
# We're bootstrapping with direct calls to MPI in Userbuffers code so we need to force
# an MPI_Init() here by creating a new MPI process group...
assert torch.distributed.is_mpi_available()
mpi_group = torch.distributed.new_group(backend="mpi")
world_rank = torch.distributed.get_rank(mpi_group)
world_size = torch.distributed.get_world_size(mpi_group)
local_rank = world_rank % tp_size
local_size = tp_size
self_node_idx = world_rank // tp_size
num_nodes = world_size // tp_size
ub_callbacks = tex.UbufBootstrapCallbacks()
_ = torch.distributed.new_group(backend="mpi")
helper = tex.CommOverlapHelper()
else:
# Bootstrapping with torch.distributed API, so check backend and construct
# intra/inter-node process groups...
assert (
torch.distributed.is_initialized()
), "torch.distributed must be initialized before Userbuffers"
if bootstrap_backend is None:
bootstrap_backend = "nccl"
if torch.distributed.is_gloo_available():
bootstrap_backend = "gloo"
elif torch.distributed.is_mpi_available():
if torch.distributed.is_mpi_available():
bootstrap_backend = "mpi"
elif torch.distributed.is_gloo_available():
bootstrap_backend = "gloo"
else:
assert bootstrap_backend in ["gloo", "mpi", "nccl"]
assert bootstrap_backend in [
"gloo",
"mpi",
"nccl",
], "Invalid torch.distributed backend for bootstrapping Userbuffers!"
assert torch.distributed.is_backend_available(bootstrap_backend), (
f"PyTorch must be compiled with '{bootstrap_backend}' support in order to "
f"bootstrap Userbuffers with '{bootstrap_backend}' collectives."
)
world_group = torch.distributed.new_group(backend=bootstrap_backend)
world_rank = torch.distributed.get_rank(world_group)
world_size = torch.distributed.get_world_size(world_group)
# Construct an intra-node communicator based on global ranks that share the same hostname
# NOTE: If the user specified a valid network interface for NCCL or GLOO, use the host
# address on that interface instead of the hostname. This can help avoid issues when
# different hosts have the same hostname on Kubernetes clusters.
hostname = socket.gethostname()
# We have single-node NVLink so we can color based on physical node hostnames.
# NOTE: Prefer a network interface defined via the NVTE_UB_SOCKET_IFNAME variable, and
# otherwise fall back on NCCL_SOCKET_IFNAME or GLOO_SOCKET_IFNAME depending on
# the chosen bootstrap backend.
mydomain = socket.gethostname()
ifname = os.getenv(
"NVTE_UB_SOCKET_IFNAME",
os.getenv("NCCL_SOCKET_IFNAME", os.getenv("GLOO_SOCKET_IFNAME")),
"NVTE_UB_SOCKET_IFNAME", os.getenv(f"{bootstrap_backend.upper()}_SOCKET_IFNAME")
)
if ifname is not None:
# Make sure the ifname found in the environment is a valid network interface
if ifname in [name for _, name in socket.if_nameindex()]:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
hostname = socket.inet_ntoa(
mydomain = socket.inet_ntoa(
fcntl.ioctl(
s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8"))
)[20:24]
......@@ -154,57 +202,63 @@ def initialize_ub(
else:
ifname_warning = (
f"'{ifname}' is not a valid network interface! `te.initialize_ub()` will"
" attempt to "
+ "detect ranks on the same node by matching 'socket.gethostname()', which is "
+ "known to fail on virtual clusters like Kubernetes. If Userbuffers "
+ "initialization fails, please set the 'NVTE_UB_SOCKET_IFNAME' variable in "
+ "your environment to the correct network interface."
+ " attempt to detect ranks on the same node by matching "
+ "'socket.gethostname()', which is known to fail on virtual clusters like "
+ "Kubernetes. If Userbuffers initialization fails, please set the "
+ "'NVTE_UB_SOCKET_IFNAME' variable in your environment to the correct network "
+ "interface."
)
warnings.warn(ifname_warning, UserWarning)
hostnames = [None for _ in range(world_size)]
torch.distributed.all_gather_object(hostnames, hostname, world_group)
unique_hosts = []
for host in hostnames:
if host not in unique_hosts:
unique_hosts.append(host)
num_nodes = len(unique_hosts)
if num_nodes > 1:
ranks_per_node_list = [[] for _ in range(num_nodes)]
self_node_idx = -1
for i, host in enumerate(hostnames):
node_idx = unique_hosts.index(host)
ranks_per_node_list[node_idx].append(i)
if host == hostname:
self_node_idx = node_idx
assert self_node_idx >= 0, "Internal TE error!"
intra_node_group, _ = torch.distributed.new_subgroups_by_enumeration(
ranks_per_node_list, backend=bootstrap_backend
# Allgather the domain colors across ranks and reduce to a list of unique domains
domain_per_rank_list = [None for _ in range(world_size)]
torch.distributed.all_gather_object(domain_per_rank_list, mydomain, world_group)
unique_domains = []
for domain in domain_per_rank_list:
if domain not in unique_domains:
unique_domains.append(domain)
num_domains = len(unique_domains)
if num_domains > 1:
# DP/TP model replicated on multiple NVLink domains
ranks_per_domain_list = [[] for _ in range(num_domains)]
mydomain_idx = -1
for i, domain in enumerate(domain_per_rank_list):
domain_idx = unique_domains.index(domain)
ranks_per_domain_list[domain_idx].append(i)
if domain == mydomain:
mydomain_idx = domain_idx
assert mydomain_idx >= 0, "Internal TE error!"
intra_domain_group, _ = torch.distributed.new_subgroups_by_enumeration(
ranks_per_domain_list, backend=bootstrap_backend
)
local_rank = torch.distributed.get_rank(intra_domain_group)
inter_domain_group, _ = torch.distributed.new_subgroups_by_enumeration(
[list(ranks) for ranks in zip(*ranks_per_domain_list)],
backend=bootstrap_backend,
)
local_rank = torch.distributed.get_rank(intra_node_group)
local_size = torch.distributed.get_world_size(intra_node_group)
intra_node_ranks = torch.distributed.get_process_group_ranks(intra_node_group)
helper = tex.CommOverlapHelper(world_group, intra_domain_group, inter_domain_group)
else:
self_node_idx = 0
intra_node_group = world_group
# TP model on single NVLink domain, no replication, no data-parallelism
mydomain_idx = 0
local_rank = world_rank
local_size = world_size
intra_node_ranks = list(range(world_size))
intra_domain_ranks = list(range(world_size))
helper = tex.CommOverlapHelper(world_group)
if world_rank == 0:
print(f"!!! [UB] Number of physical nodes: {num_nodes}\n", end="", flush=True)
print(f"!!! [UB] Number of NVLink domains: {num_domains}\n", end="", flush=True)
if local_rank == 0:
print(
f"!!! [UB] Global ranks on node {self_node_idx}: {intra_node_ranks}\n",
f"!!! [UB] Global ranks on domain {mydomain_idx}: {intra_domain_ranks}\n",
end="",
flush=True,
)
ub_callbacks = tex.UbufBootstrapCallbacks(world_group, intra_node_group)
# Increase the workspace by the number of maximum concurrent streams
global _cublas_workspace
_cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS)
......@@ -303,46 +357,34 @@ def initialize_ub(
if atomic_gemm and method == "ring_exchange":
assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message
sample_buffer = torch.empty(
shape, dtype=torch.uint8 if (use_fp8 and fp8_buf) else dtype, device="cuda"
)
buffer_dtype = torch.uint8 if (use_fp8 and fp8_buf) else dtype
if method == "ring_exchange":
ub_obj = tex.UbufP2PCommOverlap(
sample_buffer, # Sample userbuffer
world_rank, # World rank
world_size, # World size
local_rank, # Rank within the node
local_size, # Number of ranks/GPUs per node
self_node_idx, # Node ID
num_nodes, # Number of nodes
ub_obj = tex.CommOverlapP2P(
shape, # Communication buffer shape
buffer_dtype, # Communication buffer data type
helper, # Helper for torch.distributed callbacks during bootstrapping
tp_size, # Tensor-parallel group size (may be different than local_size)
num_sm, # Number of communication SMs
cga_size, # CGA cluster size
set_sm_margin, # Set SM margin
aggregate, # Aggregate 2X GEMM chunks
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams
is_reduce_scatter, # Overlap with reduce scatter
atomic_gemm, # Use a single GEMM with atomic-counters
use_ce, # Use copy engine for P2P communications
ub_callbacks,
tex.CommOverlapType.RS if is_reduce_scatter else tex.CommOverlapType.AG,
num_max_streams=_NUM_MAX_UB_STREAMS,
comm_cga_size=cga_size,
num_comm_sm=num_sm,
set_sm_margin=set_sm_margin,
atomic_gemm=atomic_gemm,
use_ce=use_ce,
aggregate=aggregate,
)
else:
ub_obj = tex.UbufCommOverlap(
sample_buffer, # Sample userbuffer
world_rank, # World rank
world_size, # World size
local_rank, # Rank within the node
local_size, # Number of ranks/GPUs per node
self_node_idx, # Node ID
num_nodes, # Number of nodes
ub_obj = tex.CommOverlap(
shape, # Communication buffer shape
buffer_dtype, # Communication buffer data type
helper, # Helper for torch.distributed callbacks during bootstrapping
tp_size, # Tensor-parallel group size (may be different than local_size)
num_sm, # Number of communication SMs
cga_size, # CGA cluster size
num_splits, # Number of communication splits
set_sm_margin, # Set SM margin
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams
atomic_gemm, # Use a single GEMM with atomic-counters
ub_callbacks,
num_splits=num_splits,
num_max_streams=_NUM_MAX_UB_STREAMS,
comm_cga_size=cga_size,
num_comm_sm=num_sm,
set_sm_margin=set_sm_margin,
atomic_gemm=atomic_gemm,
)
_ub_communicators[name] = ub_obj
......
......@@ -161,9 +161,9 @@ class _LayerNormLinear(torch.autograd.Function):
if not return_layernorm_output:
ln_out = torch.empty_like(ln_out)
if ub_obj_lnout.is_atomic_gemm():
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P
ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P
else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P
ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
elif parallel_mode == "column" and sequence_parallel:
ln_out_gathered = True
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
......@@ -293,7 +293,7 @@ class _LayerNormLinear(torch.autograd.Function):
get_workspace(),
bias=bias,
use_bias=use_bias,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None,
ub_algo=tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None,
ub=ub_obj_lnout if ub_overlap_ag else None,
extra_output_tensor=ln_out if ub_overlap_ag else None,
)
......@@ -485,7 +485,7 @@ class _LayerNormLinear(torch.autograd.Function):
rs_out = None
if ctx.ub_bulk_dgrad:
ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG
ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG
ub_obj = ub_obj_lnout
elif ctx.ub_overlap_rs_dgrad:
dim_size = list(grad_output.size())
......@@ -496,14 +496,14 @@ class _LayerNormLinear(torch.autograd.Function):
)
if ub_obj_dgrad.is_p2p_overlap():
if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm():
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P
ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P
else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P
ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P
else:
if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm():
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS
ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS
else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS
ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS
ub_obj = ub_obj_dgrad
else:
ub_algo = None
......@@ -616,7 +616,7 @@ class _LayerNormLinear(torch.autograd.Function):
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD,
ub_algo=(
tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None
tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None
),
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
extra_output_tensor=extra_output_tensor,
......@@ -640,7 +640,7 @@ class _LayerNormLinear(torch.autograd.Function):
accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
ub_algo=(
tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None
tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None
),
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
extra_output_tensor=extra_output_tensor,
......@@ -658,7 +658,7 @@ class _LayerNormLinear(torch.autograd.Function):
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None,
ub_algo=tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
)
clear_tensor_data(ln_out_total)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment