Unverified Commit 933294dc authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[C/PyTorch] Userbuffers and comm+GEMM overlap algorithms refactored and moved to TE/common (#1067)



* moved userbuffers code to TE/common
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* moved comm+GEMM overlap code to TE/common
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* removed PyTorch depdency from comm+GEMM overlap in TE/common
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added TE/PyTorch wrappers for refactored comm+GEMM overlap code in TE/common
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* updated TE/PyTorch Python API to match the refactored comm+GEMM overlap code
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* updated unit tests to work with refactored comm+GEMM overlap code
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added a pylint exception to comm+GEMM overlap test runner
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixing linting errors
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* added documentation for te.initialize_ub
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed compile errors when building with NVTE_UB_WITH_MPI=1
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed default bootstrap backend
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* switched default bootstrap backend priority to MPI > Gloo > NCCL
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* updated bootstrap backend documentation
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* close UB bootstrap socket to avoid interfering with CUDA Multicast shareable file handle send/recv
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added torch::Tensor wrappers for communication buffer and atomic counters so PyTorch can factor externally allocated memory into its garbage collection threshold
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* automated handling of world, local and node ranks/sizes within C++ CommOverlapHelper to simplify Python function signatures
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed incorrect read of environment variables
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected priority for _SOCKET_IFNAME environment variables in UB bootstrapping
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* moved multicast support check to cuda_runtime.h and replaced cudaDeviceGetProp call with cached sm_count()
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* removed commented out old code and replaced external collective function type defines with aliases
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* compile-time CUDA version guard for CUDA Driver Multicast attribute
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added compile-time CUDA version guards to Multicast code in Userbuffers
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* condensed UB docs, corrected const violations
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed autodoc rst for UB calls, added CUDA version guard on Multicast UB kernels
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed incorrect UB type reporting for P2P overlaps, comment reformatting
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* add docstring to tex.ubuf_built_with_mpi()
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 35bbe740
......@@ -38,3 +38,4 @@ dist/
downloads/
.pytest_cache/
compile_commands.json
.nfs
......@@ -11,7 +11,6 @@ import setuptools
from .utils import (
all_files_in_dir,
cuda_archs,
cuda_path,
cuda_version,
)
......@@ -29,9 +28,6 @@ def setup_pytorch_extension(
sources = [
csrc_source_files / "common.cu",
csrc_source_files / "ts_fp8_op.cpp",
csrc_source_files / "userbuffers" / "ipcsocket.cc",
csrc_source_files / "userbuffers" / "userbuffers.cu",
csrc_source_files / "userbuffers" / "userbuffers-host.cpp",
] + all_files_in_dir(extensions_dir)
# Header files
......@@ -85,19 +81,14 @@ def setup_pytorch_extension(
continue # Already handled
nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"])
# Libraries
library_dirs = []
libraries = []
if bool(int(os.getenv("NVTE_UB_WITH_MPI", 0))):
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
assert (
os.getenv("MPI_HOME") is not None
), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
mpi_home = Path(os.getenv("MPI_HOME"))
include_dirs.append(mpi_home / "include")
), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!"
mpi_path = Path(os.getenv("MPI_HOME"))
include_dirs.append(mpi_path / "include")
cxx_flags.append("-DNVTE_UB_WITH_MPI")
nvcc_flags.append("-DNVTE_UB_WITH_MPI")
library_dirs.append(mpi_home / "lib")
libraries.append("mpi")
# Construct PyTorch CUDA extension
sources = [str(path) for path in sources]
......@@ -112,6 +103,4 @@ def setup_pytorch_extension(
"cxx": cxx_flags,
"nvcc": nvcc_flags,
},
libraries=[str(lib) for lib in libraries],
library_dirs=[str(lib_dir) for lib_dir in library_dirs],
)
......@@ -51,3 +51,7 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.moe_permute
.. autoapifunction:: transformer_engine.pytorch.moe_unpermute
.. autoapifunction:: transformer_engine.pytorch.initialize_ub
.. autoapifunction:: transformer_engine.pytorch.destroy_ub
......@@ -57,13 +57,20 @@ class TimedBdist(bdist_wheel):
def setup_common_extension() -> CMakeExtension:
"""Setup CMake extension for common library"""
cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(cuda_archs())]
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
assert (
os.getenv("MPI_HOME") is not None
), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
cmake_flags.append("-DNVTE_UB_WITH_MPI=ON")
# Project directory root
root_path = Path(__file__).resolve().parent
return CMakeExtension(
name="transformer_engine",
cmake_path=root_path / Path("transformer_engine/common"),
cmake_flags=["-DCMAKE_CUDA_ARCHITECTURES={}".format(cuda_archs())],
cmake_flags=cmake_flags,
)
......
......@@ -22,6 +22,7 @@ import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.common.recipe import Format
from transformer_engine.pytorch.fp8 import _default_sf_compute
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
......@@ -32,8 +33,8 @@ torch_dtypes = {
}
nvte_comm_types = {
"rs": 0,
"ag": 1,
"rs": tex.CommOverlapType.RS,
"ag": tex.CommOverlapType.AG,
}
......@@ -75,7 +76,7 @@ def _parse_args(argv=None, namespace=None):
parser.add_argument(
"--comm-type",
type=partial(_mapped_argtype, typemap=nvte_comm_types),
default=0,
default=tex.CommOverlapType.AG,
help="Comm type to overlap.",
)
parser.add_argument(
......@@ -156,11 +157,9 @@ def _parse_args(argv=None, namespace=None):
if opts.fp8:
warnings.warn("Bulk overlap is supported in FP8 but only tested in BF16.")
opts.fp8 = False
elif opts.comm_type == 1:
elif opts.comm_type == tex.CommOverlapType.AG:
if opts.atomic:
setattr(opts, "atomic_rs_p2p", opts.p2p)
if not opts.p2p:
warnings.warn("All-gather overlap is only supported with point-2-point comms.")
opts.p2p = True
if opts.atomic:
......@@ -283,35 +282,35 @@ def _main(opts):
if WORLD_RANK == 0:
print("\n", end="", flush=True)
ub_callbacks = (
tex.UbufBootstrapCallbacks()
helper = (
tex.CommOverlapHelper()
if tex.ubuf_built_with_mpi()
else tex.UbufBootstrapCallbacks(bootstrap_pg, bootstrap_pg)
else tex.CommOverlapHelper(bootstrap_pg)
)
if opts.comm_type == 0:
if opts.comm_type == tex.CommOverlapType.RS:
if opts.bulk_overlap:
ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_RS
ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_RS
elif opts.p2p:
ub_algo = (
tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P
tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P
if opts.atomic
else tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P
else tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P
)
else:
ub_algo = (
tex.UbufOverlapAlgo.ATOMIC_GEMM_RS
tex.CommOverlapAlgo.ATOMIC_GEMM_RS
if opts.atomic
else tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS
else tex.CommOverlapAlgo.SPLIT_PIPELINED_RS
)
elif opts.comm_type == 1:
elif opts.comm_type == tex.CommOverlapType.AG:
if opts.bulk_overlap:
ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG
ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG
else:
ub_algo = (
tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P
tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P
if opts.atomic
else tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P
else tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
)
else:
raise TypeError("Invalid comm+GEMM overlap type!")
......@@ -322,95 +321,55 @@ def _main(opts):
hidden_size = opts.num_heads * opts.head_dim
inp_shape = (opts.seq_length, opts.batch_size, hidden_size)
outer_size = reduce(operator.mul, inp_shape[:-1], 1)
ubuf_dtype = torch.bfloat16
if opts.fp8 and not opts.bulk_overlap and (opts.comm_type == 1 or opts.fp8_output):
ubuf_dtype = torch.uint8
sample_buffer = torch.empty((outer_size, hidden_size), dtype=ubuf_dtype, device="cuda")
ub_obj = ub_obj = (
tex.UbufP2PCommOverlap(
sample_buffer, # Sample userbuffer
WORLD_RANK, # World rank
WORLD_SIZE, # World size
LOCAL_RANK, # Rank within the node
LOCAL_SIZE, # Number of ranks/GPUs per node
0, # Node ID
1, # Number of nodes
buffer_dtype = torch.bfloat16
if (
opts.fp8
and not opts.bulk_overlap
and (opts.comm_type == tex.CommOverlapType.AG or opts.fp8_output)
):
buffer_dtype = torch.uint8
ub_obj = (
tex.CommOverlapP2P(
(outer_size, hidden_size),
buffer_dtype,
helper,
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
1, # Number of communication SMs
1, # CGA cluster size
opts.comm_type == 0 or opts.atomic, # Set SM margin
opts.aggregate, # Aggregate 2X GEMM chunks
3, # Max concurrent GEMM streams
opts.comm_type == 0, # overlap with reduce scatter
opts.atomic, # use a single GEMM with atomic-counters
not (opts.atomic and bool(int(os.getenv("NVTE_AG_P2P_MULTI_ATOMIC", "0")))),
ub_callbacks,
opts.comm_type,
set_sm_margin=opts.comm_type == tex.CommOverlapType.RS or opts.atomic,
atomic_gemm=opts.atomic,
aggregate=opts.aggregate,
use_ce=not (opts.atomic and bool(int(os.getenv("NVTE_AG_P2P_MULTI_ATOMIC", "0")))),
)
if opts.p2p
else tex.UbufCommOverlap(
sample_buffer, # Sample userbuffer
WORLD_RANK, # World rank
WORLD_SIZE, # World size
LOCAL_RANK, # Rank within the node
LOCAL_SIZE, # Number of ranks/GPUs per node
0, # Node ID
1, # Number of nodes
else tex.CommOverlap(
(outer_size, hidden_size),
buffer_dtype,
helper,
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
16, # Number of communication SMs
2, # CGA cluster size
4, # Number of communication splits
True, # Set SM margin
3, # Max concurrent GEMM streams
opts.atomic, # Use a single GEMM with atomic-counters
ub_callbacks,
atomic_gemm=opts.atomic,
)
)
# Numerical check on AG + atomic GEMM requires testing an AG+RS pair
ub_obj2 = None
if opts.atomic and opts.comm_type == 1 and opts.check_numerics:
sample_buffer2 = torch.empty(
(outer_size, hidden_size),
dtype=torch.uint8 if opts.fp8_output else torch.bfloat16,
device="cuda",
)
if opts.atomic and opts.comm_type == tex.CommOverlapType.AG and opts.check_numerics:
ub_obj2 = (
tex.UbufP2PCommOverlap(
sample_buffer2, # Sample userbuffer
WORLD_RANK, # World rank
WORLD_SIZE, # World size
LOCAL_RANK, # Rank within the node
LOCAL_SIZE, # Number of ranks/GPUs per node
0, # Node ID
1, # Number of nodes
tex.CommOverlapP2P(
(outer_size, hidden_size),
torch.uint8 if opts.fp8_output else torch.bfloat16,
helper,
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
1, # Number of communication SMs
1, # CGA cluster size
True, # Set SM margin
False, # Aggregate 2X GEMM chunks
3, # Max concurrent GEMM streams
True, # overlap with reduce scatter
True, # use a single GEMM with atomic-counters
True, # use copy engine for P2P communications
ub_callbacks,
tex.CommOverlapType.RS,
set_sm_margin=True,
atomic_gemm=True,
)
if opts.atomic_rs_p2p
else tex.UbufCommOverlap(
sample_buffer2, # Sample userbuffer
WORLD_RANK, # World rank
WORLD_SIZE, # World size
LOCAL_RANK, # Rank within the node
LOCAL_SIZE, # Number of ranks/GPUs per node
0, # Node ID
1, # Number of nodes
else tex.CommOverlap(
(outer_size, hidden_size),
torch.uint8 if opts.fp8_output else torch.bfloat16,
helper,
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
16, # Number of communication SMs
2, # CGA cluster size
4, # Number of communication splits
True, # Set SM margin
3, # Max concurrent GEMM streams
True, # uUe a single GEMM with atomic-counters
ub_callbacks,
atomic_gemm=True,
)
)
......@@ -426,12 +385,12 @@ def _main(opts):
local_kernel_t_shape = (ffn_hidden_size, hidden_size)
local_inp_shape = (outer_size, hidden_size)
# Bulk overlap comm tensor is distributed for AG overlap only
if opts.comm_type == 1:
if opts.comm_type == tex.CommOverlapType.AG:
bulk_inp_shape = (outer_size // tp_size, hidden_size)
else:
bulk_inp_shape = (outer_size, hidden_size)
else:
if opts.comm_type == 1:
if opts.comm_type == tex.CommOverlapType.AG:
# (M/P, N) -> overlapped AG -> (M, N) x (K/P, N)^T = (M, K/P)
local_kernel_t_shape = (ffn_hidden_size // tp_size, hidden_size)
local_inp_shape = (outer_size // tp_size, hidden_size)
......@@ -472,7 +431,7 @@ def _main(opts):
std=opts.std,
)
else:
if opts.comm_type == 1:
if opts.comm_type == tex.CommOverlapType.AG:
# AG Kernel: (K/P, N) -> gather -> (K, N) -> T -> (N, K)
ker_g = torch.transpose(
te.distributed.gather_along_first_dim(kernel_t, tp_group)[0], 0, 1
......@@ -494,7 +453,7 @@ def _main(opts):
).to(dtype=torch.float32)
if opts.bulk_overlap:
if opts.comm_type == 1:
if opts.comm_type == tex.CommOverlapType.AG:
ref_g = te.distributed.gather_along_first_dim(bulk_inp, tp_group)[0]
else:
# First all-gather all the bulk inputs into a list
......@@ -505,7 +464,7 @@ def _main(opts):
else:
ref_g = torch.matmul(inp_g, ker_g)
if ub_obj2 is not None:
inp2_g = torch.nn.functional.gelu(ref_g)
inp2_g = torch.nn.functional.gelu(ref_g) # pylint: disable=not-callable
ref2_g = torch.matmul(inp2_g, ker2_g)
if opts.fp8:
......@@ -529,7 +488,7 @@ def _main(opts):
fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_WEIGHT].copy_(ker_amax)
ref_amax = torch.max(torch.abs(ref_g))
fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_OUTPUT].copy_(ref_amax)
if opts.bulk_overlap and opts.comm_type == 0:
if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS:
bulk_amax = torch.max(torch.abs(bulk_inp))
fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_OUTPUT].copy_(bulk_amax)
elif ub_obj2 is not None:
......@@ -551,7 +510,7 @@ def _main(opts):
kernel_t_fp8 = tex.cast_to_fp8(
kernel_t, fp8_meta, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype
)
if opts.bulk_overlap and opts.comm_type == 0:
if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS:
bulk_inp_fp8 = tex.cast_to_fp8(
bulk_inp, fp8_meta, tex.FP8Tensors.GEMM2_OUTPUT, fp8_dtype
)
......@@ -574,7 +533,7 @@ def _main(opts):
rtol=0.125,
atol=0.0675,
)
if opts.bulk_overlap and opts.comm_type == 0:
if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS:
torch.allclose(
bulk_inp.to(dtype=torch.float32),
bulk_inp_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT],
......@@ -590,7 +549,7 @@ def _main(opts):
)
# Set Fp8 scales for userbuffers
if opts.comm_type == 1:
if opts.comm_type == tex.CommOverlapType.AG:
ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_INPUT])
if ub_obj2 is not None:
ub_obj2.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT])
......@@ -602,7 +561,7 @@ def _main(opts):
# Set up comm/compute buffers
ubuf_out2 = None
rs_out2 = None
if opts.comm_type == 1:
if opts.comm_type == tex.CommOverlapType.AG:
if opts.bulk_overlap:
ub_obj.copy_input_to_ubuf(bulk_inp, 1)
gemm_inp = inp
......@@ -686,9 +645,9 @@ def _main(opts):
gelu=False,
use_split_accumulator=te.module.base._2X_ACC_FPROP,
ub_algo=(
tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P
tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P
if opts.atomic_rs_p2p
else tex.UbufOverlapAlgo.ATOMIC_GEMM_RS
else tex.CommOverlapAlgo.ATOMIC_GEMM_RS
),
ub=ub_obj2,
extra_output_tensor=rs_out2,
......@@ -762,10 +721,14 @@ def _main(opts):
avg_gpu_time = sum(gpu_times) / opts.timing_iters
gemm_name = "".join(
[
"p2p all-gather + " if opts.comm_type == 1 else "",
"p2p all-gather + " if opts.comm_type == tex.CommOverlapType.AG else "",
"atomic " if opts.atomic else "",
"GEMM",
(f" + {'p2p ' if opts.p2p else ''}reduce-scatter" if opts.comm_type == 0 else ""),
(
f" + {'p2p ' if opts.p2p else ''}reduce-scatter"
if opts.comm_type == tex.CommOverlapType.RS
else ""
),
]
)
timing_info = (
......@@ -781,7 +744,7 @@ def _main(opts):
dist.barrier(tp_group)
if opts.bulk_overlap:
output_info = ""
if opts.comm_type == 1:
if opts.comm_type == tex.CommOverlapType.AG:
# Bulk overlap AG output is already gathered
test_out = ub_obj.get_ubuf_output(1)
else:
......@@ -794,7 +757,7 @@ def _main(opts):
output_info += f"output: {list(test_out.shape)} | reference: {list(ref_out.shape)}"
dist_print(
output_info,
src=0 if opts.comm_type == 0 else None,
src=0 if opts.comm_type == tex.CommOverlapType.RS else None,
section=True,
)
......@@ -805,7 +768,7 @@ def _main(opts):
)
dist_print(nonzero_info, src=0, section=True, group=tp_group)
else:
if opts.comm_type == 1:
if opts.comm_type == tex.CommOverlapType.AG:
if ub_obj2 is not None:
# AG+RS Output: (M/P, N) -> gather -> (M, N)
output = rs_out2.to(dtype=torch.float32)
......
......@@ -9,7 +9,6 @@ import sys
import socket
import argparse
import warnings
from functools import partial
import torch
import torch.distributed as dist
......
......@@ -42,6 +42,9 @@ if not tex.device_supports_multicast():
# Force GPU kernels to launch in the order they're executed by the host CPU
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
# Clear torch.dynamo caches
torch._dynamo.reset()
def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggregate):
test_path = TEST_ROOT / "run_gemm_with_overlap.py"
......
......@@ -80,7 +80,11 @@ list(APPEND transformer_engine_SOURCES
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_rope/fused_rope.cu
recipe/delayed_scaling.cu)
recipe/delayed_scaling.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/comm_gemm_overlap.cpp)
add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include")
......@@ -93,6 +97,15 @@ target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
# Compiling Userbuffers with native MPI bootstrapping requires linking against MPI
option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF)
if (NVTE_UB_WITH_MPI)
find_package(MPI REQUIRED)
target_link_libraries(transformer_engine PUBLIC MPI::MPI_CXX)
target_include_directories(transformer_engine PRIVATE ${MPI_CXX_INCLUDES})
target_compile_definitions(transformer_engine PUBLIC NVTE_UB_WITH_MPI)
endif()
# Hack to enable dynamic loading in cuDNN frontend
target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING)
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/comm_gemm_overlap.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h>
#include <cassert>
#include <numeric>
#include "common/common.h"
#include "common/util/cuda_driver.h"
#include "common/util/cuda_runtime.h"
#include "common/util/logging.h"
#include "common/util/system.h"
#include "userbuffers/userbuffers.h"
#define HALF_BYTES 2
#define UB_MAX_SM 32
using namespace std::placeholders;
namespace transformer_engine {
/***************************************************************************************************
* Comm+GEMM Overlap Common Core
**************************************************************************************************/
bool ubuf_built_with_mpi() {
#ifdef NVTE_UB_WITH_MPI
return true;
#else
return false;
#endif
}
CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode,
int numnodes, int tp_size, ExtAllgatherOp allgather_handle,
ExtBarrierOp barrier_handle, int num_splits, int num_max_streams,
int comm_cga_size, int num_comm_sm, bool set_sm_margin,
bool use_ce, bool atomic_gemm) {
// Initialize userbuf communicator
if (!_comm_created) {
if (myrank == 0) {
printf("!!! [UB] Create Userbuffers Communicator\n");
}
#ifdef NVTE_UB_WITH_MPI
create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1);
#else
create_communicator_grouped2(&_ub_comm, myrank, numranks, mylocal, numlocal, mynode, numnodes,
allgather_handle, barrier_handle, 1, 1, tp_size, 1);
#endif
_comm_created = true;
}
_use_ce = static_cast<int>(use_ce);
_num_comm_sm = num_comm_sm;
_cga_size = comm_cga_size;
for (int i = 0; i < std::min(num_max_streams, num_splits); i++) {
cudaStream_t stream;
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1));
_stream_compute.push_back(std::move(stream));
}
_num_splits = num_splits;
_rank = _ub_comm->myrank;
_tp_size = tp_size;
_tp_id = _rank % _tp_size;
// Set the number of SMs for GEMM with margin
int sm_count = transformer_engine::cuda::sm_count();
_math_sms = (set_sm_margin) ? sm_count - num_comm_sm : sm_count;
_math_sms -= transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", 0);
_atomic_gemm = atomic_gemm;
if (_atomic_gemm) {
void *counter_ptr;
size_t counter_bytes = _num_splits * 2 * sizeof(int32_t);
NVTE_CHECK_CUDA(cudaMalloc(&counter_ptr, counter_bytes));
NVTE_CHECK_CUDA(cudaMemset(counter_ptr, 0, counter_bytes));
NVTE_CHECK_CUDA(cudaMemset(counter_ptr, 1, counter_bytes / 2));
_counter = TensorWrapper(counter_ptr, std::vector<size_t>{static_cast<size_t>(_num_splits * 2)},
DType::kInt32);
}
// CUDA event creation
cudaEventCreateWithFlags(&_start_compute, 0);
cudaEventCreateWithFlags(&_stop_compute, 0);
cudaEventCreateWithFlags(&_start_comm, 0);
cudaEventCreateWithFlags(&_stop_comm, 0);
}
CommOverlapCore::~CommOverlapCore() {
cudaEventDestroy(_stop_comm);
cudaEventDestroy(_start_comm);
cudaEventDestroy(_stop_compute);
cudaEventDestroy(_start_compute);
if (_atomic_gemm) cudaFree(_counter.dptr());
for (size_t i = 0; i < _stream_compute.size(); i++) cudaStreamDestroy(_stream_compute[i]);
if (_comm_created) {
#ifdef NVTE_UB_WITH_MPI
destroy_communicator_mpi(_ub_comm);
#else
destroy_communicator(_ub_comm);
#endif
_comm_created = false;
}
}
/***************************************************************************************************
* Comm+GEMM Overlap Base (Pipelined / Collective)
**************************************************************************************************/
CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType buffer_dtype,
int myrank, int numranks, int mylocal, int numlocal, int mynode,
int numnodes, int tp_size, ExtAllgatherOp allgather_handle,
ExtBarrierOp barrier_handle, int num_splits, int num_max_streams,
int comm_cga_size, int num_comm_sm, bool set_sm_margin,
bool atomic_gemm)
: CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size,
allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size,
num_comm_sm, set_sm_margin, false, atomic_gemm) {
_rs_kernel_type = getenv<int>("NVTE_RS_STRIDED_ATOMIC", 0);
NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3,
"Invalid choice for NVTE_RS_STRIDED_ATOMIC: Must be 0 (non-atomic), 1 (atomic) ",
"or 2 (multi-atomic).");
NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!");
size_t buffer_bytes = buffer_shape[0] * buffer_shape[1] * typeToSize(buffer_dtype);
void *buffer_ptr;
_ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true);
if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg);
_ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype);
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_comm, cudaStreamNonBlocking, -1));
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_d2dcopy, 0));
}
CommOverlapBase::~CommOverlapBase() {
cudaEventDestroy(_start_d2dcopy);
cudaStreamDestroy(_stream_comm);
}
/*
** Bulk GEMM + COMM
** This function assumes the communication input is pre-copied to _ubuf
*/
void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator,
CommOverlapType comm_type, TensorWrapper &rs_output,
cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
// Catch up the default torch stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0));
// Communication: AG and RS
int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size
if (comm_type == CommOverlapType::AG) {
allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm);
} else {
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
comm_elements *= 2;
assert(rs_output.numel() == _ubuf.numel() / _tp_size);
assert(rs_output.size(0) == _ubuf.size(0) / _tp_size);
assert(rs_output.element_size() == 2);
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf_scale_inv, _ub_reg, 0,
comm_elements, _ub_comm, _stream_comm);
} else {
reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm);
}
}
assert(pre_gelu_out.numel() == 0);
nvte_cublas_gemm(A.data(), B.data(), D.data(), bias.data(), pre_gelu_out.data(), transa, transb,
grad, workspace.data(), accumulate, use_split_accumulator, _math_sms,
stream_main);
_ub_comm->sms = ori_sms;
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
} // CommOverlapBase::bulk_overlap
/*
** Split FPROP GEMM + ReduceScatter
*/
void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
bool grad, bool accumulate, bool use_split_accumulator,
bool gemm_overlap, TensorWrapper &rs_output,
cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
// Get GEMM dimensions
size_t m = A.size(0);
size_t k = A.size(1);
size_t n = B.size(0);
size_t m_chunk = m / _num_splits;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
// Get input, output, and workspace data pointers
char *input_a_chunk_ptr = reinterpret_cast<char *>(A.dptr());
char *output_buf_chunk_ptr = reinterpret_cast<char *>(_ubuf.dptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.dptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
// Reset atomic counters
int *counter_ptr = reinterpret_cast<int *>(_counter.dptr());
reset_counters(counter_ptr, _num_splits, false, stream_main);
// Catch up the default torch stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_compute, 0));
assert(pre_gelu_out.numel() == 0);
auto output_d = TensorWrapper(_ubuf.dptr(), {n, m}, D.dtype(), D.amax(), D.scale(), nullptr);
auto workspace_chunk =
TensorWrapper(workspace.dptr(), std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(),
transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _num_splits, 0, true, _counter.data(),
_stream_compute[0]);
for (int i = 0; i < _num_splits; i++) {
if (_rs_kernel_type == 1) {
if (i == _num_splits - 1) {
_ub_comm->sms = UB_MAX_SM;
}
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reducescatter2_userbuff_strided_atomic_fp8<fp8_type>(
rs_output_ptr, _ubuf_scale_inv, _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits,
&counter_ptr[i], _ub_comm, _stream_comm););
} else {
reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m,
_num_splits, &counter_ptr[i], _ub_comm,
_stream_comm);
}
} else if (_rs_kernel_type == 2) {
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reducescatter2_userbuff_strided_multiatomic_fp8<fp8_type>(
rs_output_ptr, _ubuf_scale_inv, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits,
counter_ptr, _ub_comm, _stream_comm););
} else {
reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, n, m,
_num_splits, counter_ptr, _ub_comm,
_stream_comm);
}
break;
} else {
consumer(counter_ptr, i, _stream_comm);
if (_ubuf.element_size() == 1) {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(rs_output_ptr, _ubuf_scale_inv,
_ub_reg, i * m_chunk, m_chunk, n, m,
_ub_comm, _stream_comm););
} else {
reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m,
_ub_comm, _stream_comm);
}
}
rs_output_ptr += m_chunk * rs_output.element_size();
}
_ub_comm->sms = ori_sms;
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[0]));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
} // split_overlap_rs
/*
** Split FPROP GEMM + ReduceScatter
*/
void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
bool grad, bool accumulate, bool use_split_accumulator,
bool gemm_overlap, TensorWrapper &rs_output,
cudaStream_t stream_main) {
// Get GEMM dimensions
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
size_t m = A.size(0);
size_t k = A.size(1);
size_t n = B.size(0);
size_t m_chunk = m / _num_splits;
size_t input_a_chunk_size = m_chunk * k;
size_t output_chunk_size = n * m_chunk;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
// Get input, output, and workspace data pointers
char *input_a_chunk_ptr = reinterpret_cast<char *>(A.dptr());
char *output_buf_chunk_ptr = reinterpret_cast<char *>(_ubuf.dptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.dptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
// Catch up the default torch stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0));
}
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_compute, 0));
assert(pre_gelu_out.numel() == 0);
if (gemm_overlap) {
auto input_a_chunk =
TensorWrapper(A.dptr(), {m_chunk, k}, A.dtype(), nullptr, nullptr, A.scale_inv());
auto output_chunk =
TensorWrapper(_ubuf.dptr(), {m, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr);
auto workspace_chunk = TensorWrapper(
workspace.dptr(), std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[0]);
for (int i = 1; i < _num_splits; i++) {
input_a_chunk_ptr += input_a_chunk_size * B.element_size();
output_buf_chunk_ptr += output_chunk_size * D.element_size();
char *workspace_chunk_ptr =
workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk;
input_a_chunk = TensorWrapper(reinterpret_cast<void *>(input_a_chunk_ptr), {m_chunk, k},
A.dtype(), nullptr, nullptr, A.scale_inv());
output_chunk = TensorWrapper(reinterpret_cast<void *>(output_buf_chunk_ptr), {n, m_chunk},
D.dtype(), D.amax(), D.scale(), nullptr);
workspace_chunk = TensorWrapper(reinterpret_cast<void *>(workspace_chunk_ptr),
std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
NVTE_CHECK_CUDA(
cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0));
// Communication chunk
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, _ubuf_scale_inv, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m,
_ub_comm, _stream_comm););
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size,
m_chunk, n, m, _ub_comm, _stream_comm);
}
rs_output_ptr += m_chunk * rs_output.element_size();
}
int last_compute_stream_id =
(_num_splits + _stream_compute.size() - 1) % _stream_compute.size();
NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[last_compute_stream_id]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0));
// Last communication chunk with max SM
_ub_comm->sms = UB_MAX_SM;
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, _ubuf_scale_inv, _ub_reg, (_num_splits - 1) * output_chunk_size,
m_chunk, n, m, _ub_comm, _stream_comm););
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg,
(_num_splits - 1) * output_chunk_size, m_chunk, n, m,
_ub_comm, _stream_comm);
}
} else {
for (int i = 0; i < _num_splits; i++) {
char *workspace_chunk_ptr =
workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk;
auto input_a_chunk = TensorWrapper(reinterpret_cast<void *>(input_a_chunk_ptr), {m_chunk, k},
A.dtype(), nullptr, nullptr, A.scale_inv());
auto output_chunk = TensorWrapper(reinterpret_cast<void *>(output_buf_chunk_ptr),
{n, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr);
auto workspace_chunk =
TensorWrapper(reinterpret_cast<void *>(workspace_chunk_ptr),
std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[i % _stream_compute.size()]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0));
// Communication chunk. Uses MAX_SM at the last chunk
if (i == _num_splits - 1) {
_ub_comm->sms = UB_MAX_SM;
}
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, _ubuf_scale_inv, _ub_reg, i * output_chunk_size, m_chunk, n, m,
_ub_comm, _stream_comm););
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size,
m_chunk, n, m, _ub_comm, _stream_comm);
}
rs_output_ptr += m_chunk * rs_output.element_size();
input_a_chunk_ptr += input_a_chunk_size * B.element_size();
output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size();
}
}
_ub_comm->sms = ori_sms;
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0));
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
} // CommOverlapBase::split_overlap_rs
/***************************************************************************************************
* Comm+GEMM Overlap P2P Base (Ring-Exchange)
**************************************************************************************************/
CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape, DType buffer_dtype,
int myrank, int numranks, int mylocal, int numlocal,
int mynode, int numnodes, int tp_size,
ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle,
CommOverlapType comm_type, int num_max_streams,
int comm_cga_size, int num_comm_sm, bool set_sm_margin,
bool use_ce, bool atomic_gemm, bool aggregate)
: CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size,
allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size,
num_comm_sm, set_sm_margin, use_ce, atomic_gemm) {
_is_p2p = true;
_is_reduce_scatter = comm_type == CommOverlapType::RS;
_aggregate = aggregate;
// Create workspace tensor with userbuffer
NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!");
size_t buffer_bytes = buffer_shape[0] * buffer_shape[1] * typeToSize(buffer_dtype);
int buffer_chunk_bytes = buffer_bytes / tp_size;
_num_ubuf_chunks = tp_size;
if (_is_reduce_scatter) {
// GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk
// outputs for reduction at the end of the pipelining.
buffer_bytes = buffer_bytes / tp_size * (tp_size * 2 - 1);
_num_ubuf_chunks = tp_size * 2 - 1;
}
void *buffer_ptr;
_ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true);
if (_rank == 0) printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg);
_ubuf = TensorWrapper(buffer_ptr, {buffer_shape[0] / tp_size * _num_ubuf_chunks, buffer_shape[1]},
buffer_dtype);
// Create tensor chunks for easy management
char *ubuf_byte_ptr = reinterpret_cast<char *>(buffer_ptr);
for (int i = 0; i < _num_ubuf_chunks; i++) {
_ubufs.push_back(TensorWrapper(reinterpret_cast<void *>(ubuf_byte_ptr),
{buffer_shape[0] / tp_size, buffer_shape[1]}, buffer_dtype));
ubuf_byte_ptr += buffer_chunk_bytes;
}
_rank_round_tp = (_rank / _tp_size) * _tp_size;
_next_rank = (_tp_size + _rank + 1) % _tp_size + _rank_round_tp;
_prev_rank = (_tp_size + _rank + -1) % _tp_size + _rank_round_tp;
_self_chunk_id = _tp_id;
if (_atomic_gemm && !_is_reduce_scatter) {
_use_multiatomic_ag = getenv<bool>("NVTE_AG_P2P_MULTI_ATOMIC");
if (_use_multiatomic_ag) {
_use_ce = 0;
_ub_comm->push = 1;
if (_rank == 0) {
printf("!!userbuffers_sendrecv_multi_atomic_shuffle\n");
}
}
_self_chunk_id = 0;
NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t)));
}
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_send, cudaStreamNonBlocking, -1));
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_recv, cudaStreamNonBlocking, -1));
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_send, 0));
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_recv, 0));
}
CommOverlapP2PBase::~CommOverlapP2PBase() {
cudaEventDestroy(_stop_recv);
cudaEventDestroy(_stop_send);
cudaStreamDestroy(_stream_recv);
cudaStreamDestroy(_stream_send);
}
/*
** Split AllGather + AtomicGEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG
** outputs in each rank to be in the contiguous memory space after all ring exchange phases.
*/
void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator,
TensorWrapper &B_copy, cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
// Get GEMM dimensions between TN and NN input layouts
const size_t m = (transa) ? A.size(0) : A.size(1);
const size_t n = _ubuf.size(0);
const size_t n_chunk = n / _tp_size;
assert(pre_gelu_out.numel() == 0);
// Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
// Create an GEMM output buffer with N+1 chunks in a contiguous memory
void *D_buffer_ptr;
int D_chunk_bytes = n_chunk * m * D.element_size();
NVTE_CHECK_CUDA(cudaMallocAsync(&D_buffer_ptr, (_tp_size + 1) * D_chunk_bytes, stream_main));
auto D_buffer = TensorWrapper(D_buffer_ptr, D.shape(), D.dtype(), D.amax(), D.scale(), nullptr);
// Reset atomic counters
int *counter_ptr = reinterpret_cast<int *>(_counter.dptr());
reset_counters(counter_ptr, _tp_size, true, stream_main);
// Catch up the default torch stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0));
auto input_b = TensorWrapper(_ubuf.dptr(), B.shape(), B.dtype(), nullptr, nullptr, B.scale_inv());
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
auto workspace_chunk =
TensorWrapper(workspace.dptr(), std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
for (int i = 0; i < _tp_size - 1; i++) {
// Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
// have the AG output in all ranks to be contiguous after the ring
// exchanges
int send_chunk_id = i;
int recv_chunk_id = i + 1;
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;
if (_use_multiatomic_ag) {
if (i == 0) {
_ub_comm->use_ce = 0;
userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes,
_ub_comm, _next_rank, _prev_rank, _tp_size, counter_ptr,
true, _stream_recv);
}
} else {
userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, _next_rank,
_stream_recv);
userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, _prev_rank,
_stream_recv);
producer(counter_ptr, recv_chunk_id, _stream_recv);
}
if (i == 0) {
nvte_cublas_atomic_gemm(A.data(), input_b.data(), D_buffer.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms, 0, _tp_size, false,
_counter.data(), stream_main);
}
}
// Store the input activation for backprop
if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_self_chunk_id].numel());
assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size());
NVTE_CHECK_CUDA(
cudaMemcpyAsync(B_copy.dptr(), _ubufs[_self_chunk_id].dptr(),
_ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(),
cudaMemcpyDeviceToDevice, _stream_send));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
}
// Copy the first GEMM output chunk to the end chunk position of D_buffer
char *src_ptr = reinterpret_cast<char *>(D_buffer.dptr());
NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + (D.numel() * D.element_size()), src_ptr, D_chunk_bytes,
cudaMemcpyDeviceToDevice, stream_main));
// Return the last N rows of D_buffer
NVTE_CHECK_CUDA(cudaMemcpyAsync(D.dptr(), src_ptr + D_chunk_bytes, D.numel() * D.element_size(),
cudaMemcpyDeviceToDevice, stream_main));
// Clean up buffer allocation
NVTE_CHECK_CUDA(cudaFreeAsync(D_buffer_ptr, stream_main));
_ub_comm->sms = ori_sms;
} // CommOverlapP2PBase::atomic_gemm_overlap_ag
/*
** Split AllGather + GEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG
** outputs in each rank to be in the contiguous memory space after all ring exchange phases.
*/
void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
bool grad, bool accumulate, bool use_split_accumulator,
TensorWrapper &B_copy, cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
// Get GEMM dimensions between TN and NN input layouts
const size_t m = (transa) ? A.size(0) : A.size(1);
const size_t k = (transa) ? A.size(1) : A.size(0);
const size_t n_chunk = _ubufs[0].size(0);
// Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const bool do_gelu = pre_gelu_out.numel() > 0;
const int output_chunk_bytes = (n_chunk * m) * D.element_size();
const int aux_chunk_bytes = do_gelu ? (n_chunk * m) * pre_gelu_out.element_size() : 0;
// Get output and workspace data pointers
char *output_ptr = reinterpret_cast<char *>(D.dptr());
char *pre_gelu_out_ptr = reinterpret_cast<char *>(pre_gelu_out.dptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.dptr());
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0));
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0));
}
if (_aggregate) {
const int num_steps = _tp_size / 2;
char *input_b_ptr = reinterpret_cast<char *>(_ubuf.dptr());
// Initial 1X input chunk exchange between neighboring peers
int send_chunk_id = _tp_id;
int recv_chunk_id = (_tp_id % 2 == 0) ? _tp_id + 1 : _tp_id - 1;
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;
int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank;
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank,
_stream_send);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank,
_stream_recv);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _stop_recv, 0));
int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1;
const int next_rank = (_tp_size + _tp_id + 2) % _tp_size + _rank_round_tp;
const int prev_rank = (_tp_size + _tp_id - 2) % _tp_size + _rank_round_tp;
// Ring exchange of 2X inputs chunks
for (int i = 0; i < num_steps; i++) {
send_chunk_id = (_tp_size + local_rank_round2 - i * 2) % _tp_size;
recv_chunk_id = (_tp_size + local_rank_round2 - i * 2 - 2) % _tp_size;
send_offset = comm_bytes * send_chunk_id;
recv_offset = comm_bytes * recv_chunk_id;
// GEMM
char *input_b_chunk_ptr = input_b_ptr + send_offset;
auto input_b_chunk =
TensorWrapper(reinterpret_cast<void *>(input_b_chunk_ptr), {n_chunk * 2, k}, B.dtype(),
nullptr, nullptr, B.scale_inv());
char *output_chunk_ptr = output_ptr + (send_chunk_id * output_chunk_bytes);
auto output_chunk = TensorWrapper(reinterpret_cast<void *>(output_chunk_ptr),
{n_chunk * 2, m}, D.dtype(), D.amax(), D.scale(), nullptr);
char *aux_chunk_ptr =
(do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) : nullptr;
auto aux_chunk_shape =
(do_gelu) ? std::vector<size_t>{n_chunk * 2, m} : std::vector<size_t>{0};
auto aux_chunk = TensorWrapper(reinterpret_cast<void *>(aux_chunk_ptr), aux_chunk_shape,
pre_gelu_out.dtype());
char *workspace_chunk_ptr =
workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk;
auto workspace_chunk =
TensorWrapper(reinterpret_cast<void *>(workspace_chunk_ptr),
std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
if (i < num_steps - 1) {
// P2P communication
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm,
next_rank, _stream_send);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm,
prev_rank, _stream_recv);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0));
NVTE_CHECK_CUDA(
cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
} else if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, _stream_send));
}
}
} else {
for (int i = 0; i < _tp_size; i++) {
// Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
// have the AG output in all ranks to be contiguous after the ring
// exchanges
int send_chunk_id = (_tp_size + _tp_id - i) % _tp_size;
int recv_chunk_id = (_tp_size + _tp_id - i - 1) % _tp_size;
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;
// GEMM
auto input_b_chunk = TensorWrapper(_ubufs[send_chunk_id].dptr(), {n_chunk, k}, B.dtype(),
nullptr, nullptr, B.scale_inv());
char *output_chunk_ptr = output_ptr + (send_chunk_id * output_chunk_bytes);
auto output_chunk = TensorWrapper(reinterpret_cast<void *>(output_chunk_ptr), {n_chunk, m},
D.dtype(), D.amax(), D.scale(), nullptr);
char *aux_chunk_ptr =
(do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) : nullptr;
auto aux_chunk_shape = (do_gelu) ? std::vector<size_t>{n_chunk, m} : std::vector<size_t>{0};
auto aux_chunk = TensorWrapper(reinterpret_cast<void *>(aux_chunk_ptr), aux_chunk_shape,
pre_gelu_out.dtype());
char *workspace_chunk_ptr =
workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk;
auto workspace_chunk =
TensorWrapper(reinterpret_cast<void *>(workspace_chunk_ptr),
std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
if (i < _tp_size - 1) {
// P2P communication
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm,
_next_rank, _stream_send);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm,
_prev_rank, _stream_recv);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0));
NVTE_CHECK_CUDA(
cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
} else if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, _stream_send));
}
}
}
_ub_comm->sms = ori_sms;
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0));
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
} // CommOverlapP2PBase::split_overlap_ag
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator,
TensorWrapper &rs_output,
cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
// Get communication and GEMM input chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
// Reset counters
int *counter_ptr = reinterpret_cast<int *>(_counter.dptr());
reset_counters(counter_ptr, _tp_size, false, stream_main);
// Catch up the main stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0));
// Atomic GEMM
// Process GEMM chunks in the order that AG+GEMM places the output chunks.
auto output_d = TensorWrapper(_ubuf.dptr(), D.shape(), D.dtype(), D.amax(), D.scale(), nullptr);
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
auto workspace_chunk =
TensorWrapper(workspace.data(), std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(),
transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, 0, _tp_size, true, _counter.data(),
stream_main);
// P2P communication chunk
for (int i = 1; i < _tp_size; i++) {
int send_chunk_id = i - 1;
int recv_chunk_id = send_chunk_id + _tp_size;
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;
int send_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp;
int recv_rank = (_tp_id + i) % _tp_size + _rank_round_tp;
consumer(counter_ptr, send_chunk_id, _stream_recv);
userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank,
_stream_recv);
userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank,
_stream_recv);
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
// Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].dptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) {
assert(_ubuf_scale_inv_initialized);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, _ubuf_scale_inv, _tp_size,
_ubufs[0].numel(), stream_main););
} else {
reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main);
}
_ub_comm->sms = ori_sms;
}
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
bool grad, bool accumulate, bool use_split_accumulator,
TensorWrapper &rs_output, cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
size_t k = A.size(1);
size_t n = B.size(0);
// Get communication and GEMM input chunk sizes
size_t n_chunk = n / _tp_size;
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const int input_b_chunk_bytes = n_chunk * k * B.element_size();
// Get input and workspace data pointers
char *input_b_ptr = reinterpret_cast<char *>(B.dptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.dptr());
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
// Catch up the main stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0));
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0));
}
// GEMM and send/recv chunks
for (int i = 0; i < _tp_size; i++) {
// GEMM chunk
int input_b_chunk_id = (_tp_id + i + 1) % _tp_size;
char *input_b_chunk_ptr = input_b_ptr + (input_b_chunk_id * input_b_chunk_bytes);
auto input_b_chunk = TensorWrapper(reinterpret_cast<void *>(input_b_chunk_ptr), {n_chunk, k},
B.dtype(), nullptr, nullptr, B.scale_inv());
auto output_chunk =
TensorWrapper(_ubufs[i].dptr(), _ubufs[i].shape(), D.dtype(), D.amax(), D.scale(), nullptr);
char *workspace_chunk_ptr = workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk;
auto workspace_chunk =
TensorWrapper(reinterpret_cast<void *>(workspace_chunk_ptr),
std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[i % _stream_compute.size()]);
if (i > 0) {
// P2P communication chunk
int send_offset = comm_bytes * (i - 1);
int recv_offset = comm_bytes * (i - 1 + _tp_size);
int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp;
int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp;
NVTE_CHECK_CUDA(
cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_comm, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_comm, 0));
userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank,
_stream_send);
userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank,
_stream_recv);
}
}
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0));
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
// Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].dptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) {
assert(_ubuf_scale_inv_initialized);
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, _ubuf_scale_inv, _tp_size,
_ubufs[0].numel(), stream_main););
} else {
reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main);
}
_ub_comm->sms = ori_sms;
}
} // namespace transformer_engine
......@@ -20,7 +20,9 @@
#include <utility>
#include "common/util/cuda_driver.h"
#include "common/util/cuda_runtime.h"
#include "common/util/logging.h"
#include "common/util/system.h"
#include "ipcsocket.h"
#include "userbuffers.h"
......@@ -44,31 +46,19 @@ static MPI_Comm EXT_COMM_INTER;
} while (false)
void ub_mpi_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes,
ExtComm group) {
// UB_MPI_CHECK(MPI_Allgather(localdata, localbytes, MPI_BYTE,
// globaldata, globalbytes, MPI_BYTE,
// static_cast<MPI_Comm>(group)));
MPI_Comm comm = static_cast<MPI_Comm>(group);
ExtComm comm) {
int numranks;
UB_MPI_CHECK(MPI_Comm_size(comm, &numranks));
assert(globalbytes == numranks * localbytes);
int myrank;
UB_MPI_CHECK(MPI_Comm_rank(comm, &myrank));
char *globaltarget = reinterpret_cast<char *>(globaldata) + (myrank * localbytes);
memcpy(globaltarget, localdata, localbytes);
for (int n = 0; n < numranks; n++) {
globaltarget = reinterpret_cast<char *>(globaldata) + (n * localbytes);
UB_MPI_CHECK(MPI_Bcast(globaltarget, localbytes, MPI_BYTE, n, comm));
}
UB_MPI_CHECK(
MPI_Allgather(localdata, localbytes, MPI_BYTE, globaldata, localbytes, MPI_BYTE, comm));
}
void ub_mpi_barrier(ExtComm group) { UB_MPI_CHECK(MPI_Barrier(static_cast<MPI_Comm>(group))); }
void ub_mpi_barrier(ExtComm comm) { UB_MPI_CHECK(MPI_Barrier(comm)); }
#else
static char EXT_COMM_WORLD[] = "world";
static char EXT_COMM_INTRA[] = "intra";
static char EXT_COMM_INTER[] = "inter";
#define EXT_COMM_WORLD "world"
#define EXT_COMM_INTRA "intra"
#define EXT_COMM_INTER "inter"
#endif
#define MULTICAST_GB_TOTAL 512
......@@ -106,11 +96,10 @@ int pipe_rank(communicator *comm, int step) {
return newnode * numlocal + newlocal;
}
int create_communicator_grouped2(
communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode,
int numnodes, std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather,
std::function<void(ExtComm)> ext_barrier, int pipegpus, int pipenodes, int tensorgpus,
int tensornodes) {
int create_communicator_grouped2(communicator **comm, int myrank, int numranks, int mylocal,
int numlocal, int mynode, int numnodes,
ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier,
int pipegpus, int pipenodes, int tensorgpus, int tensornodes) {
*comm = new communicator();
(*comm)->comm_world = EXT_COMM_WORLD;
......@@ -214,8 +203,11 @@ int create_communicator_grouped2(
(*comm)->asyncblocks = 16;
#define NBUF 2
if ((*comm)->sm_arch >= 9 && (*comm)->ar2_nvsize > 1 &&
!getenv("UB_SKIPMC")) { // multicast init only for TP ops (____2 operations)
#if CUDART_VERSION >= 12010
if (!transformer_engine::getenv<bool>("UB_SKIPMC") &&
transformer_engine::cuda::supports_multicast() && (*comm)->ar2_nvsize > 1) {
// multicast init only for TP ops (____2 operations)
size_t mc_maxsize = MULTICAST_GB_TOTAL * (1ull << 30);
(*comm)->mc_offset = 0;
(*comm)->use_mc = 1;
......@@ -291,20 +283,20 @@ int create_communicator_grouped2(
(*comm)->_barrier((*comm)->comm_world);
if (!(*comm)->myrank) printf("MC initialized succesfully, window size = %ld\n", mc_maxsize);
} else {
#endif
if (!(*comm)->myrank) printf("MC NOT initialized and used\n");
(*comm)->mc_maxsize = 0;
(*comm)->mc_offset = 0;
(*comm)->use_mc = 0;
#if CUDART_VERSION >= 12010
}
#endif
#define LOCALSIZE 4 * (NVTE_REG0_OFFSET(*comm) + NVTE_REG0_FLAGS + NVTE_REG0_COMMBUFFER * NBUF)
// peer pointers + op flags + comm buffer
NVTE_CHECK_CUDA(
cudaMalloc(&(*comm)->gpu_ptrs, LOCALSIZE)); // flags and pointers, no block data yet
NVTE_CHECK_CUDA(cudaMemset((*comm)->gpu_ptrs, 0, LOCALSIZE));
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm, false);
register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm, true);
NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->send_id, (*comm)->nranks * sizeof(int)));
NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->recv_id, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int)));
NVTE_CHECK_CUDA(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int)));
......@@ -346,18 +338,17 @@ int create_communicator_grouped2(
return 0;
}
int create_communicator_grouped(
communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode,
int numnodes, std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather,
std::function<void(ExtComm)> ext_barrier, int pipegpus, int pipenodes) {
int create_communicator_grouped(communicator **comm, int myrank, int numranks, int mylocal,
int numlocal, int mynode, int numnodes,
ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier,
int pipegpus, int pipenodes) {
return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes,
ext_allgather, ext_barrier, pipegpus, pipenodes, 1, 1);
}
int create_communicator(communicator **comm, int myrank, int numranks, int mylocal, int numlocal,
int mynode, int numnodes,
std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather,
std::function<void(ExtComm)> ext_barrier) {
int mynode, int numnodes, ExtAllgatherOp ext_allgather,
ExtBarrierOp ext_barrier) {
return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes,
ext_allgather, ext_barrier, 1, 1, 1, 1);
}
......@@ -428,7 +419,7 @@ int create_communicator_mpi(communicator **comm) {
void destroy_communicator(communicator *comm) {
for (int hndl = 0; hndl < comm->free_region; hndl++) {
if (hndl > 0 && comm->use_mc && comm->mem_dealloc[hndl]) {
if (comm->use_mc && comm->mem_dealloc[hndl]) {
for (int rank = 0; rank < comm->nvsize; rank++) {
if (rank == comm->nvrank) {
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemRelease, comm->uchandles[hndl][rank]);
......@@ -479,6 +470,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
comm->memflags[hndl] = 0;
comm->mem_dealloc[hndl] = alloc;
#if CUDART_VERSION >= 12010
if (comm->use_mc && alloc) {
int nranks = comm->nvsize; // total GPUs in NVLINK domain
int myrank = comm->nvrank;
......@@ -594,6 +586,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
}
} else {
#endif
if (alloc) {
NVTE_CHECK_CUDA(cudaMalloc(gpubuff, bytes));
NVTE_CHECK_CUDA(cudaMemset(*gpubuff, 0, bytes));
......@@ -624,7 +617,9 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
free(tmp);
#if CUDART_VERSION >= 12010
}
#endif
comm->mem_size[hndl] = aligned_size;
comm->mem_ptr[hndl] = *gpubuff;
......
......@@ -392,14 +392,14 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id;
} // fp16 reduce-scatter kernel (out of place)
#if __CUDA_ARCH__ >= 900
#if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010
// All MC kernels here
template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_mc(const int op, const int flagoffset, const int firstrank,
const int myrank, const int gpustep, const int lineoffset,
const int numlines, void **commbuff, const int handleridx,
float4 *mc_ptr) {
float4 *mc_ptr, const uint64_t ub_timeout) {
int *flagptr, physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
......@@ -417,7 +417,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
volatile int *flag = (volatile int *)&(myptr[targetgpu]);
clock_t s = clock64();
while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > TIMEOUT) {
if (clock64() - s > ub_timeout) {
UB_PRINT("Reduce-scatter: SM %d [%d]: expecting %d got %d", blockIdx.x, threadIdx.x,
reduce_id, *flag);
break;
......@@ -484,7 +484,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
volatile int *flag = (volatile int *)&myptr[targetgpu];
clock_t s = clock64();
while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > 2ull * TIMEOUT) {
if (clock64() - s > 2ull * ub_timeout) {
UB_PRINT("Allgather: SM %d [%d]: expecting %d got %d", blockIdx.x, threadIdx.x, reduce_id,
*flag);
break;
......@@ -741,7 +741,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_mc(const int op, const int flagoffset, const int firstrank,
const int myrank, const int gpustep, const int lineoffset,
const int numlines, void **commbuff, const int handleridx,
float4 *mc_ptr) {}
float4 *mc_ptr, const uint64_t ub_timeout) {}
template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_mc_rs_oop(const int op, const int flagoffset,
......@@ -2496,6 +2496,18 @@ static __global__ void consumer_batch_kernel(void *atomic_ptr, int first_chunk_i
}
}
// reset counters kernel
static __global__ void reset_counters_kernel(void *atomic_ptr, int num_chunks, bool allgather) {
if (blockIdx.x == 0 && threadIdx.x == 0) {
#pragma unroll
for (int i = 0; i < num_chunks; i++) {
((unsigned int *)atomic_ptr)[i] = 1;
((unsigned int *)atomic_ptr)[i + num_chunks] = 0;
}
if (allgather) ((unsigned int *)atomic_ptr)[0] = 0;
}
}
void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream) {
dim3 block(1);
dim3 grid(1);
......@@ -2514,6 +2526,12 @@ void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStr
consumer_batch_kernel<<<grid, block, 0, stream>>>(atomic_ptr, first_chunk_i, num_chunks);
}
void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, cudaStream_t stream) {
dim3 block(1);
dim3 grid(1);
reset_counters_kernel<<<grid, block, 0, stream>>>(atomic_ptr, num_chunks, allgather);
}
template <typename fp8type>
__global__ void __launch_bounds__(MAX_THREADS / 4)
reduce_fp8_in_bf16_out_cuda(void *inputs, void *output, const float *scale,
......@@ -2546,3 +2564,24 @@ template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(void *inputs, void *output,
template void reduce_fp8_in_bf16_out<__nv_fp8_e5m2>(void *inputs, void *output, float *scale,
int num_inputs, int input_size,
cudaStream_t stream);
__global__ void __launch_bounds__(MAX_THREADS / 4)
reduce_bf16_cuda(void *inputs, void *output, const int num_inputs, const int input_size) {
const size_t tid = threadIdx.x + blockDim.x * blockIdx.x;
half *inputs_half = reinterpret_cast<half *>(inputs);
float accum_buf = static_cast<float>(inputs_half[tid]);
#pragma unroll
for (int i = 1; i < num_inputs; i++) {
accum_buf += static_cast<float>(inputs_half[tid + input_size * i]);
}
half *output_half = reinterpret_cast<half *>(output);
output_half[tid] = (half)accum_buf;
}
void reduce_bf16(void *inputs, void *output, int num_inputs, int input_size, cudaStream_t stream) {
size_t num_threads = MAX_THREADS / 4;
size_t num_blocks = (input_size + num_threads - 1) / num_threads;
dim3 block(num_threads);
dim3 grid(num_blocks);
reduce_bf16_cuda<<<grid, block, 0, stream>>>(inputs, output, num_inputs, input_size);
}
......@@ -19,11 +19,14 @@
#ifdef NVTE_UB_WITH_MPI
#include <mpi.h>
typedef MPI_Comm ExtComm;
#define ExtComm MPI_Comm
#else
typedef char *ExtComm;
#define ExtComm const char *
#endif
using ExtAllgatherOp = std::function<void(void *, size_t, void *, size_t, ExtComm)>;
using ExtBarrierOp = std::function<void(ExtComm)>;
#define NVTE_MAX_REGIONS 16
#define NVTE_MAX_SMS 32
#define NVTE_MAX_OPS 32
......@@ -142,12 +145,12 @@ struct communicator {
volatile int tail;
// Abstract communication callbacks to support external bootstrapping (e.g. DL frameworks)
std::function<void(void *, size_t, void *, size_t, ExtComm)> _allgather;
std::function<void(ExtComm)> _barrier;
ExtAllgatherOp _allgather;
ExtBarrierOp _barrier;
ExtComm comm_world,
comm_inter, // reduction group communicator (subset of the nodes) along GPU rail
comm_intra; // full intranode (all ndev GPUS)
ExtComm comm_world;
ExtComm comm_inter; // reduction group communicator (subset of the nodes) along GPU rail
ExtComm comm_intra; // full intranode (all ndev GPUS)
#ifdef NVTE_UB_WITH_MPI
MPI_Request mpihndl[NVTE_MAX_SHARP];
#endif
......@@ -161,23 +164,22 @@ typedef struct communicator communicator;
void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream);
void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream);
void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStream_t stream);
void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, cudaStream_t stream);
/* creates communicator, allocates all internal buffers if necessary */
int create_communicator_grouped2(
communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode,
int numnodes, std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather,
std::function<void(ExtComm)> ext_barrier, int pipegpus, int pipenodes, int tensorgpus,
int tensornodes);
int create_communicator_grouped2(communicator **comm, int myrank, int numranks, int mylocal,
int numlocal, int mynode, int numnodes,
ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier,
int pipegpus, int pipenodes, int tensorgpus, int tensornodes);
int create_communicator_grouped(
communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode,
int numnodes, std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather,
std::function<void(ExtComm)> ext_barrier, int pipegpus, int pipenodes);
int create_communicator_grouped(communicator **comm, int myrank, int numranks, int mylocal,
int numlocal, int mynode, int numnodes,
ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier,
int pipegpus, int pipenodes);
int create_communicator(communicator **comm, int myrank, int numranks, int mylocal, int numlocal,
int mynode, int numnodes,
std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather,
std::function<void(ExtComm)> ext_barrier);
int mynode, int numnodes, ExtAllgatherOp ext_allgather,
ExtBarrierOp ext_barrier);
int create_communicator_grouped2_mpi(communicator **comm, int pipegpus, int pipenodes,
int tensorgpus, int tensornodes);
......@@ -314,4 +316,6 @@ template <typename fp8type>
void reduce_fp8_in_bf16_out(void *input, void *output, float *scale, int num_inputs, int input_size,
cudaStream_t stream);
void reduce_bf16(void *input, void *output, int num_inputs, int input_size, cudaStream_t stream);
#endif // TRANSFORMER_ENGINE_USERBUFFERS_H_
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_COMM_GEMM_OVERLAP_H_
#define TRANSFORMER_ENGINE_COMMON_COMM_GEMM_OVERLAP_H_
#include <cuda.h>
#include <cuda_fp8.h>
#include <transformer_engine/transformer_engine.h>
#include <functional>
#include "common/comm_gemm_overlap/userbuffers/userbuffers.h"
#define NVTE_COMM_OVERLAP_MAX_STREAMS 3
namespace transformer_engine {
/* \brief Check if Userbufers bootstraps with direct calls to MPI collectives.
* This can turned on by building Transformer Engine with the `NVTE_UB_WITH_MPI=1` option.
*
* \return True if Userbuffers is built with MPI
*/
bool ubuf_built_with_mpi();
enum class CommOverlapType { RS = 0, AG = 1 };
enum class CommOverlapAlgo {
BULK_OVERLAP_AG = 0,
BULK_OVERLAP_RS = 1,
SPLIT_PIPELINED_AG_P2P = 2,
SPLIT_PIPELINED_RS = 3,
SPLIT_PIPELINED_RS_P2P = 4,
ATOMIC_GEMM_RS = 5,
ATOMIC_GEMM_AG_P2P = 6,
ATOMIC_GEMM_RS_P2P = 7
};
class CommOverlapCore {
protected:
static inline communicator *_ub_comm{nullptr};
static inline bool _comm_created{false};
int _rank;
int _tp_id;
int _tp_size;
int _num_splits;
int _math_sms;
int _num_comm_sm;
int _cga_size;
int _use_ce;
int _ub_reg;
bool _atomic_gemm{false};
bool _is_p2p{false};
TensorWrapper _ubuf;
TensorWrapper _counter;
float *_ubuf_scale_inv;
bool _ubuf_scale_inv_initialized{false};
std::vector<cudaStream_t> _stream_compute;
cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm;
public:
CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes,
int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle,
int num_splits, int num_max_streams, int comm_cga_size, int num_comm_sm,
bool set_sm_margin, bool use_ce, bool atomic_gemm);
virtual ~CommOverlapCore();
void set_ubuf_scale_inv(float *scale_inv) {
_ubuf_scale_inv = scale_inv;
_ubuf_scale_inv_initialized = true;
}
bool is_atomic_gemm() { return _atomic_gemm; }
bool is_p2p_overlap() { return _is_p2p; }
bool is_fp8_ubuf() { return _ubuf.element_size() == 1; }
}; // CommOverlapCore
class CommOverlapBase : public CommOverlapCore {
protected:
int _rs_kernel_type;
cudaStream_t _stream_comm;
cudaEvent_t _start_d2dcopy;
public:
CommOverlapBase(const std::vector<size_t> &buffer_shape, DType buffer_dtype, int myrank,
int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size,
ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3,
int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2,
int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false);
virtual ~CommOverlapBase();
/*
** Bulk GEMM + COMM
** This function assumes the communication input is pre-copied to _ubuf
*/
void bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, TensorWrapper &D,
TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
bool grad, bool accumulate, bool use_split_accumulator,
CommOverlapType comm_type, TensorWrapper &rs_output, cudaStream_t stream_main);
/*
** Split FPROP GEMM + ReduceScatter
*/
void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, bool gemm_overlap,
TensorWrapper &rs_output, cudaStream_t stream_main);
/*
** Split FPROP GEMM + ReduceScatter
*/
void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, bool gemm_overlap, TensorWrapper &rs_output,
cudaStream_t stream_main);
}; // CommOverlapBase
class CommOverlapP2PBase : public CommOverlapCore {
protected:
bool _is_reduce_scatter{false};
bool _use_multiatomic_ag{false};
int _next_rank;
int _prev_rank;
int _rank_round_tp;
int _aggregate;
int _num_ubuf_chunks;
int _self_chunk_id;
std::vector<TensorWrapper> _ubufs;
cudaStream_t _stream_send;
cudaStream_t _stream_recv;
cudaEvent_t _stop_send, _stop_recv;
public:
CommOverlapP2PBase(const std::vector<size_t> &buffer_shape, DType buffer_dtype, int myrank,
int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size,
ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle,
CommOverlapType comm_type, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS,
int comm_cga_size = 1, int num_comm_sm = 1, bool set_sm_margin = false,
bool use_ce = true, bool atomic_gemm = false, bool aggregate = false);
virtual ~CommOverlapP2PBase();
/*
** Split AllGather + AtomicGEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG
** outputs in each rank to be in the contiguous memory space after all ring exchange phases.
*/
void atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &B_copy,
cudaStream_t stream_main);
/*
** Split AllGather + GEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG
** outputs in each rank to be in the contiguous memory space after all ring exchange phases.
*/
void split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &B_copy,
cudaStream_t stream_main);
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main);
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main);
}; // CommOverlapP2PBase
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMM_GEMM_OVERLAP_H_
......@@ -78,13 +78,13 @@ NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType
*/
void nvte_destroy_tensor(NVTETensor tensor);
/*! \brief Get a tensor's data type.
/*! \brief Get a raw pointer to the tensor's data.
*
* \param[in] tensor Tensor.
*
* \return A data type of the input tensor.
* \return A raw pointer to tensor's data.
*/
NVTEDType nvte_tensor_type(const NVTETensor tensor);
void *nvte_tensor_data(const NVTETensor tensor);
/*! \brief Get a tensor's data shape.
*
......@@ -94,13 +94,46 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor);
*/
NVTEShape nvte_tensor_shape(const NVTETensor tensor);
/*! \brief Get a raw pointer to the tensor's data.
/*! \brief Get a tensor's number of dimensions.
*
* \param[in] tensor Tensor.
*
* \return A raw pointer to tensor's data.
* \return Number of tensor dimensions.
*/
void *nvte_tensor_data(const NVTETensor tensor);
size_t nvte_tensor_ndims(const NVTETensor tensor);
/*! \brief Get the size of a specific tensor dimension.
*
* \param[in] tensor Tensor.
* \param[in] size_t Dimension index.
*
* \return Size of the tensor at the specified dimension.
*/
size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim);
/*! \brief Get a tensor's total number of elements.
*
* \param[in] tensor Tensor.
*
* \return Number of elements in the tensor.
*/
size_t nvte_tensor_numel(const NVTETensor tensor);
/*! \brief Get the byte size for the tensor's data type.
*
* \param[in] tensor Tensor.
*
* \return Byte size of the tensor's data type.
*/
size_t nvte_tensor_element_size(const NVTETensor tensor);
/*! \brief Get a tensor's data type.
*
* \param[in] tensor Tensor.
*
* \return A data type of the input tensor.
*/
NVTEDType nvte_tensor_type(const NVTETensor tensor);
/*! \brief Get a pointer to the tensor's amax data.
*
......@@ -265,6 +298,56 @@ class TensorWrapper {
return nvte_tensor_shape(tensor_);
}
/*! \brief Get the size of this TensorWrapper in the given dimension.
*
* \param[in] size_t Dimension index.
*
* \return Size of this TensorWrapper in given dimension.
*/
size_t size(const size_t dim) const {
if (tensor_ == nullptr) return 0;
return nvte_tensor_size(tensor_, dim);
}
/*! \brief Get the number of dimensions for this TensorWrapper.
*
* \return Number of dimensions for this TensorWrapper.
*/
size_t ndim() const noexcept {
if (tensor_ == nullptr) return 0;
return nvte_tensor_ndims(tensor_);
}
/*! \brief Get the number of allocated elements in the tensor. This will return 0 for tensors
* with nullptr data even if the TensorWrapper has a non-zero shape.
*
*
* \return Number of elements in the tensor.
*/
size_t numel() const noexcept {
if (tensor_ == nullptr || this->dptr() == nullptr) return 0;
return nvte_tensor_numel(tensor_);
}
/*! \brief Get the tensor's element size in bytes.
*
* \return Element size in bytes.
*/
size_t element_size() const noexcept {
if (tensor_ == nullptr) return 0;
return nvte_tensor_element_size(tensor_);
}
/*! \brief Get the tensor's allocated size in bytes. This will return 0 for tensors with nullptr
* data even if the TensorWrapper has a non-zero shape and valid dtype.
*
* \return Total tensor size in bytes.
*/
size_t bytes() const noexcept {
if (tensor_ == nullptr || this->dptr() == nullptr) return 0;
return nvte_tensor_numel(tensor_) * nvte_tensor_element_size(tensor_);
}
/*! \brief Get the data type of this TensorWrapper.
*
* \return Data type of this TensorWrapper.
......@@ -317,6 +400,6 @@ class TensorWrapper {
} // namespace transformer_engine
#endif
#endif // __cplusplus
#endif // TRANSFORMER_ENGINE_TRANSFORMER_ENGINE_H_
......@@ -93,6 +93,31 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
return ret;
}
size_t nvte_tensor_ndim(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return t.data.shape.size();
}
size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTE_CHECK(dim >= 0 && dim < t.data.shape.size(), "Invalid dimension index: ", dim);
return t.data.shape[dim];
}
size_t nvte_tensor_numel(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
size_t numel = 1;
for (auto size : t.data.shape) {
numel *= size;
}
return numel;
}
size_t nvte_tensor_element_size(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return transformer_engine::typeToSize(t.data.dtype);
}
void *nvte_tensor_data(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return t.data.dptr;
......
......@@ -12,6 +12,7 @@
#include "../common.h"
#include "../util/cuda_driver.h"
#include "../util/system.h"
#include "common/util/cuda_runtime.h"
namespace transformer_engine {
......@@ -80,6 +81,31 @@ int sm_count(int device_id) {
return cache[device_id];
}
bool supports_multicast(int device_id) {
#if CUDART_VERSION >= 12010
// NOTE: This needs to be guarded at compile time because the
// CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED enum is not defined in earlier CUDA versions.
static std::vector<bool> cache(num_devices(), false);
static std::vector<std::once_flag> flags(num_devices());
if (device_id < 0) {
device_id = current_device();
}
NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID");
auto init = [&]() {
CUdevice cudev;
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &cudev, device_id);
int result;
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &result,
CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, cudev);
cache[device_id] = static_cast<bool>(result);
};
std::call_once(flags[device_id], init);
return cache[device_id];
#else
return false;
#endif
}
const std::string &include_directory(bool required) {
static std::string path;
......
......@@ -38,6 +38,14 @@ int sm_arch(int device_id = -1);
*/
int sm_count(int device_id = -1);
/* \brief CUDA Multicast support status for device
*
* \param[in] device_id CUDA device (default is current device)
*
* \return CUDA multicast support flag
*/
bool supports_multicast(int device_id = -1);
/* \brief Path to CUDA Toolkit headers
*
* The path can be configured by setting NVTE_CUDA_INCLUDE_DIR in the
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_
#include <pybind11/pybind11.h>
#include <transformer_engine/comm_gemm_overlap.h>
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/transformer_engine.h>
#include "cuda_runtime.h"
#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \
pybind11::enum_<transformer_engine::DType>(m, "DType") \
.value("kByte", transformer_engine::DType::kByte) \
.value("kInt32", transformer_engine::DType::kInt32) \
.value("kFloat32", transformer_engine::DType::kFloat32) \
.value("kFloat16", transformer_engine::DType::kFloat16) \
.value("kBFloat16", transformer_engine::DType::kBFloat16) \
.value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \
pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type") \
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \
.value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \
.value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \
.value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \
pybind11::enum_<NVTE_Mask_Type>(m, "NVTE_Mask_Type") \
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \
.value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) \
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \
pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout") \
.value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \
.value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \
.value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \
.value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) \
.value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) \
.value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) \
.value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) \
.value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) \
.value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) \
.value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) \
.value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) \
.value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \
.value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \
.value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \
.value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend") \
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \
.value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \
.value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \
pybind11::enum_<transformer_engine::CommOverlapType>(m, "CommOverlapType") \
.value("RS", transformer_engine::CommOverlapType::RS) \
.value("AG", transformer_engine::CommOverlapType::AG); \
pybind11::enum_<transformer_engine::CommOverlapAlgo>(m, "CommOverlapAlgo") \
.value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \
.value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \
.value("SPLIT_PIPELINED_AG_P2P", \
transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_P2P) \
.value("SPLIT_PIPELINED_RS", transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS) \
.value("SPLIT_PIPELINED_RS_P2P", \
transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \
.value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \
.value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \
.value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \
m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \
py::call_guard<py::gil_scoped_release>(), py::arg("device_id") = -1); \
m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \
py::call_guard<py::gil_scoped_release>());
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment