Unverified Commit 933294dc authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[C/PyTorch] Userbuffers and comm+GEMM overlap algorithms refactored and moved to TE/common (#1067)



* moved userbuffers code to TE/common
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* moved comm+GEMM overlap code to TE/common
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* removed PyTorch depdency from comm+GEMM overlap in TE/common
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added TE/PyTorch wrappers for refactored comm+GEMM overlap code in TE/common
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* updated TE/PyTorch Python API to match the refactored comm+GEMM overlap code
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* updated unit tests to work with refactored comm+GEMM overlap code
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added a pylint exception to comm+GEMM overlap test runner
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixing linting errors
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* added documentation for te.initialize_ub
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed compile errors when building with NVTE_UB_WITH_MPI=1
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed default bootstrap backend
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* switched default bootstrap backend priority to MPI > Gloo > NCCL
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* updated bootstrap backend documentation
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* close UB bootstrap socket to avoid interfering with CUDA Multicast shareable file handle send/recv
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added torch::Tensor wrappers for communication buffer and atomic counters so PyTorch can factor externally allocated memory into its garbage collection threshold
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* automated handling of world, local and node ranks/sizes within C++ CommOverlapHelper to simplify Python function signatures
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed incorrect read of environment variables
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected priority for _SOCKET_IFNAME environment variables in UB bootstrapping
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* moved multicast support check to cuda_runtime.h and replaced cudaDeviceGetProp call with cached sm_count()
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* removed commented out old code and replaced external collective function type defines with aliases
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* compile-time CUDA version guard for CUDA Driver Multicast attribute
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added compile-time CUDA version guards to Multicast code in Userbuffers
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* condensed UB docs, corrected const violations
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed autodoc rst for UB calls, added CUDA version guard on Multicast UB kernels
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed incorrect UB type reporting for P2P overlaps, comment reformatting
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* add docstring to tex.ubuf_built_with_mpi()
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 35bbe740
...@@ -38,3 +38,4 @@ dist/ ...@@ -38,3 +38,4 @@ dist/
downloads/ downloads/
.pytest_cache/ .pytest_cache/
compile_commands.json compile_commands.json
.nfs
...@@ -11,7 +11,6 @@ import setuptools ...@@ -11,7 +11,6 @@ import setuptools
from .utils import ( from .utils import (
all_files_in_dir, all_files_in_dir,
cuda_archs, cuda_archs,
cuda_path,
cuda_version, cuda_version,
) )
...@@ -29,9 +28,6 @@ def setup_pytorch_extension( ...@@ -29,9 +28,6 @@ def setup_pytorch_extension(
sources = [ sources = [
csrc_source_files / "common.cu", csrc_source_files / "common.cu",
csrc_source_files / "ts_fp8_op.cpp", csrc_source_files / "ts_fp8_op.cpp",
csrc_source_files / "userbuffers" / "ipcsocket.cc",
csrc_source_files / "userbuffers" / "userbuffers.cu",
csrc_source_files / "userbuffers" / "userbuffers-host.cpp",
] + all_files_in_dir(extensions_dir) ] + all_files_in_dir(extensions_dir)
# Header files # Header files
...@@ -85,19 +81,14 @@ def setup_pytorch_extension( ...@@ -85,19 +81,14 @@ def setup_pytorch_extension(
continue # Already handled continue # Already handled
nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"]) nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"])
# Libraries if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
library_dirs = []
libraries = []
if bool(int(os.getenv("NVTE_UB_WITH_MPI", 0))):
assert ( assert (
os.getenv("MPI_HOME") is not None os.getenv("MPI_HOME") is not None
), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1" ), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!"
mpi_home = Path(os.getenv("MPI_HOME")) mpi_path = Path(os.getenv("MPI_HOME"))
include_dirs.append(mpi_home / "include") include_dirs.append(mpi_path / "include")
cxx_flags.append("-DNVTE_UB_WITH_MPI") cxx_flags.append("-DNVTE_UB_WITH_MPI")
nvcc_flags.append("-DNVTE_UB_WITH_MPI") nvcc_flags.append("-DNVTE_UB_WITH_MPI")
library_dirs.append(mpi_home / "lib")
libraries.append("mpi")
# Construct PyTorch CUDA extension # Construct PyTorch CUDA extension
sources = [str(path) for path in sources] sources = [str(path) for path in sources]
...@@ -112,6 +103,4 @@ def setup_pytorch_extension( ...@@ -112,6 +103,4 @@ def setup_pytorch_extension(
"cxx": cxx_flags, "cxx": cxx_flags,
"nvcc": nvcc_flags, "nvcc": nvcc_flags,
}, },
libraries=[str(lib) for lib in libraries],
library_dirs=[str(lib_dir) for lib_dir in library_dirs],
) )
...@@ -51,3 +51,7 @@ pyTorch ...@@ -51,3 +51,7 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.moe_permute .. autoapifunction:: transformer_engine.pytorch.moe_permute
.. autoapifunction:: transformer_engine.pytorch.moe_unpermute .. autoapifunction:: transformer_engine.pytorch.moe_unpermute
.. autoapifunction:: transformer_engine.pytorch.initialize_ub
.. autoapifunction:: transformer_engine.pytorch.destroy_ub
...@@ -57,13 +57,20 @@ class TimedBdist(bdist_wheel): ...@@ -57,13 +57,20 @@ class TimedBdist(bdist_wheel):
def setup_common_extension() -> CMakeExtension: def setup_common_extension() -> CMakeExtension:
"""Setup CMake extension for common library""" """Setup CMake extension for common library"""
cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(cuda_archs())]
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
assert (
os.getenv("MPI_HOME") is not None
), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
cmake_flags.append("-DNVTE_UB_WITH_MPI=ON")
# Project directory root # Project directory root
root_path = Path(__file__).resolve().parent root_path = Path(__file__).resolve().parent
return CMakeExtension( return CMakeExtension(
name="transformer_engine", name="transformer_engine",
cmake_path=root_path / Path("transformer_engine/common"), cmake_path=root_path / Path("transformer_engine/common"),
cmake_flags=["-DCMAKE_CUDA_ARCHITECTURES={}".format(cuda_archs())], cmake_flags=cmake_flags,
) )
......
...@@ -22,6 +22,7 @@ import transformer_engine.pytorch.cpp_extensions as tex ...@@ -22,6 +22,7 @@ import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.common.recipe import Format from transformer_engine.common.recipe import Format
from transformer_engine.pytorch.fp8 import _default_sf_compute from transformer_engine.pytorch.fp8 import _default_sf_compute
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=UserWarning)
...@@ -32,8 +33,8 @@ torch_dtypes = { ...@@ -32,8 +33,8 @@ torch_dtypes = {
} }
nvte_comm_types = { nvte_comm_types = {
"rs": 0, "rs": tex.CommOverlapType.RS,
"ag": 1, "ag": tex.CommOverlapType.AG,
} }
...@@ -75,7 +76,7 @@ def _parse_args(argv=None, namespace=None): ...@@ -75,7 +76,7 @@ def _parse_args(argv=None, namespace=None):
parser.add_argument( parser.add_argument(
"--comm-type", "--comm-type",
type=partial(_mapped_argtype, typemap=nvte_comm_types), type=partial(_mapped_argtype, typemap=nvte_comm_types),
default=0, default=tex.CommOverlapType.AG,
help="Comm type to overlap.", help="Comm type to overlap.",
) )
parser.add_argument( parser.add_argument(
...@@ -156,11 +157,9 @@ def _parse_args(argv=None, namespace=None): ...@@ -156,11 +157,9 @@ def _parse_args(argv=None, namespace=None):
if opts.fp8: if opts.fp8:
warnings.warn("Bulk overlap is supported in FP8 but only tested in BF16.") warnings.warn("Bulk overlap is supported in FP8 but only tested in BF16.")
opts.fp8 = False opts.fp8 = False
elif opts.comm_type == 1: elif opts.comm_type == tex.CommOverlapType.AG:
if opts.atomic: if opts.atomic:
setattr(opts, "atomic_rs_p2p", opts.p2p) setattr(opts, "atomic_rs_p2p", opts.p2p)
if not opts.p2p:
warnings.warn("All-gather overlap is only supported with point-2-point comms.")
opts.p2p = True opts.p2p = True
if opts.atomic: if opts.atomic:
...@@ -283,35 +282,35 @@ def _main(opts): ...@@ -283,35 +282,35 @@ def _main(opts):
if WORLD_RANK == 0: if WORLD_RANK == 0:
print("\n", end="", flush=True) print("\n", end="", flush=True)
ub_callbacks = ( helper = (
tex.UbufBootstrapCallbacks() tex.CommOverlapHelper()
if tex.ubuf_built_with_mpi() if tex.ubuf_built_with_mpi()
else tex.UbufBootstrapCallbacks(bootstrap_pg, bootstrap_pg) else tex.CommOverlapHelper(bootstrap_pg)
) )
if opts.comm_type == 0: if opts.comm_type == tex.CommOverlapType.RS:
if opts.bulk_overlap: if opts.bulk_overlap:
ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_RS ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_RS
elif opts.p2p: elif opts.p2p:
ub_algo = ( ub_algo = (
tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P
if opts.atomic if opts.atomic
else tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P else tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P
) )
else: else:
ub_algo = ( ub_algo = (
tex.UbufOverlapAlgo.ATOMIC_GEMM_RS tex.CommOverlapAlgo.ATOMIC_GEMM_RS
if opts.atomic if opts.atomic
else tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS else tex.CommOverlapAlgo.SPLIT_PIPELINED_RS
) )
elif opts.comm_type == 1: elif opts.comm_type == tex.CommOverlapType.AG:
if opts.bulk_overlap: if opts.bulk_overlap:
ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG
else: else:
ub_algo = ( ub_algo = (
tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P
if opts.atomic if opts.atomic
else tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P else tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
) )
else: else:
raise TypeError("Invalid comm+GEMM overlap type!") raise TypeError("Invalid comm+GEMM overlap type!")
...@@ -322,95 +321,55 @@ def _main(opts): ...@@ -322,95 +321,55 @@ def _main(opts):
hidden_size = opts.num_heads * opts.head_dim hidden_size = opts.num_heads * opts.head_dim
inp_shape = (opts.seq_length, opts.batch_size, hidden_size) inp_shape = (opts.seq_length, opts.batch_size, hidden_size)
outer_size = reduce(operator.mul, inp_shape[:-1], 1) outer_size = reduce(operator.mul, inp_shape[:-1], 1)
ubuf_dtype = torch.bfloat16 buffer_dtype = torch.bfloat16
if opts.fp8 and not opts.bulk_overlap and (opts.comm_type == 1 or opts.fp8_output): if (
ubuf_dtype = torch.uint8 opts.fp8
sample_buffer = torch.empty((outer_size, hidden_size), dtype=ubuf_dtype, device="cuda") and not opts.bulk_overlap
ub_obj = ub_obj = ( and (opts.comm_type == tex.CommOverlapType.AG or opts.fp8_output)
tex.UbufP2PCommOverlap( ):
sample_buffer, # Sample userbuffer buffer_dtype = torch.uint8
WORLD_RANK, # World rank ub_obj = (
WORLD_SIZE, # World size tex.CommOverlapP2P(
LOCAL_RANK, # Rank within the node (outer_size, hidden_size),
LOCAL_SIZE, # Number of ranks/GPUs per node buffer_dtype,
0, # Node ID helper,
1, # Number of nodes
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
1, # Number of communication SMs opts.comm_type,
1, # CGA cluster size set_sm_margin=opts.comm_type == tex.CommOverlapType.RS or opts.atomic,
opts.comm_type == 0 or opts.atomic, # Set SM margin atomic_gemm=opts.atomic,
opts.aggregate, # Aggregate 2X GEMM chunks aggregate=opts.aggregate,
3, # Max concurrent GEMM streams use_ce=not (opts.atomic and bool(int(os.getenv("NVTE_AG_P2P_MULTI_ATOMIC", "0")))),
opts.comm_type == 0, # overlap with reduce scatter
opts.atomic, # use a single GEMM with atomic-counters
not (opts.atomic and bool(int(os.getenv("NVTE_AG_P2P_MULTI_ATOMIC", "0")))),
ub_callbacks,
) )
if opts.p2p if opts.p2p
else tex.UbufCommOverlap( else tex.CommOverlap(
sample_buffer, # Sample userbuffer (outer_size, hidden_size),
WORLD_RANK, # World rank buffer_dtype,
WORLD_SIZE, # World size helper,
LOCAL_RANK, # Rank within the node
LOCAL_SIZE, # Number of ranks/GPUs per node
0, # Node ID
1, # Number of nodes
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
16, # Number of communication SMs atomic_gemm=opts.atomic,
2, # CGA cluster size
4, # Number of communication splits
True, # Set SM margin
3, # Max concurrent GEMM streams
opts.atomic, # Use a single GEMM with atomic-counters
ub_callbacks,
) )
) )
# Numerical check on AG + atomic GEMM requires testing an AG+RS pair # Numerical check on AG + atomic GEMM requires testing an AG+RS pair
ub_obj2 = None ub_obj2 = None
if opts.atomic and opts.comm_type == 1 and opts.check_numerics: if opts.atomic and opts.comm_type == tex.CommOverlapType.AG and opts.check_numerics:
sample_buffer2 = torch.empty(
(outer_size, hidden_size),
dtype=torch.uint8 if opts.fp8_output else torch.bfloat16,
device="cuda",
)
ub_obj2 = ( ub_obj2 = (
tex.UbufP2PCommOverlap( tex.CommOverlapP2P(
sample_buffer2, # Sample userbuffer (outer_size, hidden_size),
WORLD_RANK, # World rank torch.uint8 if opts.fp8_output else torch.bfloat16,
WORLD_SIZE, # World size helper,
LOCAL_RANK, # Rank within the node
LOCAL_SIZE, # Number of ranks/GPUs per node
0, # Node ID
1, # Number of nodes
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
1, # Number of communication SMs tex.CommOverlapType.RS,
1, # CGA cluster size set_sm_margin=True,
True, # Set SM margin atomic_gemm=True,
False, # Aggregate 2X GEMM chunks
3, # Max concurrent GEMM streams
True, # overlap with reduce scatter
True, # use a single GEMM with atomic-counters
True, # use copy engine for P2P communications
ub_callbacks,
) )
if opts.atomic_rs_p2p if opts.atomic_rs_p2p
else tex.UbufCommOverlap( else tex.CommOverlap(
sample_buffer2, # Sample userbuffer (outer_size, hidden_size),
WORLD_RANK, # World rank torch.uint8 if opts.fp8_output else torch.bfloat16,
WORLD_SIZE, # World size helper,
LOCAL_RANK, # Rank within the node
LOCAL_SIZE, # Number of ranks/GPUs per node
0, # Node ID
1, # Number of nodes
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
16, # Number of communication SMs atomic_gemm=True,
2, # CGA cluster size
4, # Number of communication splits
True, # Set SM margin
3, # Max concurrent GEMM streams
True, # uUe a single GEMM with atomic-counters
ub_callbacks,
) )
) )
...@@ -426,12 +385,12 @@ def _main(opts): ...@@ -426,12 +385,12 @@ def _main(opts):
local_kernel_t_shape = (ffn_hidden_size, hidden_size) local_kernel_t_shape = (ffn_hidden_size, hidden_size)
local_inp_shape = (outer_size, hidden_size) local_inp_shape = (outer_size, hidden_size)
# Bulk overlap comm tensor is distributed for AG overlap only # Bulk overlap comm tensor is distributed for AG overlap only
if opts.comm_type == 1: if opts.comm_type == tex.CommOverlapType.AG:
bulk_inp_shape = (outer_size // tp_size, hidden_size) bulk_inp_shape = (outer_size // tp_size, hidden_size)
else: else:
bulk_inp_shape = (outer_size, hidden_size) bulk_inp_shape = (outer_size, hidden_size)
else: else:
if opts.comm_type == 1: if opts.comm_type == tex.CommOverlapType.AG:
# (M/P, N) -> overlapped AG -> (M, N) x (K/P, N)^T = (M, K/P) # (M/P, N) -> overlapped AG -> (M, N) x (K/P, N)^T = (M, K/P)
local_kernel_t_shape = (ffn_hidden_size // tp_size, hidden_size) local_kernel_t_shape = (ffn_hidden_size // tp_size, hidden_size)
local_inp_shape = (outer_size // tp_size, hidden_size) local_inp_shape = (outer_size // tp_size, hidden_size)
...@@ -472,7 +431,7 @@ def _main(opts): ...@@ -472,7 +431,7 @@ def _main(opts):
std=opts.std, std=opts.std,
) )
else: else:
if opts.comm_type == 1: if opts.comm_type == tex.CommOverlapType.AG:
# AG Kernel: (K/P, N) -> gather -> (K, N) -> T -> (N, K) # AG Kernel: (K/P, N) -> gather -> (K, N) -> T -> (N, K)
ker_g = torch.transpose( ker_g = torch.transpose(
te.distributed.gather_along_first_dim(kernel_t, tp_group)[0], 0, 1 te.distributed.gather_along_first_dim(kernel_t, tp_group)[0], 0, 1
...@@ -494,7 +453,7 @@ def _main(opts): ...@@ -494,7 +453,7 @@ def _main(opts):
).to(dtype=torch.float32) ).to(dtype=torch.float32)
if opts.bulk_overlap: if opts.bulk_overlap:
if opts.comm_type == 1: if opts.comm_type == tex.CommOverlapType.AG:
ref_g = te.distributed.gather_along_first_dim(bulk_inp, tp_group)[0] ref_g = te.distributed.gather_along_first_dim(bulk_inp, tp_group)[0]
else: else:
# First all-gather all the bulk inputs into a list # First all-gather all the bulk inputs into a list
...@@ -505,7 +464,7 @@ def _main(opts): ...@@ -505,7 +464,7 @@ def _main(opts):
else: else:
ref_g = torch.matmul(inp_g, ker_g) ref_g = torch.matmul(inp_g, ker_g)
if ub_obj2 is not None: if ub_obj2 is not None:
inp2_g = torch.nn.functional.gelu(ref_g) inp2_g = torch.nn.functional.gelu(ref_g) # pylint: disable=not-callable
ref2_g = torch.matmul(inp2_g, ker2_g) ref2_g = torch.matmul(inp2_g, ker2_g)
if opts.fp8: if opts.fp8:
...@@ -529,7 +488,7 @@ def _main(opts): ...@@ -529,7 +488,7 @@ def _main(opts):
fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_WEIGHT].copy_(ker_amax) fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_WEIGHT].copy_(ker_amax)
ref_amax = torch.max(torch.abs(ref_g)) ref_amax = torch.max(torch.abs(ref_g))
fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_OUTPUT].copy_(ref_amax) fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_OUTPUT].copy_(ref_amax)
if opts.bulk_overlap and opts.comm_type == 0: if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS:
bulk_amax = torch.max(torch.abs(bulk_inp)) bulk_amax = torch.max(torch.abs(bulk_inp))
fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_OUTPUT].copy_(bulk_amax) fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_OUTPUT].copy_(bulk_amax)
elif ub_obj2 is not None: elif ub_obj2 is not None:
...@@ -551,7 +510,7 @@ def _main(opts): ...@@ -551,7 +510,7 @@ def _main(opts):
kernel_t_fp8 = tex.cast_to_fp8( kernel_t_fp8 = tex.cast_to_fp8(
kernel_t, fp8_meta, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype kernel_t, fp8_meta, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype
) )
if opts.bulk_overlap and opts.comm_type == 0: if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS:
bulk_inp_fp8 = tex.cast_to_fp8( bulk_inp_fp8 = tex.cast_to_fp8(
bulk_inp, fp8_meta, tex.FP8Tensors.GEMM2_OUTPUT, fp8_dtype bulk_inp, fp8_meta, tex.FP8Tensors.GEMM2_OUTPUT, fp8_dtype
) )
...@@ -574,7 +533,7 @@ def _main(opts): ...@@ -574,7 +533,7 @@ def _main(opts):
rtol=0.125, rtol=0.125,
atol=0.0675, atol=0.0675,
) )
if opts.bulk_overlap and opts.comm_type == 0: if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS:
torch.allclose( torch.allclose(
bulk_inp.to(dtype=torch.float32), bulk_inp.to(dtype=torch.float32),
bulk_inp_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT], bulk_inp_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT],
...@@ -590,7 +549,7 @@ def _main(opts): ...@@ -590,7 +549,7 @@ def _main(opts):
) )
# Set Fp8 scales for userbuffers # Set Fp8 scales for userbuffers
if opts.comm_type == 1: if opts.comm_type == tex.CommOverlapType.AG:
ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_INPUT]) ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_INPUT])
if ub_obj2 is not None: if ub_obj2 is not None:
ub_obj2.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT]) ub_obj2.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT])
...@@ -602,7 +561,7 @@ def _main(opts): ...@@ -602,7 +561,7 @@ def _main(opts):
# Set up comm/compute buffers # Set up comm/compute buffers
ubuf_out2 = None ubuf_out2 = None
rs_out2 = None rs_out2 = None
if opts.comm_type == 1: if opts.comm_type == tex.CommOverlapType.AG:
if opts.bulk_overlap: if opts.bulk_overlap:
ub_obj.copy_input_to_ubuf(bulk_inp, 1) ub_obj.copy_input_to_ubuf(bulk_inp, 1)
gemm_inp = inp gemm_inp = inp
...@@ -686,9 +645,9 @@ def _main(opts): ...@@ -686,9 +645,9 @@ def _main(opts):
gelu=False, gelu=False,
use_split_accumulator=te.module.base._2X_ACC_FPROP, use_split_accumulator=te.module.base._2X_ACC_FPROP,
ub_algo=( ub_algo=(
tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P
if opts.atomic_rs_p2p if opts.atomic_rs_p2p
else tex.UbufOverlapAlgo.ATOMIC_GEMM_RS else tex.CommOverlapAlgo.ATOMIC_GEMM_RS
), ),
ub=ub_obj2, ub=ub_obj2,
extra_output_tensor=rs_out2, extra_output_tensor=rs_out2,
...@@ -762,10 +721,14 @@ def _main(opts): ...@@ -762,10 +721,14 @@ def _main(opts):
avg_gpu_time = sum(gpu_times) / opts.timing_iters avg_gpu_time = sum(gpu_times) / opts.timing_iters
gemm_name = "".join( gemm_name = "".join(
[ [
"p2p all-gather + " if opts.comm_type == 1 else "", "p2p all-gather + " if opts.comm_type == tex.CommOverlapType.AG else "",
"atomic " if opts.atomic else "", "atomic " if opts.atomic else "",
"GEMM", "GEMM",
(f" + {'p2p ' if opts.p2p else ''}reduce-scatter" if opts.comm_type == 0 else ""), (
f" + {'p2p ' if opts.p2p else ''}reduce-scatter"
if opts.comm_type == tex.CommOverlapType.RS
else ""
),
] ]
) )
timing_info = ( timing_info = (
...@@ -781,7 +744,7 @@ def _main(opts): ...@@ -781,7 +744,7 @@ def _main(opts):
dist.barrier(tp_group) dist.barrier(tp_group)
if opts.bulk_overlap: if opts.bulk_overlap:
output_info = "" output_info = ""
if opts.comm_type == 1: if opts.comm_type == tex.CommOverlapType.AG:
# Bulk overlap AG output is already gathered # Bulk overlap AG output is already gathered
test_out = ub_obj.get_ubuf_output(1) test_out = ub_obj.get_ubuf_output(1)
else: else:
...@@ -794,7 +757,7 @@ def _main(opts): ...@@ -794,7 +757,7 @@ def _main(opts):
output_info += f"output: {list(test_out.shape)} | reference: {list(ref_out.shape)}" output_info += f"output: {list(test_out.shape)} | reference: {list(ref_out.shape)}"
dist_print( dist_print(
output_info, output_info,
src=0 if opts.comm_type == 0 else None, src=0 if opts.comm_type == tex.CommOverlapType.RS else None,
section=True, section=True,
) )
...@@ -805,7 +768,7 @@ def _main(opts): ...@@ -805,7 +768,7 @@ def _main(opts):
) )
dist_print(nonzero_info, src=0, section=True, group=tp_group) dist_print(nonzero_info, src=0, section=True, group=tp_group)
else: else:
if opts.comm_type == 1: if opts.comm_type == tex.CommOverlapType.AG:
if ub_obj2 is not None: if ub_obj2 is not None:
# AG+RS Output: (M/P, N) -> gather -> (M, N) # AG+RS Output: (M/P, N) -> gather -> (M, N)
output = rs_out2.to(dtype=torch.float32) output = rs_out2.to(dtype=torch.float32)
......
...@@ -9,7 +9,6 @@ import sys ...@@ -9,7 +9,6 @@ import sys
import socket import socket
import argparse import argparse
import warnings import warnings
from functools import partial
import torch import torch
import torch.distributed as dist import torch.distributed as dist
......
...@@ -42,6 +42,9 @@ if not tex.device_supports_multicast(): ...@@ -42,6 +42,9 @@ if not tex.device_supports_multicast():
# Force GPU kernels to launch in the order they're executed by the host CPU # Force GPU kernels to launch in the order they're executed by the host CPU
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
# Clear torch.dynamo caches
torch._dynamo.reset()
def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggregate): def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggregate):
test_path = TEST_ROOT / "run_gemm_with_overlap.py" test_path = TEST_ROOT / "run_gemm_with_overlap.py"
......
...@@ -80,7 +80,11 @@ list(APPEND transformer_engine_SOURCES ...@@ -80,7 +80,11 @@ list(APPEND transformer_engine_SOURCES
fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_rope/fused_rope.cu fused_rope/fused_rope.cu
recipe/delayed_scaling.cu) recipe/delayed_scaling.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/comm_gemm_overlap.cpp)
add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
target_include_directories(transformer_engine PUBLIC target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include") "${CMAKE_CURRENT_SOURCE_DIR}/include")
...@@ -93,6 +97,15 @@ target_include_directories(transformer_engine PRIVATE ...@@ -93,6 +97,15 @@ target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
# Compiling Userbuffers with native MPI bootstrapping requires linking against MPI
option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF)
if (NVTE_UB_WITH_MPI)
find_package(MPI REQUIRED)
target_link_libraries(transformer_engine PUBLIC MPI::MPI_CXX)
target_include_directories(transformer_engine PRIVATE ${MPI_CXX_INCLUDES})
target_compile_definitions(transformer_engine PUBLIC NVTE_UB_WITH_MPI)
endif()
# Hack to enable dynamic loading in cuDNN frontend # Hack to enable dynamic loading in cuDNN frontend
target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING)
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/comm_gemm_overlap.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h>
#include <cassert>
#include <numeric>
#include "common/common.h"
#include "common/util/cuda_driver.h"
#include "common/util/cuda_runtime.h"
#include "common/util/logging.h"
#include "common/util/system.h"
#include "userbuffers/userbuffers.h"
#define HALF_BYTES 2
#define UB_MAX_SM 32
using namespace std::placeholders;
namespace transformer_engine {
/***************************************************************************************************
* Comm+GEMM Overlap Common Core
**************************************************************************************************/
bool ubuf_built_with_mpi() {
#ifdef NVTE_UB_WITH_MPI
return true;
#else
return false;
#endif
}
CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode,
int numnodes, int tp_size, ExtAllgatherOp allgather_handle,
ExtBarrierOp barrier_handle, int num_splits, int num_max_streams,
int comm_cga_size, int num_comm_sm, bool set_sm_margin,
bool use_ce, bool atomic_gemm) {
// Initialize userbuf communicator
if (!_comm_created) {
if (myrank == 0) {
printf("!!! [UB] Create Userbuffers Communicator\n");
}
#ifdef NVTE_UB_WITH_MPI
create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1);
#else
create_communicator_grouped2(&_ub_comm, myrank, numranks, mylocal, numlocal, mynode, numnodes,
allgather_handle, barrier_handle, 1, 1, tp_size, 1);
#endif
_comm_created = true;
}
_use_ce = static_cast<int>(use_ce);
_num_comm_sm = num_comm_sm;
_cga_size = comm_cga_size;
for (int i = 0; i < std::min(num_max_streams, num_splits); i++) {
cudaStream_t stream;
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1));
_stream_compute.push_back(std::move(stream));
}
_num_splits = num_splits;
_rank = _ub_comm->myrank;
_tp_size = tp_size;
_tp_id = _rank % _tp_size;
// Set the number of SMs for GEMM with margin
int sm_count = transformer_engine::cuda::sm_count();
_math_sms = (set_sm_margin) ? sm_count - num_comm_sm : sm_count;
_math_sms -= transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", 0);
_atomic_gemm = atomic_gemm;
if (_atomic_gemm) {
void *counter_ptr;
size_t counter_bytes = _num_splits * 2 * sizeof(int32_t);
NVTE_CHECK_CUDA(cudaMalloc(&counter_ptr, counter_bytes));
NVTE_CHECK_CUDA(cudaMemset(counter_ptr, 0, counter_bytes));
NVTE_CHECK_CUDA(cudaMemset(counter_ptr, 1, counter_bytes / 2));
_counter = TensorWrapper(counter_ptr, std::vector<size_t>{static_cast<size_t>(_num_splits * 2)},
DType::kInt32);
}
// CUDA event creation
cudaEventCreateWithFlags(&_start_compute, 0);
cudaEventCreateWithFlags(&_stop_compute, 0);
cudaEventCreateWithFlags(&_start_comm, 0);
cudaEventCreateWithFlags(&_stop_comm, 0);
}
CommOverlapCore::~CommOverlapCore() {
cudaEventDestroy(_stop_comm);
cudaEventDestroy(_start_comm);
cudaEventDestroy(_stop_compute);
cudaEventDestroy(_start_compute);
if (_atomic_gemm) cudaFree(_counter.dptr());
for (size_t i = 0; i < _stream_compute.size(); i++) cudaStreamDestroy(_stream_compute[i]);
if (_comm_created) {
#ifdef NVTE_UB_WITH_MPI
destroy_communicator_mpi(_ub_comm);
#else
destroy_communicator(_ub_comm);
#endif
_comm_created = false;
}
}
/***************************************************************************************************
* Comm+GEMM Overlap Base (Pipelined / Collective)
**************************************************************************************************/
CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType buffer_dtype,
int myrank, int numranks, int mylocal, int numlocal, int mynode,
int numnodes, int tp_size, ExtAllgatherOp allgather_handle,
ExtBarrierOp barrier_handle, int num_splits, int num_max_streams,
int comm_cga_size, int num_comm_sm, bool set_sm_margin,
bool atomic_gemm)
: CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size,
allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size,
num_comm_sm, set_sm_margin, false, atomic_gemm) {
_rs_kernel_type = getenv<int>("NVTE_RS_STRIDED_ATOMIC", 0);
NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3,
"Invalid choice for NVTE_RS_STRIDED_ATOMIC: Must be 0 (non-atomic), 1 (atomic) ",
"or 2 (multi-atomic).");
NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!");
size_t buffer_bytes = buffer_shape[0] * buffer_shape[1] * typeToSize(buffer_dtype);
void *buffer_ptr;
_ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true);
if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg);
_ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype);
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_comm, cudaStreamNonBlocking, -1));
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_d2dcopy, 0));
}
CommOverlapBase::~CommOverlapBase() {
cudaEventDestroy(_start_d2dcopy);
cudaStreamDestroy(_stream_comm);
}
/*
** Bulk GEMM + COMM
** This function assumes the communication input is pre-copied to _ubuf
*/
void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator,
CommOverlapType comm_type, TensorWrapper &rs_output,
cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
// Catch up the default torch stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0));
// Communication: AG and RS
int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size
if (comm_type == CommOverlapType::AG) {
allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm);
} else {
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
comm_elements *= 2;
assert(rs_output.numel() == _ubuf.numel() / _tp_size);
assert(rs_output.size(0) == _ubuf.size(0) / _tp_size);
assert(rs_output.element_size() == 2);
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf_scale_inv, _ub_reg, 0,
comm_elements, _ub_comm, _stream_comm);
} else {
reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm);
}
}
assert(pre_gelu_out.numel() == 0);
nvte_cublas_gemm(A.data(), B.data(), D.data(), bias.data(), pre_gelu_out.data(), transa, transb,
grad, workspace.data(), accumulate, use_split_accumulator, _math_sms,
stream_main);
_ub_comm->sms = ori_sms;
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
} // CommOverlapBase::bulk_overlap
/*
** Split FPROP GEMM + ReduceScatter
*/
void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
bool grad, bool accumulate, bool use_split_accumulator,
bool gemm_overlap, TensorWrapper &rs_output,
cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
// Get GEMM dimensions
size_t m = A.size(0);
size_t k = A.size(1);
size_t n = B.size(0);
size_t m_chunk = m / _num_splits;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
// Get input, output, and workspace data pointers
char *input_a_chunk_ptr = reinterpret_cast<char *>(A.dptr());
char *output_buf_chunk_ptr = reinterpret_cast<char *>(_ubuf.dptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.dptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
// Reset atomic counters
int *counter_ptr = reinterpret_cast<int *>(_counter.dptr());
reset_counters(counter_ptr, _num_splits, false, stream_main);
// Catch up the default torch stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_compute, 0));
assert(pre_gelu_out.numel() == 0);
auto output_d = TensorWrapper(_ubuf.dptr(), {n, m}, D.dtype(), D.amax(), D.scale(), nullptr);
auto workspace_chunk =
TensorWrapper(workspace.dptr(), std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(),
transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _num_splits, 0, true, _counter.data(),
_stream_compute[0]);
for (int i = 0; i < _num_splits; i++) {
if (_rs_kernel_type == 1) {
if (i == _num_splits - 1) {
_ub_comm->sms = UB_MAX_SM;
}
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reducescatter2_userbuff_strided_atomic_fp8<fp8_type>(
rs_output_ptr, _ubuf_scale_inv, _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits,
&counter_ptr[i], _ub_comm, _stream_comm););
} else {
reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m,
_num_splits, &counter_ptr[i], _ub_comm,
_stream_comm);
}
} else if (_rs_kernel_type == 2) {
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reducescatter2_userbuff_strided_multiatomic_fp8<fp8_type>(
rs_output_ptr, _ubuf_scale_inv, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits,
counter_ptr, _ub_comm, _stream_comm););
} else {
reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, n, m,
_num_splits, counter_ptr, _ub_comm,
_stream_comm);
}
break;
} else {
consumer(counter_ptr, i, _stream_comm);
if (_ubuf.element_size() == 1) {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(rs_output_ptr, _ubuf_scale_inv,
_ub_reg, i * m_chunk, m_chunk, n, m,
_ub_comm, _stream_comm););
} else {
reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m,
_ub_comm, _stream_comm);
}
}
rs_output_ptr += m_chunk * rs_output.element_size();
}
_ub_comm->sms = ori_sms;
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[0]));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
} // split_overlap_rs
/*
** Split FPROP GEMM + ReduceScatter
*/
void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
bool grad, bool accumulate, bool use_split_accumulator,
bool gemm_overlap, TensorWrapper &rs_output,
cudaStream_t stream_main) {
// Get GEMM dimensions
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
size_t m = A.size(0);
size_t k = A.size(1);
size_t n = B.size(0);
size_t m_chunk = m / _num_splits;
size_t input_a_chunk_size = m_chunk * k;
size_t output_chunk_size = n * m_chunk;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
// Get input, output, and workspace data pointers
char *input_a_chunk_ptr = reinterpret_cast<char *>(A.dptr());
char *output_buf_chunk_ptr = reinterpret_cast<char *>(_ubuf.dptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.dptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
// Catch up the default torch stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0));
}
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_compute, 0));
assert(pre_gelu_out.numel() == 0);
if (gemm_overlap) {
auto input_a_chunk =
TensorWrapper(A.dptr(), {m_chunk, k}, A.dtype(), nullptr, nullptr, A.scale_inv());
auto output_chunk =
TensorWrapper(_ubuf.dptr(), {m, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr);
auto workspace_chunk = TensorWrapper(
workspace.dptr(), std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[0]);
for (int i = 1; i < _num_splits; i++) {
input_a_chunk_ptr += input_a_chunk_size * B.element_size();
output_buf_chunk_ptr += output_chunk_size * D.element_size();
char *workspace_chunk_ptr =
workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk;
input_a_chunk = TensorWrapper(reinterpret_cast<void *>(input_a_chunk_ptr), {m_chunk, k},
A.dtype(), nullptr, nullptr, A.scale_inv());
output_chunk = TensorWrapper(reinterpret_cast<void *>(output_buf_chunk_ptr), {n, m_chunk},
D.dtype(), D.amax(), D.scale(), nullptr);
workspace_chunk = TensorWrapper(reinterpret_cast<void *>(workspace_chunk_ptr),
std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
NVTE_CHECK_CUDA(
cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0));
// Communication chunk
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, _ubuf_scale_inv, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m,
_ub_comm, _stream_comm););
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size,
m_chunk, n, m, _ub_comm, _stream_comm);
}
rs_output_ptr += m_chunk * rs_output.element_size();
}
int last_compute_stream_id =
(_num_splits + _stream_compute.size() - 1) % _stream_compute.size();
NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[last_compute_stream_id]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0));
// Last communication chunk with max SM
_ub_comm->sms = UB_MAX_SM;
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, _ubuf_scale_inv, _ub_reg, (_num_splits - 1) * output_chunk_size,
m_chunk, n, m, _ub_comm, _stream_comm););
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg,
(_num_splits - 1) * output_chunk_size, m_chunk, n, m,
_ub_comm, _stream_comm);
}
} else {
for (int i = 0; i < _num_splits; i++) {
char *workspace_chunk_ptr =
workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk;
auto input_a_chunk = TensorWrapper(reinterpret_cast<void *>(input_a_chunk_ptr), {m_chunk, k},
A.dtype(), nullptr, nullptr, A.scale_inv());
auto output_chunk = TensorWrapper(reinterpret_cast<void *>(output_buf_chunk_ptr),
{n, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr);
auto workspace_chunk =
TensorWrapper(reinterpret_cast<void *>(workspace_chunk_ptr),
std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[i % _stream_compute.size()]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0));
// Communication chunk. Uses MAX_SM at the last chunk
if (i == _num_splits - 1) {
_ub_comm->sms = UB_MAX_SM;
}
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, _ubuf_scale_inv, _ub_reg, i * output_chunk_size, m_chunk, n, m,
_ub_comm, _stream_comm););
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size,
m_chunk, n, m, _ub_comm, _stream_comm);
}
rs_output_ptr += m_chunk * rs_output.element_size();
input_a_chunk_ptr += input_a_chunk_size * B.element_size();
output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size();
}
}
_ub_comm->sms = ori_sms;
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0));
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
} // CommOverlapBase::split_overlap_rs
/***************************************************************************************************
* Comm+GEMM Overlap P2P Base (Ring-Exchange)
**************************************************************************************************/
CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape, DType buffer_dtype,
int myrank, int numranks, int mylocal, int numlocal,
int mynode, int numnodes, int tp_size,
ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle,
CommOverlapType comm_type, int num_max_streams,
int comm_cga_size, int num_comm_sm, bool set_sm_margin,
bool use_ce, bool atomic_gemm, bool aggregate)
: CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size,
allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size,
num_comm_sm, set_sm_margin, use_ce, atomic_gemm) {
_is_p2p = true;
_is_reduce_scatter = comm_type == CommOverlapType::RS;
_aggregate = aggregate;
// Create workspace tensor with userbuffer
NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!");
size_t buffer_bytes = buffer_shape[0] * buffer_shape[1] * typeToSize(buffer_dtype);
int buffer_chunk_bytes = buffer_bytes / tp_size;
_num_ubuf_chunks = tp_size;
if (_is_reduce_scatter) {
// GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk
// outputs for reduction at the end of the pipelining.
buffer_bytes = buffer_bytes / tp_size * (tp_size * 2 - 1);
_num_ubuf_chunks = tp_size * 2 - 1;
}
void *buffer_ptr;
_ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true);
if (_rank == 0) printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg);
_ubuf = TensorWrapper(buffer_ptr, {buffer_shape[0] / tp_size * _num_ubuf_chunks, buffer_shape[1]},
buffer_dtype);
// Create tensor chunks for easy management
char *ubuf_byte_ptr = reinterpret_cast<char *>(buffer_ptr);
for (int i = 0; i < _num_ubuf_chunks; i++) {
_ubufs.push_back(TensorWrapper(reinterpret_cast<void *>(ubuf_byte_ptr),
{buffer_shape[0] / tp_size, buffer_shape[1]}, buffer_dtype));
ubuf_byte_ptr += buffer_chunk_bytes;
}
_rank_round_tp = (_rank / _tp_size) * _tp_size;
_next_rank = (_tp_size + _rank + 1) % _tp_size + _rank_round_tp;
_prev_rank = (_tp_size + _rank + -1) % _tp_size + _rank_round_tp;
_self_chunk_id = _tp_id;
if (_atomic_gemm && !_is_reduce_scatter) {
_use_multiatomic_ag = getenv<bool>("NVTE_AG_P2P_MULTI_ATOMIC");
if (_use_multiatomic_ag) {
_use_ce = 0;
_ub_comm->push = 1;
if (_rank == 0) {
printf("!!userbuffers_sendrecv_multi_atomic_shuffle\n");
}
}
_self_chunk_id = 0;
NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t)));
}
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_send, cudaStreamNonBlocking, -1));
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_recv, cudaStreamNonBlocking, -1));
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_send, 0));
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_recv, 0));
}
CommOverlapP2PBase::~CommOverlapP2PBase() {
cudaEventDestroy(_stop_recv);
cudaEventDestroy(_stop_send);
cudaStreamDestroy(_stream_recv);
cudaStreamDestroy(_stream_send);
}
/*
** Split AllGather + AtomicGEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG
** outputs in each rank to be in the contiguous memory space after all ring exchange phases.
*/
void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator,
TensorWrapper &B_copy, cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
// Get GEMM dimensions between TN and NN input layouts
const size_t m = (transa) ? A.size(0) : A.size(1);
const size_t n = _ubuf.size(0);
const size_t n_chunk = n / _tp_size;
assert(pre_gelu_out.numel() == 0);
// Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
// Create an GEMM output buffer with N+1 chunks in a contiguous memory
void *D_buffer_ptr;
int D_chunk_bytes = n_chunk * m * D.element_size();
NVTE_CHECK_CUDA(cudaMallocAsync(&D_buffer_ptr, (_tp_size + 1) * D_chunk_bytes, stream_main));
auto D_buffer = TensorWrapper(D_buffer_ptr, D.shape(), D.dtype(), D.amax(), D.scale(), nullptr);
// Reset atomic counters
int *counter_ptr = reinterpret_cast<int *>(_counter.dptr());
reset_counters(counter_ptr, _tp_size, true, stream_main);
// Catch up the default torch stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0));
auto input_b = TensorWrapper(_ubuf.dptr(), B.shape(), B.dtype(), nullptr, nullptr, B.scale_inv());
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
auto workspace_chunk =
TensorWrapper(workspace.dptr(), std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
for (int i = 0; i < _tp_size - 1; i++) {
// Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
// have the AG output in all ranks to be contiguous after the ring
// exchanges
int send_chunk_id = i;
int recv_chunk_id = i + 1;
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;
if (_use_multiatomic_ag) {
if (i == 0) {
_ub_comm->use_ce = 0;
userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes,
_ub_comm, _next_rank, _prev_rank, _tp_size, counter_ptr,
true, _stream_recv);
}
} else {
userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, _next_rank,
_stream_recv);
userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, _prev_rank,
_stream_recv);
producer(counter_ptr, recv_chunk_id, _stream_recv);
}
if (i == 0) {
nvte_cublas_atomic_gemm(A.data(), input_b.data(), D_buffer.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms, 0, _tp_size, false,
_counter.data(), stream_main);
}
}
// Store the input activation for backprop
if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_self_chunk_id].numel());
assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size());
NVTE_CHECK_CUDA(
cudaMemcpyAsync(B_copy.dptr(), _ubufs[_self_chunk_id].dptr(),
_ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(),
cudaMemcpyDeviceToDevice, _stream_send));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
}
// Copy the first GEMM output chunk to the end chunk position of D_buffer
char *src_ptr = reinterpret_cast<char *>(D_buffer.dptr());
NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + (D.numel() * D.element_size()), src_ptr, D_chunk_bytes,
cudaMemcpyDeviceToDevice, stream_main));
// Return the last N rows of D_buffer
NVTE_CHECK_CUDA(cudaMemcpyAsync(D.dptr(), src_ptr + D_chunk_bytes, D.numel() * D.element_size(),
cudaMemcpyDeviceToDevice, stream_main));
// Clean up buffer allocation
NVTE_CHECK_CUDA(cudaFreeAsync(D_buffer_ptr, stream_main));
_ub_comm->sms = ori_sms;
} // CommOverlapP2PBase::atomic_gemm_overlap_ag
/*
** Split AllGather + GEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG
** outputs in each rank to be in the contiguous memory space after all ring exchange phases.
*/
void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
bool grad, bool accumulate, bool use_split_accumulator,
TensorWrapper &B_copy, cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
// Get GEMM dimensions between TN and NN input layouts
const size_t m = (transa) ? A.size(0) : A.size(1);
const size_t k = (transa) ? A.size(1) : A.size(0);
const size_t n_chunk = _ubufs[0].size(0);
// Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const bool do_gelu = pre_gelu_out.numel() > 0;
const int output_chunk_bytes = (n_chunk * m) * D.element_size();
const int aux_chunk_bytes = do_gelu ? (n_chunk * m) * pre_gelu_out.element_size() : 0;
// Get output and workspace data pointers
char *output_ptr = reinterpret_cast<char *>(D.dptr());
char *pre_gelu_out_ptr = reinterpret_cast<char *>(pre_gelu_out.dptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.dptr());
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0));
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0));
}
if (_aggregate) {
const int num_steps = _tp_size / 2;
char *input_b_ptr = reinterpret_cast<char *>(_ubuf.dptr());
// Initial 1X input chunk exchange between neighboring peers
int send_chunk_id = _tp_id;
int recv_chunk_id = (_tp_id % 2 == 0) ? _tp_id + 1 : _tp_id - 1;
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;
int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank;
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank,
_stream_send);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank,
_stream_recv);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _stop_recv, 0));
int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1;
const int next_rank = (_tp_size + _tp_id + 2) % _tp_size + _rank_round_tp;
const int prev_rank = (_tp_size + _tp_id - 2) % _tp_size + _rank_round_tp;
// Ring exchange of 2X inputs chunks
for (int i = 0; i < num_steps; i++) {
send_chunk_id = (_tp_size + local_rank_round2 - i * 2) % _tp_size;
recv_chunk_id = (_tp_size + local_rank_round2 - i * 2 - 2) % _tp_size;
send_offset = comm_bytes * send_chunk_id;
recv_offset = comm_bytes * recv_chunk_id;
// GEMM
char *input_b_chunk_ptr = input_b_ptr + send_offset;
auto input_b_chunk =
TensorWrapper(reinterpret_cast<void *>(input_b_chunk_ptr), {n_chunk * 2, k}, B.dtype(),
nullptr, nullptr, B.scale_inv());
char *output_chunk_ptr = output_ptr + (send_chunk_id * output_chunk_bytes);
auto output_chunk = TensorWrapper(reinterpret_cast<void *>(output_chunk_ptr),
{n_chunk * 2, m}, D.dtype(), D.amax(), D.scale(), nullptr);
char *aux_chunk_ptr =
(do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) : nullptr;
auto aux_chunk_shape =
(do_gelu) ? std::vector<size_t>{n_chunk * 2, m} : std::vector<size_t>{0};
auto aux_chunk = TensorWrapper(reinterpret_cast<void *>(aux_chunk_ptr), aux_chunk_shape,
pre_gelu_out.dtype());
char *workspace_chunk_ptr =
workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk;
auto workspace_chunk =
TensorWrapper(reinterpret_cast<void *>(workspace_chunk_ptr),
std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
if (i < num_steps - 1) {
// P2P communication
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm,
next_rank, _stream_send);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm,
prev_rank, _stream_recv);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0));
NVTE_CHECK_CUDA(
cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
} else if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, _stream_send));
}
}
} else {
for (int i = 0; i < _tp_size; i++) {
// Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
// have the AG output in all ranks to be contiguous after the ring
// exchanges
int send_chunk_id = (_tp_size + _tp_id - i) % _tp_size;
int recv_chunk_id = (_tp_size + _tp_id - i - 1) % _tp_size;
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;
// GEMM
auto input_b_chunk = TensorWrapper(_ubufs[send_chunk_id].dptr(), {n_chunk, k}, B.dtype(),
nullptr, nullptr, B.scale_inv());
char *output_chunk_ptr = output_ptr + (send_chunk_id * output_chunk_bytes);
auto output_chunk = TensorWrapper(reinterpret_cast<void *>(output_chunk_ptr), {n_chunk, m},
D.dtype(), D.amax(), D.scale(), nullptr);
char *aux_chunk_ptr =
(do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) : nullptr;
auto aux_chunk_shape = (do_gelu) ? std::vector<size_t>{n_chunk, m} : std::vector<size_t>{0};
auto aux_chunk = TensorWrapper(reinterpret_cast<void *>(aux_chunk_ptr), aux_chunk_shape,
pre_gelu_out.dtype());
char *workspace_chunk_ptr =
workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk;
auto workspace_chunk =
TensorWrapper(reinterpret_cast<void *>(workspace_chunk_ptr),
std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
if (i < _tp_size - 1) {
// P2P communication
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm,
_next_rank, _stream_send);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm,
_prev_rank, _stream_recv);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0));
NVTE_CHECK_CUDA(
cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
} else if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, _stream_send));
}
}
}
_ub_comm->sms = ori_sms;
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0));
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
} // CommOverlapP2PBase::split_overlap_ag
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator,
TensorWrapper &rs_output,
cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
// Get communication and GEMM input chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
// Reset counters
int *counter_ptr = reinterpret_cast<int *>(_counter.dptr());
reset_counters(counter_ptr, _tp_size, false, stream_main);
// Catch up the main stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0));
// Atomic GEMM
// Process GEMM chunks in the order that AG+GEMM places the output chunks.
auto output_d = TensorWrapper(_ubuf.dptr(), D.shape(), D.dtype(), D.amax(), D.scale(), nullptr);
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
auto workspace_chunk =
TensorWrapper(workspace.data(), std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(),
transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, 0, _tp_size, true, _counter.data(),
stream_main);
// P2P communication chunk
for (int i = 1; i < _tp_size; i++) {
int send_chunk_id = i - 1;
int recv_chunk_id = send_chunk_id + _tp_size;
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;
int send_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp;
int recv_rank = (_tp_id + i) % _tp_size + _rank_round_tp;
consumer(counter_ptr, send_chunk_id, _stream_recv);
userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank,
_stream_recv);
userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank,
_stream_recv);
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
// Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].dptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) {
assert(_ubuf_scale_inv_initialized);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, _ubuf_scale_inv, _tp_size,
_ubufs[0].numel(), stream_main););
} else {
reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main);
}
_ub_comm->sms = ori_sms;
}
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
bool grad, bool accumulate, bool use_split_accumulator,
TensorWrapper &rs_output, cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
size_t k = A.size(1);
size_t n = B.size(0);
// Get communication and GEMM input chunk sizes
size_t n_chunk = n / _tp_size;
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const int input_b_chunk_bytes = n_chunk * k * B.element_size();
// Get input and workspace data pointers
char *input_b_ptr = reinterpret_cast<char *>(B.dptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.dptr());
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
// Catch up the main stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0));
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0));
}
// GEMM and send/recv chunks
for (int i = 0; i < _tp_size; i++) {
// GEMM chunk
int input_b_chunk_id = (_tp_id + i + 1) % _tp_size;
char *input_b_chunk_ptr = input_b_ptr + (input_b_chunk_id * input_b_chunk_bytes);
auto input_b_chunk = TensorWrapper(reinterpret_cast<void *>(input_b_chunk_ptr), {n_chunk, k},
B.dtype(), nullptr, nullptr, B.scale_inv());
auto output_chunk =
TensorWrapper(_ubufs[i].dptr(), _ubufs[i].shape(), D.dtype(), D.amax(), D.scale(), nullptr);
char *workspace_chunk_ptr = workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk;
auto workspace_chunk =
TensorWrapper(reinterpret_cast<void *>(workspace_chunk_ptr),
std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[i % _stream_compute.size()]);
if (i > 0) {
// P2P communication chunk
int send_offset = comm_bytes * (i - 1);
int recv_offset = comm_bytes * (i - 1 + _tp_size);
int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp;
int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp;
NVTE_CHECK_CUDA(
cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_comm, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_comm, 0));
userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank,
_stream_send);
userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank,
_stream_recv);
}
}
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0));
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
// Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].dptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) {
assert(_ubuf_scale_inv_initialized);
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, _ubuf_scale_inv, _tp_size,
_ubufs[0].numel(), stream_main););
} else {
reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main);
}
_ub_comm->sms = ori_sms;
}
} // namespace transformer_engine
...@@ -20,7 +20,9 @@ ...@@ -20,7 +20,9 @@
#include <utility> #include <utility>
#include "common/util/cuda_driver.h" #include "common/util/cuda_driver.h"
#include "common/util/cuda_runtime.h"
#include "common/util/logging.h" #include "common/util/logging.h"
#include "common/util/system.h"
#include "ipcsocket.h" #include "ipcsocket.h"
#include "userbuffers.h" #include "userbuffers.h"
...@@ -44,31 +46,19 @@ static MPI_Comm EXT_COMM_INTER; ...@@ -44,31 +46,19 @@ static MPI_Comm EXT_COMM_INTER;
} while (false) } while (false)
void ub_mpi_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes, void ub_mpi_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes,
ExtComm group) { ExtComm comm) {
// UB_MPI_CHECK(MPI_Allgather(localdata, localbytes, MPI_BYTE,
// globaldata, globalbytes, MPI_BYTE,
// static_cast<MPI_Comm>(group)));
MPI_Comm comm = static_cast<MPI_Comm>(group);
int numranks; int numranks;
UB_MPI_CHECK(MPI_Comm_size(comm, &numranks)); UB_MPI_CHECK(MPI_Comm_size(comm, &numranks));
assert(globalbytes == numranks * localbytes); assert(globalbytes == numranks * localbytes);
UB_MPI_CHECK(
int myrank; MPI_Allgather(localdata, localbytes, MPI_BYTE, globaldata, localbytes, MPI_BYTE, comm));
UB_MPI_CHECK(MPI_Comm_rank(comm, &myrank));
char *globaltarget = reinterpret_cast<char *>(globaldata) + (myrank * localbytes);
memcpy(globaltarget, localdata, localbytes);
for (int n = 0; n < numranks; n++) {
globaltarget = reinterpret_cast<char *>(globaldata) + (n * localbytes);
UB_MPI_CHECK(MPI_Bcast(globaltarget, localbytes, MPI_BYTE, n, comm));
}
} }
void ub_mpi_barrier(ExtComm group) { UB_MPI_CHECK(MPI_Barrier(static_cast<MPI_Comm>(group))); } void ub_mpi_barrier(ExtComm comm) { UB_MPI_CHECK(MPI_Barrier(comm)); }
#else #else
static char EXT_COMM_WORLD[] = "world"; #define EXT_COMM_WORLD "world"
static char EXT_COMM_INTRA[] = "intra"; #define EXT_COMM_INTRA "intra"
static char EXT_COMM_INTER[] = "inter"; #define EXT_COMM_INTER "inter"
#endif #endif
#define MULTICAST_GB_TOTAL 512 #define MULTICAST_GB_TOTAL 512
...@@ -106,11 +96,10 @@ int pipe_rank(communicator *comm, int step) { ...@@ -106,11 +96,10 @@ int pipe_rank(communicator *comm, int step) {
return newnode * numlocal + newlocal; return newnode * numlocal + newlocal;
} }
int create_communicator_grouped2( int create_communicator_grouped2(communicator **comm, int myrank, int numranks, int mylocal,
communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numlocal, int mynode, int numnodes,
int numnodes, std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather, ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier,
std::function<void(ExtComm)> ext_barrier, int pipegpus, int pipenodes, int tensorgpus, int pipegpus, int pipenodes, int tensorgpus, int tensornodes) {
int tensornodes) {
*comm = new communicator(); *comm = new communicator();
(*comm)->comm_world = EXT_COMM_WORLD; (*comm)->comm_world = EXT_COMM_WORLD;
...@@ -214,8 +203,11 @@ int create_communicator_grouped2( ...@@ -214,8 +203,11 @@ int create_communicator_grouped2(
(*comm)->asyncblocks = 16; (*comm)->asyncblocks = 16;
#define NBUF 2 #define NBUF 2
if ((*comm)->sm_arch >= 9 && (*comm)->ar2_nvsize > 1 &&
!getenv("UB_SKIPMC")) { // multicast init only for TP ops (____2 operations) #if CUDART_VERSION >= 12010
if (!transformer_engine::getenv<bool>("UB_SKIPMC") &&
transformer_engine::cuda::supports_multicast() && (*comm)->ar2_nvsize > 1) {
// multicast init only for TP ops (____2 operations)
size_t mc_maxsize = MULTICAST_GB_TOTAL * (1ull << 30); size_t mc_maxsize = MULTICAST_GB_TOTAL * (1ull << 30);
(*comm)->mc_offset = 0; (*comm)->mc_offset = 0;
(*comm)->use_mc = 1; (*comm)->use_mc = 1;
...@@ -291,20 +283,20 @@ int create_communicator_grouped2( ...@@ -291,20 +283,20 @@ int create_communicator_grouped2(
(*comm)->_barrier((*comm)->comm_world); (*comm)->_barrier((*comm)->comm_world);
if (!(*comm)->myrank) printf("MC initialized succesfully, window size = %ld\n", mc_maxsize); if (!(*comm)->myrank) printf("MC initialized succesfully, window size = %ld\n", mc_maxsize);
} else { } else {
#endif
if (!(*comm)->myrank) printf("MC NOT initialized and used\n"); if (!(*comm)->myrank) printf("MC NOT initialized and used\n");
(*comm)->mc_maxsize = 0; (*comm)->mc_maxsize = 0;
(*comm)->mc_offset = 0; (*comm)->mc_offset = 0;
(*comm)->use_mc = 0; (*comm)->use_mc = 0;
#if CUDART_VERSION >= 12010
} }
#endif
#define LOCALSIZE 4 * (NVTE_REG0_OFFSET(*comm) + NVTE_REG0_FLAGS + NVTE_REG0_COMMBUFFER * NBUF) #define LOCALSIZE 4 * (NVTE_REG0_OFFSET(*comm) + NVTE_REG0_FLAGS + NVTE_REG0_COMMBUFFER * NBUF)
// peer pointers + op flags + comm buffer // peer pointers + op flags + comm buffer
NVTE_CHECK_CUDA(
cudaMalloc(&(*comm)->gpu_ptrs, LOCALSIZE)); // flags and pointers, no block data yet
NVTE_CHECK_CUDA(cudaMemset((*comm)->gpu_ptrs, 0, LOCALSIZE));
NVTE_CHECK_CUDA(cudaDeviceSynchronize()); NVTE_CHECK_CUDA(cudaDeviceSynchronize());
register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm, false); register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm, true);
NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->send_id, (*comm)->nranks * sizeof(int))); NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->send_id, (*comm)->nranks * sizeof(int)));
NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->recv_id, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int))); NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->recv_id, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int)));
NVTE_CHECK_CUDA(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int))); NVTE_CHECK_CUDA(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int)));
...@@ -346,18 +338,17 @@ int create_communicator_grouped2( ...@@ -346,18 +338,17 @@ int create_communicator_grouped2(
return 0; return 0;
} }
int create_communicator_grouped( int create_communicator_grouped(communicator **comm, int myrank, int numranks, int mylocal,
communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numlocal, int mynode, int numnodes,
int numnodes, std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather, ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier,
std::function<void(ExtComm)> ext_barrier, int pipegpus, int pipenodes) { int pipegpus, int pipenodes) {
return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes,
ext_allgather, ext_barrier, pipegpus, pipenodes, 1, 1); ext_allgather, ext_barrier, pipegpus, pipenodes, 1, 1);
} }
int create_communicator(communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int create_communicator(communicator **comm, int myrank, int numranks, int mylocal, int numlocal,
int mynode, int numnodes, int mynode, int numnodes, ExtAllgatherOp ext_allgather,
std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather, ExtBarrierOp ext_barrier) {
std::function<void(ExtComm)> ext_barrier) {
return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes,
ext_allgather, ext_barrier, 1, 1, 1, 1); ext_allgather, ext_barrier, 1, 1, 1, 1);
} }
...@@ -428,7 +419,7 @@ int create_communicator_mpi(communicator **comm) { ...@@ -428,7 +419,7 @@ int create_communicator_mpi(communicator **comm) {
void destroy_communicator(communicator *comm) { void destroy_communicator(communicator *comm) {
for (int hndl = 0; hndl < comm->free_region; hndl++) { for (int hndl = 0; hndl < comm->free_region; hndl++) {
if (hndl > 0 && comm->use_mc && comm->mem_dealloc[hndl]) { if (comm->use_mc && comm->mem_dealloc[hndl]) {
for (int rank = 0; rank < comm->nvsize; rank++) { for (int rank = 0; rank < comm->nvsize; rank++) {
if (rank == comm->nvrank) { if (rank == comm->nvrank) {
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemRelease, comm->uchandles[hndl][rank]); NVTE_CALL_CHECK_CUDA_DRIVER(cuMemRelease, comm->uchandles[hndl][rank]);
...@@ -479,6 +470,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * ...@@ -479,6 +470,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
comm->memflags[hndl] = 0; comm->memflags[hndl] = 0;
comm->mem_dealloc[hndl] = alloc; comm->mem_dealloc[hndl] = alloc;
#if CUDART_VERSION >= 12010
if (comm->use_mc && alloc) { if (comm->use_mc && alloc) {
int nranks = comm->nvsize; // total GPUs in NVLINK domain int nranks = comm->nvsize; // total GPUs in NVLINK domain
int myrank = comm->nvrank; int myrank = comm->nvrank;
...@@ -594,6 +586,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * ...@@ -594,6 +586,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
} }
} else { } else {
#endif
if (alloc) { if (alloc) {
NVTE_CHECK_CUDA(cudaMalloc(gpubuff, bytes)); NVTE_CHECK_CUDA(cudaMalloc(gpubuff, bytes));
NVTE_CHECK_CUDA(cudaMemset(*gpubuff, 0, bytes)); NVTE_CHECK_CUDA(cudaMemset(*gpubuff, 0, bytes));
...@@ -624,7 +617,9 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * ...@@ -624,7 +617,9 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
NVTE_CHECK_CUDA(cudaDeviceSynchronize()); NVTE_CHECK_CUDA(cudaDeviceSynchronize());
free(tmp); free(tmp);
#if CUDART_VERSION >= 12010
} }
#endif
comm->mem_size[hndl] = aligned_size; comm->mem_size[hndl] = aligned_size;
comm->mem_ptr[hndl] = *gpubuff; comm->mem_ptr[hndl] = *gpubuff;
......
...@@ -392,14 +392,14 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ ...@@ -392,14 +392,14 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id; if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id;
} // fp16 reduce-scatter kernel (out of place) } // fp16 reduce-scatter kernel (out of place)
#if __CUDA_ARCH__ >= 900 #if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010
// All MC kernels here // All MC kernels here
template <int RANKS> template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS) __global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_mc(const int op, const int flagoffset, const int firstrank, userbuffers_fp16_sum_inplace_gpu_mc(const int op, const int flagoffset, const int firstrank,
const int myrank, const int gpustep, const int lineoffset, const int myrank, const int gpustep, const int lineoffset,
const int numlines, void **commbuff, const int handleridx, const int numlines, void **commbuff, const int handleridx,
float4 *mc_ptr) { float4 *mc_ptr, const uint64_t ub_timeout) {
int *flagptr, physgpu, targetgpu, *myptr; int *flagptr, physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id; int *reduceidptr, reduce_id;
...@@ -417,7 +417,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -417,7 +417,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
volatile int *flag = (volatile int *)&(myptr[targetgpu]); volatile int *flag = (volatile int *)&(myptr[targetgpu]);
clock_t s = clock64(); clock_t s = clock64();
while (CHECK_IDS(*flag, reduce_id)) { while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > TIMEOUT) { if (clock64() - s > ub_timeout) {
UB_PRINT("Reduce-scatter: SM %d [%d]: expecting %d got %d", blockIdx.x, threadIdx.x, UB_PRINT("Reduce-scatter: SM %d [%d]: expecting %d got %d", blockIdx.x, threadIdx.x,
reduce_id, *flag); reduce_id, *flag);
break; break;
...@@ -484,7 +484,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -484,7 +484,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
volatile int *flag = (volatile int *)&myptr[targetgpu]; volatile int *flag = (volatile int *)&myptr[targetgpu];
clock_t s = clock64(); clock_t s = clock64();
while (CHECK_IDS(*flag, reduce_id)) { while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > 2ull * TIMEOUT) { if (clock64() - s > 2ull * ub_timeout) {
UB_PRINT("Allgather: SM %d [%d]: expecting %d got %d", blockIdx.x, threadIdx.x, reduce_id, UB_PRINT("Allgather: SM %d [%d]: expecting %d got %d", blockIdx.x, threadIdx.x, reduce_id,
*flag); *flag);
break; break;
...@@ -741,7 +741,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -741,7 +741,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_mc(const int op, const int flagoffset, const int firstrank, userbuffers_fp16_sum_inplace_gpu_mc(const int op, const int flagoffset, const int firstrank,
const int myrank, const int gpustep, const int lineoffset, const int myrank, const int gpustep, const int lineoffset,
const int numlines, void **commbuff, const int handleridx, const int numlines, void **commbuff, const int handleridx,
float4 *mc_ptr) {} float4 *mc_ptr, const uint64_t ub_timeout) {}
template <int RANKS> template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS) __global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_mc_rs_oop(const int op, const int flagoffset, userbuffers_fp16_sum_inplace_gpu_mc_rs_oop(const int op, const int flagoffset,
...@@ -2496,6 +2496,18 @@ static __global__ void consumer_batch_kernel(void *atomic_ptr, int first_chunk_i ...@@ -2496,6 +2496,18 @@ static __global__ void consumer_batch_kernel(void *atomic_ptr, int first_chunk_i
} }
} }
// reset counters kernel
static __global__ void reset_counters_kernel(void *atomic_ptr, int num_chunks, bool allgather) {
if (blockIdx.x == 0 && threadIdx.x == 0) {
#pragma unroll
for (int i = 0; i < num_chunks; i++) {
((unsigned int *)atomic_ptr)[i] = 1;
((unsigned int *)atomic_ptr)[i + num_chunks] = 0;
}
if (allgather) ((unsigned int *)atomic_ptr)[0] = 0;
}
}
void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream) { void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream) {
dim3 block(1); dim3 block(1);
dim3 grid(1); dim3 grid(1);
...@@ -2514,6 +2526,12 @@ void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStr ...@@ -2514,6 +2526,12 @@ void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStr
consumer_batch_kernel<<<grid, block, 0, stream>>>(atomic_ptr, first_chunk_i, num_chunks); consumer_batch_kernel<<<grid, block, 0, stream>>>(atomic_ptr, first_chunk_i, num_chunks);
} }
void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, cudaStream_t stream) {
dim3 block(1);
dim3 grid(1);
reset_counters_kernel<<<grid, block, 0, stream>>>(atomic_ptr, num_chunks, allgather);
}
template <typename fp8type> template <typename fp8type>
__global__ void __launch_bounds__(MAX_THREADS / 4) __global__ void __launch_bounds__(MAX_THREADS / 4)
reduce_fp8_in_bf16_out_cuda(void *inputs, void *output, const float *scale, reduce_fp8_in_bf16_out_cuda(void *inputs, void *output, const float *scale,
...@@ -2546,3 +2564,24 @@ template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(void *inputs, void *output, ...@@ -2546,3 +2564,24 @@ template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(void *inputs, void *output,
template void reduce_fp8_in_bf16_out<__nv_fp8_e5m2>(void *inputs, void *output, float *scale, template void reduce_fp8_in_bf16_out<__nv_fp8_e5m2>(void *inputs, void *output, float *scale,
int num_inputs, int input_size, int num_inputs, int input_size,
cudaStream_t stream); cudaStream_t stream);
__global__ void __launch_bounds__(MAX_THREADS / 4)
reduce_bf16_cuda(void *inputs, void *output, const int num_inputs, const int input_size) {
const size_t tid = threadIdx.x + blockDim.x * blockIdx.x;
half *inputs_half = reinterpret_cast<half *>(inputs);
float accum_buf = static_cast<float>(inputs_half[tid]);
#pragma unroll
for (int i = 1; i < num_inputs; i++) {
accum_buf += static_cast<float>(inputs_half[tid + input_size * i]);
}
half *output_half = reinterpret_cast<half *>(output);
output_half[tid] = (half)accum_buf;
}
void reduce_bf16(void *inputs, void *output, int num_inputs, int input_size, cudaStream_t stream) {
size_t num_threads = MAX_THREADS / 4;
size_t num_blocks = (input_size + num_threads - 1) / num_threads;
dim3 block(num_threads);
dim3 grid(num_blocks);
reduce_bf16_cuda<<<grid, block, 0, stream>>>(inputs, output, num_inputs, input_size);
}
...@@ -19,11 +19,14 @@ ...@@ -19,11 +19,14 @@
#ifdef NVTE_UB_WITH_MPI #ifdef NVTE_UB_WITH_MPI
#include <mpi.h> #include <mpi.h>
typedef MPI_Comm ExtComm; #define ExtComm MPI_Comm
#else #else
typedef char *ExtComm; #define ExtComm const char *
#endif #endif
using ExtAllgatherOp = std::function<void(void *, size_t, void *, size_t, ExtComm)>;
using ExtBarrierOp = std::function<void(ExtComm)>;
#define NVTE_MAX_REGIONS 16 #define NVTE_MAX_REGIONS 16
#define NVTE_MAX_SMS 32 #define NVTE_MAX_SMS 32
#define NVTE_MAX_OPS 32 #define NVTE_MAX_OPS 32
...@@ -142,12 +145,12 @@ struct communicator { ...@@ -142,12 +145,12 @@ struct communicator {
volatile int tail; volatile int tail;
// Abstract communication callbacks to support external bootstrapping (e.g. DL frameworks) // Abstract communication callbacks to support external bootstrapping (e.g. DL frameworks)
std::function<void(void *, size_t, void *, size_t, ExtComm)> _allgather; ExtAllgatherOp _allgather;
std::function<void(ExtComm)> _barrier; ExtBarrierOp _barrier;
ExtComm comm_world, ExtComm comm_world;
comm_inter, // reduction group communicator (subset of the nodes) along GPU rail ExtComm comm_inter; // reduction group communicator (subset of the nodes) along GPU rail
comm_intra; // full intranode (all ndev GPUS) ExtComm comm_intra; // full intranode (all ndev GPUS)
#ifdef NVTE_UB_WITH_MPI #ifdef NVTE_UB_WITH_MPI
MPI_Request mpihndl[NVTE_MAX_SHARP]; MPI_Request mpihndl[NVTE_MAX_SHARP];
#endif #endif
...@@ -161,23 +164,22 @@ typedef struct communicator communicator; ...@@ -161,23 +164,22 @@ typedef struct communicator communicator;
void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream); void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream);
void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream); void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream);
void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStream_t stream); void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStream_t stream);
void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, cudaStream_t stream);
/* creates communicator, allocates all internal buffers if necessary */ /* creates communicator, allocates all internal buffers if necessary */
int create_communicator_grouped2( int create_communicator_grouped2(communicator **comm, int myrank, int numranks, int mylocal,
communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numlocal, int mynode, int numnodes,
int numnodes, std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather, ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier,
std::function<void(ExtComm)> ext_barrier, int pipegpus, int pipenodes, int tensorgpus, int pipegpus, int pipenodes, int tensorgpus, int tensornodes);
int tensornodes);
int create_communicator_grouped( int create_communicator_grouped(communicator **comm, int myrank, int numranks, int mylocal,
communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numlocal, int mynode, int numnodes,
int numnodes, std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather, ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier,
std::function<void(ExtComm)> ext_barrier, int pipegpus, int pipenodes); int pipegpus, int pipenodes);
int create_communicator(communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int create_communicator(communicator **comm, int myrank, int numranks, int mylocal, int numlocal,
int mynode, int numnodes, int mynode, int numnodes, ExtAllgatherOp ext_allgather,
std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather, ExtBarrierOp ext_barrier);
std::function<void(ExtComm)> ext_barrier);
int create_communicator_grouped2_mpi(communicator **comm, int pipegpus, int pipenodes, int create_communicator_grouped2_mpi(communicator **comm, int pipegpus, int pipenodes,
int tensorgpus, int tensornodes); int tensorgpus, int tensornodes);
...@@ -314,4 +316,6 @@ template <typename fp8type> ...@@ -314,4 +316,6 @@ template <typename fp8type>
void reduce_fp8_in_bf16_out(void *input, void *output, float *scale, int num_inputs, int input_size, void reduce_fp8_in_bf16_out(void *input, void *output, float *scale, int num_inputs, int input_size,
cudaStream_t stream); cudaStream_t stream);
void reduce_bf16(void *input, void *output, int num_inputs, int input_size, cudaStream_t stream);
#endif // TRANSFORMER_ENGINE_USERBUFFERS_H_ #endif // TRANSFORMER_ENGINE_USERBUFFERS_H_
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_COMM_GEMM_OVERLAP_H_
#define TRANSFORMER_ENGINE_COMMON_COMM_GEMM_OVERLAP_H_
#include <cuda.h>
#include <cuda_fp8.h>
#include <transformer_engine/transformer_engine.h>
#include <functional>
#include "common/comm_gemm_overlap/userbuffers/userbuffers.h"
#define NVTE_COMM_OVERLAP_MAX_STREAMS 3
namespace transformer_engine {
/* \brief Check if Userbufers bootstraps with direct calls to MPI collectives.
* This can turned on by building Transformer Engine with the `NVTE_UB_WITH_MPI=1` option.
*
* \return True if Userbuffers is built with MPI
*/
bool ubuf_built_with_mpi();
enum class CommOverlapType { RS = 0, AG = 1 };
enum class CommOverlapAlgo {
BULK_OVERLAP_AG = 0,
BULK_OVERLAP_RS = 1,
SPLIT_PIPELINED_AG_P2P = 2,
SPLIT_PIPELINED_RS = 3,
SPLIT_PIPELINED_RS_P2P = 4,
ATOMIC_GEMM_RS = 5,
ATOMIC_GEMM_AG_P2P = 6,
ATOMIC_GEMM_RS_P2P = 7
};
class CommOverlapCore {
protected:
static inline communicator *_ub_comm{nullptr};
static inline bool _comm_created{false};
int _rank;
int _tp_id;
int _tp_size;
int _num_splits;
int _math_sms;
int _num_comm_sm;
int _cga_size;
int _use_ce;
int _ub_reg;
bool _atomic_gemm{false};
bool _is_p2p{false};
TensorWrapper _ubuf;
TensorWrapper _counter;
float *_ubuf_scale_inv;
bool _ubuf_scale_inv_initialized{false};
std::vector<cudaStream_t> _stream_compute;
cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm;
public:
CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes,
int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle,
int num_splits, int num_max_streams, int comm_cga_size, int num_comm_sm,
bool set_sm_margin, bool use_ce, bool atomic_gemm);
virtual ~CommOverlapCore();
void set_ubuf_scale_inv(float *scale_inv) {
_ubuf_scale_inv = scale_inv;
_ubuf_scale_inv_initialized = true;
}
bool is_atomic_gemm() { return _atomic_gemm; }
bool is_p2p_overlap() { return _is_p2p; }
bool is_fp8_ubuf() { return _ubuf.element_size() == 1; }
}; // CommOverlapCore
class CommOverlapBase : public CommOverlapCore {
protected:
int _rs_kernel_type;
cudaStream_t _stream_comm;
cudaEvent_t _start_d2dcopy;
public:
CommOverlapBase(const std::vector<size_t> &buffer_shape, DType buffer_dtype, int myrank,
int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size,
ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3,
int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2,
int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false);
virtual ~CommOverlapBase();
/*
** Bulk GEMM + COMM
** This function assumes the communication input is pre-copied to _ubuf
*/
void bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, TensorWrapper &D,
TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
bool grad, bool accumulate, bool use_split_accumulator,
CommOverlapType comm_type, TensorWrapper &rs_output, cudaStream_t stream_main);
/*
** Split FPROP GEMM + ReduceScatter
*/
void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, bool gemm_overlap,
TensorWrapper &rs_output, cudaStream_t stream_main);
/*
** Split FPROP GEMM + ReduceScatter
*/
void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, bool gemm_overlap, TensorWrapper &rs_output,
cudaStream_t stream_main);
}; // CommOverlapBase
class CommOverlapP2PBase : public CommOverlapCore {
protected:
bool _is_reduce_scatter{false};
bool _use_multiatomic_ag{false};
int _next_rank;
int _prev_rank;
int _rank_round_tp;
int _aggregate;
int _num_ubuf_chunks;
int _self_chunk_id;
std::vector<TensorWrapper> _ubufs;
cudaStream_t _stream_send;
cudaStream_t _stream_recv;
cudaEvent_t _stop_send, _stop_recv;
public:
CommOverlapP2PBase(const std::vector<size_t> &buffer_shape, DType buffer_dtype, int myrank,
int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size,
ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle,
CommOverlapType comm_type, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS,
int comm_cga_size = 1, int num_comm_sm = 1, bool set_sm_margin = false,
bool use_ce = true, bool atomic_gemm = false, bool aggregate = false);
virtual ~CommOverlapP2PBase();
/*
** Split AllGather + AtomicGEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG
** outputs in each rank to be in the contiguous memory space after all ring exchange phases.
*/
void atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &B_copy,
cudaStream_t stream_main);
/*
** Split AllGather + GEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG
** outputs in each rank to be in the contiguous memory space after all ring exchange phases.
*/
void split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &B_copy,
cudaStream_t stream_main);
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main);
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main);
}; // CommOverlapP2PBase
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMM_GEMM_OVERLAP_H_
...@@ -78,13 +78,13 @@ NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType ...@@ -78,13 +78,13 @@ NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType
*/ */
void nvte_destroy_tensor(NVTETensor tensor); void nvte_destroy_tensor(NVTETensor tensor);
/*! \brief Get a tensor's data type. /*! \brief Get a raw pointer to the tensor's data.
* *
* \param[in] tensor Tensor. * \param[in] tensor Tensor.
* *
* \return A data type of the input tensor. * \return A raw pointer to tensor's data.
*/ */
NVTEDType nvte_tensor_type(const NVTETensor tensor); void *nvte_tensor_data(const NVTETensor tensor);
/*! \brief Get a tensor's data shape. /*! \brief Get a tensor's data shape.
* *
...@@ -94,13 +94,46 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor); ...@@ -94,13 +94,46 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor);
*/ */
NVTEShape nvte_tensor_shape(const NVTETensor tensor); NVTEShape nvte_tensor_shape(const NVTETensor tensor);
/*! \brief Get a raw pointer to the tensor's data. /*! \brief Get a tensor's number of dimensions.
* *
* \param[in] tensor Tensor. * \param[in] tensor Tensor.
* *
* \return A raw pointer to tensor's data. * \return Number of tensor dimensions.
*/ */
void *nvte_tensor_data(const NVTETensor tensor); size_t nvte_tensor_ndims(const NVTETensor tensor);
/*! \brief Get the size of a specific tensor dimension.
*
* \param[in] tensor Tensor.
* \param[in] size_t Dimension index.
*
* \return Size of the tensor at the specified dimension.
*/
size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim);
/*! \brief Get a tensor's total number of elements.
*
* \param[in] tensor Tensor.
*
* \return Number of elements in the tensor.
*/
size_t nvte_tensor_numel(const NVTETensor tensor);
/*! \brief Get the byte size for the tensor's data type.
*
* \param[in] tensor Tensor.
*
* \return Byte size of the tensor's data type.
*/
size_t nvte_tensor_element_size(const NVTETensor tensor);
/*! \brief Get a tensor's data type.
*
* \param[in] tensor Tensor.
*
* \return A data type of the input tensor.
*/
NVTEDType nvte_tensor_type(const NVTETensor tensor);
/*! \brief Get a pointer to the tensor's amax data. /*! \brief Get a pointer to the tensor's amax data.
* *
...@@ -265,6 +298,56 @@ class TensorWrapper { ...@@ -265,6 +298,56 @@ class TensorWrapper {
return nvte_tensor_shape(tensor_); return nvte_tensor_shape(tensor_);
} }
/*! \brief Get the size of this TensorWrapper in the given dimension.
*
* \param[in] size_t Dimension index.
*
* \return Size of this TensorWrapper in given dimension.
*/
size_t size(const size_t dim) const {
if (tensor_ == nullptr) return 0;
return nvte_tensor_size(tensor_, dim);
}
/*! \brief Get the number of dimensions for this TensorWrapper.
*
* \return Number of dimensions for this TensorWrapper.
*/
size_t ndim() const noexcept {
if (tensor_ == nullptr) return 0;
return nvte_tensor_ndims(tensor_);
}
/*! \brief Get the number of allocated elements in the tensor. This will return 0 for tensors
* with nullptr data even if the TensorWrapper has a non-zero shape.
*
*
* \return Number of elements in the tensor.
*/
size_t numel() const noexcept {
if (tensor_ == nullptr || this->dptr() == nullptr) return 0;
return nvte_tensor_numel(tensor_);
}
/*! \brief Get the tensor's element size in bytes.
*
* \return Element size in bytes.
*/
size_t element_size() const noexcept {
if (tensor_ == nullptr) return 0;
return nvte_tensor_element_size(tensor_);
}
/*! \brief Get the tensor's allocated size in bytes. This will return 0 for tensors with nullptr
* data even if the TensorWrapper has a non-zero shape and valid dtype.
*
* \return Total tensor size in bytes.
*/
size_t bytes() const noexcept {
if (tensor_ == nullptr || this->dptr() == nullptr) return 0;
return nvte_tensor_numel(tensor_) * nvte_tensor_element_size(tensor_);
}
/*! \brief Get the data type of this TensorWrapper. /*! \brief Get the data type of this TensorWrapper.
* *
* \return Data type of this TensorWrapper. * \return Data type of this TensorWrapper.
...@@ -317,6 +400,6 @@ class TensorWrapper { ...@@ -317,6 +400,6 @@ class TensorWrapper {
} // namespace transformer_engine } // namespace transformer_engine
#endif #endif // __cplusplus
#endif // TRANSFORMER_ENGINE_TRANSFORMER_ENGINE_H_ #endif // TRANSFORMER_ENGINE_TRANSFORMER_ENGINE_H_
...@@ -93,6 +93,31 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { ...@@ -93,6 +93,31 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
return ret; return ret;
} }
size_t nvte_tensor_ndim(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return t.data.shape.size();
}
size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTE_CHECK(dim >= 0 && dim < t.data.shape.size(), "Invalid dimension index: ", dim);
return t.data.shape[dim];
}
size_t nvte_tensor_numel(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
size_t numel = 1;
for (auto size : t.data.shape) {
numel *= size;
}
return numel;
}
size_t nvte_tensor_element_size(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return transformer_engine::typeToSize(t.data.dtype);
}
void *nvte_tensor_data(const NVTETensor tensor) { void *nvte_tensor_data(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return t.data.dptr; return t.data.dptr;
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "../common.h" #include "../common.h"
#include "../util/cuda_driver.h" #include "../util/cuda_driver.h"
#include "../util/system.h" #include "../util/system.h"
#include "common/util/cuda_runtime.h"
namespace transformer_engine { namespace transformer_engine {
...@@ -80,6 +81,31 @@ int sm_count(int device_id) { ...@@ -80,6 +81,31 @@ int sm_count(int device_id) {
return cache[device_id]; return cache[device_id];
} }
bool supports_multicast(int device_id) {
#if CUDART_VERSION >= 12010
// NOTE: This needs to be guarded at compile time because the
// CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED enum is not defined in earlier CUDA versions.
static std::vector<bool> cache(num_devices(), false);
static std::vector<std::once_flag> flags(num_devices());
if (device_id < 0) {
device_id = current_device();
}
NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID");
auto init = [&]() {
CUdevice cudev;
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &cudev, device_id);
int result;
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &result,
CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, cudev);
cache[device_id] = static_cast<bool>(result);
};
std::call_once(flags[device_id], init);
return cache[device_id];
#else
return false;
#endif
}
const std::string &include_directory(bool required) { const std::string &include_directory(bool required) {
static std::string path; static std::string path;
......
...@@ -38,6 +38,14 @@ int sm_arch(int device_id = -1); ...@@ -38,6 +38,14 @@ int sm_arch(int device_id = -1);
*/ */
int sm_count(int device_id = -1); int sm_count(int device_id = -1);
/* \brief CUDA Multicast support status for device
*
* \param[in] device_id CUDA device (default is current device)
*
* \return CUDA multicast support flag
*/
bool supports_multicast(int device_id = -1);
/* \brief Path to CUDA Toolkit headers /* \brief Path to CUDA Toolkit headers
* *
* The path can be configured by setting NVTE_CUDA_INCLUDE_DIR in the * The path can be configured by setting NVTE_CUDA_INCLUDE_DIR in the
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_
#include <pybind11/pybind11.h>
#include <transformer_engine/comm_gemm_overlap.h>
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/transformer_engine.h>
#include "cuda_runtime.h"
#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \
pybind11::enum_<transformer_engine::DType>(m, "DType") \
.value("kByte", transformer_engine::DType::kByte) \
.value("kInt32", transformer_engine::DType::kInt32) \
.value("kFloat32", transformer_engine::DType::kFloat32) \
.value("kFloat16", transformer_engine::DType::kFloat16) \
.value("kBFloat16", transformer_engine::DType::kBFloat16) \
.value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \
pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type") \
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \
.value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \
.value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \
.value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \
pybind11::enum_<NVTE_Mask_Type>(m, "NVTE_Mask_Type") \
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \
.value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) \
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \
pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout") \
.value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \
.value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \
.value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \
.value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) \
.value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) \
.value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) \
.value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) \
.value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) \
.value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) \
.value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) \
.value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) \
.value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \
.value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \
.value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \
.value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend") \
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \
.value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \
.value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \
pybind11::enum_<transformer_engine::CommOverlapType>(m, "CommOverlapType") \
.value("RS", transformer_engine::CommOverlapType::RS) \
.value("AG", transformer_engine::CommOverlapType::AG); \
pybind11::enum_<transformer_engine::CommOverlapAlgo>(m, "CommOverlapAlgo") \
.value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \
.value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \
.value("SPLIT_PIPELINED_AG_P2P", \
transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_P2P) \
.value("SPLIT_PIPELINED_RS", transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS) \
.value("SPLIT_PIPELINED_RS_P2P", \
transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \
.value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \
.value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \
.value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \
m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \
py::call_guard<py::gil_scoped_release>(), py::arg("device_id") = -1); \
m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \
py::call_guard<py::gil_scoped_release>());
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment