Unverified Commit 933294dc authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[C/PyTorch] Userbuffers and comm+GEMM overlap algorithms refactored and moved to TE/common (#1067)



* moved userbuffers code to TE/common
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* moved comm+GEMM overlap code to TE/common
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* removed PyTorch depdency from comm+GEMM overlap in TE/common
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added TE/PyTorch wrappers for refactored comm+GEMM overlap code in TE/common
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* updated TE/PyTorch Python API to match the refactored comm+GEMM overlap code
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* updated unit tests to work with refactored comm+GEMM overlap code
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added a pylint exception to comm+GEMM overlap test runner
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixing linting errors
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* added documentation for te.initialize_ub
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed compile errors when building with NVTE_UB_WITH_MPI=1
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed default bootstrap backend
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* switched default bootstrap backend priority to MPI > Gloo > NCCL
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* updated bootstrap backend documentation
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* close UB bootstrap socket to avoid interfering with CUDA Multicast shareable file handle send/recv
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added torch::Tensor wrappers for communication buffer and atomic counters so PyTorch can factor externally allocated memory into its garbage collection threshold
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* automated handling of world, local and node ranks/sizes within C++ CommOverlapHelper to simplify Python function signatures
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed incorrect read of environment variables
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected priority for _SOCKET_IFNAME environment variables in UB bootstrapping
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* moved multicast support check to cuda_runtime.h and replaced cudaDeviceGetProp call with cached sm_count()
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* removed commented out old code and replaced external collective function type defines with aliases
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* compile-time CUDA version guard for CUDA Driver Multicast attribute
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added compile-time CUDA version guards to Multicast code in Userbuffers
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* condensed UB docs, corrected const violations
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed autodoc rst for UB calls, added CUDA version guard on Multicast UB kernels
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed incorrect UB type reporting for P2P overlaps, comment reformatting
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* add docstring to tex.ubuf_built_with_mpi()
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 35bbe740
...@@ -38,3 +38,4 @@ dist/ ...@@ -38,3 +38,4 @@ dist/
downloads/ downloads/
.pytest_cache/ .pytest_cache/
compile_commands.json compile_commands.json
.nfs
...@@ -11,7 +11,6 @@ import setuptools ...@@ -11,7 +11,6 @@ import setuptools
from .utils import ( from .utils import (
all_files_in_dir, all_files_in_dir,
cuda_archs, cuda_archs,
cuda_path,
cuda_version, cuda_version,
) )
...@@ -29,9 +28,6 @@ def setup_pytorch_extension( ...@@ -29,9 +28,6 @@ def setup_pytorch_extension(
sources = [ sources = [
csrc_source_files / "common.cu", csrc_source_files / "common.cu",
csrc_source_files / "ts_fp8_op.cpp", csrc_source_files / "ts_fp8_op.cpp",
csrc_source_files / "userbuffers" / "ipcsocket.cc",
csrc_source_files / "userbuffers" / "userbuffers.cu",
csrc_source_files / "userbuffers" / "userbuffers-host.cpp",
] + all_files_in_dir(extensions_dir) ] + all_files_in_dir(extensions_dir)
# Header files # Header files
...@@ -85,19 +81,14 @@ def setup_pytorch_extension( ...@@ -85,19 +81,14 @@ def setup_pytorch_extension(
continue # Already handled continue # Already handled
nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"]) nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"])
# Libraries if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
library_dirs = []
libraries = []
if bool(int(os.getenv("NVTE_UB_WITH_MPI", 0))):
assert ( assert (
os.getenv("MPI_HOME") is not None os.getenv("MPI_HOME") is not None
), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1" ), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!"
mpi_home = Path(os.getenv("MPI_HOME")) mpi_path = Path(os.getenv("MPI_HOME"))
include_dirs.append(mpi_home / "include") include_dirs.append(mpi_path / "include")
cxx_flags.append("-DNVTE_UB_WITH_MPI") cxx_flags.append("-DNVTE_UB_WITH_MPI")
nvcc_flags.append("-DNVTE_UB_WITH_MPI") nvcc_flags.append("-DNVTE_UB_WITH_MPI")
library_dirs.append(mpi_home / "lib")
libraries.append("mpi")
# Construct PyTorch CUDA extension # Construct PyTorch CUDA extension
sources = [str(path) for path in sources] sources = [str(path) for path in sources]
...@@ -112,6 +103,4 @@ def setup_pytorch_extension( ...@@ -112,6 +103,4 @@ def setup_pytorch_extension(
"cxx": cxx_flags, "cxx": cxx_flags,
"nvcc": nvcc_flags, "nvcc": nvcc_flags,
}, },
libraries=[str(lib) for lib in libraries],
library_dirs=[str(lib_dir) for lib_dir in library_dirs],
) )
...@@ -51,3 +51,7 @@ pyTorch ...@@ -51,3 +51,7 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.moe_permute .. autoapifunction:: transformer_engine.pytorch.moe_permute
.. autoapifunction:: transformer_engine.pytorch.moe_unpermute .. autoapifunction:: transformer_engine.pytorch.moe_unpermute
.. autoapifunction:: transformer_engine.pytorch.initialize_ub
.. autoapifunction:: transformer_engine.pytorch.destroy_ub
...@@ -57,13 +57,20 @@ class TimedBdist(bdist_wheel): ...@@ -57,13 +57,20 @@ class TimedBdist(bdist_wheel):
def setup_common_extension() -> CMakeExtension: def setup_common_extension() -> CMakeExtension:
"""Setup CMake extension for common library""" """Setup CMake extension for common library"""
cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(cuda_archs())]
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
assert (
os.getenv("MPI_HOME") is not None
), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
cmake_flags.append("-DNVTE_UB_WITH_MPI=ON")
# Project directory root # Project directory root
root_path = Path(__file__).resolve().parent root_path = Path(__file__).resolve().parent
return CMakeExtension( return CMakeExtension(
name="transformer_engine", name="transformer_engine",
cmake_path=root_path / Path("transformer_engine/common"), cmake_path=root_path / Path("transformer_engine/common"),
cmake_flags=["-DCMAKE_CUDA_ARCHITECTURES={}".format(cuda_archs())], cmake_flags=cmake_flags,
) )
......
...@@ -22,6 +22,7 @@ import transformer_engine.pytorch.cpp_extensions as tex ...@@ -22,6 +22,7 @@ import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.common.recipe import Format from transformer_engine.common.recipe import Format
from transformer_engine.pytorch.fp8 import _default_sf_compute from transformer_engine.pytorch.fp8 import _default_sf_compute
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=UserWarning)
...@@ -32,8 +33,8 @@ torch_dtypes = { ...@@ -32,8 +33,8 @@ torch_dtypes = {
} }
nvte_comm_types = { nvte_comm_types = {
"rs": 0, "rs": tex.CommOverlapType.RS,
"ag": 1, "ag": tex.CommOverlapType.AG,
} }
...@@ -75,7 +76,7 @@ def _parse_args(argv=None, namespace=None): ...@@ -75,7 +76,7 @@ def _parse_args(argv=None, namespace=None):
parser.add_argument( parser.add_argument(
"--comm-type", "--comm-type",
type=partial(_mapped_argtype, typemap=nvte_comm_types), type=partial(_mapped_argtype, typemap=nvte_comm_types),
default=0, default=tex.CommOverlapType.AG,
help="Comm type to overlap.", help="Comm type to overlap.",
) )
parser.add_argument( parser.add_argument(
...@@ -156,11 +157,9 @@ def _parse_args(argv=None, namespace=None): ...@@ -156,11 +157,9 @@ def _parse_args(argv=None, namespace=None):
if opts.fp8: if opts.fp8:
warnings.warn("Bulk overlap is supported in FP8 but only tested in BF16.") warnings.warn("Bulk overlap is supported in FP8 but only tested in BF16.")
opts.fp8 = False opts.fp8 = False
elif opts.comm_type == 1: elif opts.comm_type == tex.CommOverlapType.AG:
if opts.atomic: if opts.atomic:
setattr(opts, "atomic_rs_p2p", opts.p2p) setattr(opts, "atomic_rs_p2p", opts.p2p)
if not opts.p2p:
warnings.warn("All-gather overlap is only supported with point-2-point comms.")
opts.p2p = True opts.p2p = True
if opts.atomic: if opts.atomic:
...@@ -283,35 +282,35 @@ def _main(opts): ...@@ -283,35 +282,35 @@ def _main(opts):
if WORLD_RANK == 0: if WORLD_RANK == 0:
print("\n", end="", flush=True) print("\n", end="", flush=True)
ub_callbacks = ( helper = (
tex.UbufBootstrapCallbacks() tex.CommOverlapHelper()
if tex.ubuf_built_with_mpi() if tex.ubuf_built_with_mpi()
else tex.UbufBootstrapCallbacks(bootstrap_pg, bootstrap_pg) else tex.CommOverlapHelper(bootstrap_pg)
) )
if opts.comm_type == 0: if opts.comm_type == tex.CommOverlapType.RS:
if opts.bulk_overlap: if opts.bulk_overlap:
ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_RS ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_RS
elif opts.p2p: elif opts.p2p:
ub_algo = ( ub_algo = (
tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P
if opts.atomic if opts.atomic
else tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P else tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P
) )
else: else:
ub_algo = ( ub_algo = (
tex.UbufOverlapAlgo.ATOMIC_GEMM_RS tex.CommOverlapAlgo.ATOMIC_GEMM_RS
if opts.atomic if opts.atomic
else tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS else tex.CommOverlapAlgo.SPLIT_PIPELINED_RS
) )
elif opts.comm_type == 1: elif opts.comm_type == tex.CommOverlapType.AG:
if opts.bulk_overlap: if opts.bulk_overlap:
ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG
else: else:
ub_algo = ( ub_algo = (
tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P
if opts.atomic if opts.atomic
else tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P else tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
) )
else: else:
raise TypeError("Invalid comm+GEMM overlap type!") raise TypeError("Invalid comm+GEMM overlap type!")
...@@ -322,95 +321,55 @@ def _main(opts): ...@@ -322,95 +321,55 @@ def _main(opts):
hidden_size = opts.num_heads * opts.head_dim hidden_size = opts.num_heads * opts.head_dim
inp_shape = (opts.seq_length, opts.batch_size, hidden_size) inp_shape = (opts.seq_length, opts.batch_size, hidden_size)
outer_size = reduce(operator.mul, inp_shape[:-1], 1) outer_size = reduce(operator.mul, inp_shape[:-1], 1)
ubuf_dtype = torch.bfloat16 buffer_dtype = torch.bfloat16
if opts.fp8 and not opts.bulk_overlap and (opts.comm_type == 1 or opts.fp8_output): if (
ubuf_dtype = torch.uint8 opts.fp8
sample_buffer = torch.empty((outer_size, hidden_size), dtype=ubuf_dtype, device="cuda") and not opts.bulk_overlap
ub_obj = ub_obj = ( and (opts.comm_type == tex.CommOverlapType.AG or opts.fp8_output)
tex.UbufP2PCommOverlap( ):
sample_buffer, # Sample userbuffer buffer_dtype = torch.uint8
WORLD_RANK, # World rank ub_obj = (
WORLD_SIZE, # World size tex.CommOverlapP2P(
LOCAL_RANK, # Rank within the node (outer_size, hidden_size),
LOCAL_SIZE, # Number of ranks/GPUs per node buffer_dtype,
0, # Node ID helper,
1, # Number of nodes
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
1, # Number of communication SMs opts.comm_type,
1, # CGA cluster size set_sm_margin=opts.comm_type == tex.CommOverlapType.RS or opts.atomic,
opts.comm_type == 0 or opts.atomic, # Set SM margin atomic_gemm=opts.atomic,
opts.aggregate, # Aggregate 2X GEMM chunks aggregate=opts.aggregate,
3, # Max concurrent GEMM streams use_ce=not (opts.atomic and bool(int(os.getenv("NVTE_AG_P2P_MULTI_ATOMIC", "0")))),
opts.comm_type == 0, # overlap with reduce scatter
opts.atomic, # use a single GEMM with atomic-counters
not (opts.atomic and bool(int(os.getenv("NVTE_AG_P2P_MULTI_ATOMIC", "0")))),
ub_callbacks,
) )
if opts.p2p if opts.p2p
else tex.UbufCommOverlap( else tex.CommOverlap(
sample_buffer, # Sample userbuffer (outer_size, hidden_size),
WORLD_RANK, # World rank buffer_dtype,
WORLD_SIZE, # World size helper,
LOCAL_RANK, # Rank within the node
LOCAL_SIZE, # Number of ranks/GPUs per node
0, # Node ID
1, # Number of nodes
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
16, # Number of communication SMs atomic_gemm=opts.atomic,
2, # CGA cluster size
4, # Number of communication splits
True, # Set SM margin
3, # Max concurrent GEMM streams
opts.atomic, # Use a single GEMM with atomic-counters
ub_callbacks,
) )
) )
# Numerical check on AG + atomic GEMM requires testing an AG+RS pair # Numerical check on AG + atomic GEMM requires testing an AG+RS pair
ub_obj2 = None ub_obj2 = None
if opts.atomic and opts.comm_type == 1 and opts.check_numerics: if opts.atomic and opts.comm_type == tex.CommOverlapType.AG and opts.check_numerics:
sample_buffer2 = torch.empty(
(outer_size, hidden_size),
dtype=torch.uint8 if opts.fp8_output else torch.bfloat16,
device="cuda",
)
ub_obj2 = ( ub_obj2 = (
tex.UbufP2PCommOverlap( tex.CommOverlapP2P(
sample_buffer2, # Sample userbuffer (outer_size, hidden_size),
WORLD_RANK, # World rank torch.uint8 if opts.fp8_output else torch.bfloat16,
WORLD_SIZE, # World size helper,
LOCAL_RANK, # Rank within the node
LOCAL_SIZE, # Number of ranks/GPUs per node
0, # Node ID
1, # Number of nodes
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
1, # Number of communication SMs tex.CommOverlapType.RS,
1, # CGA cluster size set_sm_margin=True,
True, # Set SM margin atomic_gemm=True,
False, # Aggregate 2X GEMM chunks
3, # Max concurrent GEMM streams
True, # overlap with reduce scatter
True, # use a single GEMM with atomic-counters
True, # use copy engine for P2P communications
ub_callbacks,
) )
if opts.atomic_rs_p2p if opts.atomic_rs_p2p
else tex.UbufCommOverlap( else tex.CommOverlap(
sample_buffer2, # Sample userbuffer (outer_size, hidden_size),
WORLD_RANK, # World rank torch.uint8 if opts.fp8_output else torch.bfloat16,
WORLD_SIZE, # World size helper,
LOCAL_RANK, # Rank within the node
LOCAL_SIZE, # Number of ranks/GPUs per node
0, # Node ID
1, # Number of nodes
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
16, # Number of communication SMs atomic_gemm=True,
2, # CGA cluster size
4, # Number of communication splits
True, # Set SM margin
3, # Max concurrent GEMM streams
True, # uUe a single GEMM with atomic-counters
ub_callbacks,
) )
) )
...@@ -426,12 +385,12 @@ def _main(opts): ...@@ -426,12 +385,12 @@ def _main(opts):
local_kernel_t_shape = (ffn_hidden_size, hidden_size) local_kernel_t_shape = (ffn_hidden_size, hidden_size)
local_inp_shape = (outer_size, hidden_size) local_inp_shape = (outer_size, hidden_size)
# Bulk overlap comm tensor is distributed for AG overlap only # Bulk overlap comm tensor is distributed for AG overlap only
if opts.comm_type == 1: if opts.comm_type == tex.CommOverlapType.AG:
bulk_inp_shape = (outer_size // tp_size, hidden_size) bulk_inp_shape = (outer_size // tp_size, hidden_size)
else: else:
bulk_inp_shape = (outer_size, hidden_size) bulk_inp_shape = (outer_size, hidden_size)
else: else:
if opts.comm_type == 1: if opts.comm_type == tex.CommOverlapType.AG:
# (M/P, N) -> overlapped AG -> (M, N) x (K/P, N)^T = (M, K/P) # (M/P, N) -> overlapped AG -> (M, N) x (K/P, N)^T = (M, K/P)
local_kernel_t_shape = (ffn_hidden_size // tp_size, hidden_size) local_kernel_t_shape = (ffn_hidden_size // tp_size, hidden_size)
local_inp_shape = (outer_size // tp_size, hidden_size) local_inp_shape = (outer_size // tp_size, hidden_size)
...@@ -472,7 +431,7 @@ def _main(opts): ...@@ -472,7 +431,7 @@ def _main(opts):
std=opts.std, std=opts.std,
) )
else: else:
if opts.comm_type == 1: if opts.comm_type == tex.CommOverlapType.AG:
# AG Kernel: (K/P, N) -> gather -> (K, N) -> T -> (N, K) # AG Kernel: (K/P, N) -> gather -> (K, N) -> T -> (N, K)
ker_g = torch.transpose( ker_g = torch.transpose(
te.distributed.gather_along_first_dim(kernel_t, tp_group)[0], 0, 1 te.distributed.gather_along_first_dim(kernel_t, tp_group)[0], 0, 1
...@@ -494,7 +453,7 @@ def _main(opts): ...@@ -494,7 +453,7 @@ def _main(opts):
).to(dtype=torch.float32) ).to(dtype=torch.float32)
if opts.bulk_overlap: if opts.bulk_overlap:
if opts.comm_type == 1: if opts.comm_type == tex.CommOverlapType.AG:
ref_g = te.distributed.gather_along_first_dim(bulk_inp, tp_group)[0] ref_g = te.distributed.gather_along_first_dim(bulk_inp, tp_group)[0]
else: else:
# First all-gather all the bulk inputs into a list # First all-gather all the bulk inputs into a list
...@@ -505,7 +464,7 @@ def _main(opts): ...@@ -505,7 +464,7 @@ def _main(opts):
else: else:
ref_g = torch.matmul(inp_g, ker_g) ref_g = torch.matmul(inp_g, ker_g)
if ub_obj2 is not None: if ub_obj2 is not None:
inp2_g = torch.nn.functional.gelu(ref_g) inp2_g = torch.nn.functional.gelu(ref_g) # pylint: disable=not-callable
ref2_g = torch.matmul(inp2_g, ker2_g) ref2_g = torch.matmul(inp2_g, ker2_g)
if opts.fp8: if opts.fp8:
...@@ -529,7 +488,7 @@ def _main(opts): ...@@ -529,7 +488,7 @@ def _main(opts):
fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_WEIGHT].copy_(ker_amax) fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_WEIGHT].copy_(ker_amax)
ref_amax = torch.max(torch.abs(ref_g)) ref_amax = torch.max(torch.abs(ref_g))
fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_OUTPUT].copy_(ref_amax) fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_OUTPUT].copy_(ref_amax)
if opts.bulk_overlap and opts.comm_type == 0: if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS:
bulk_amax = torch.max(torch.abs(bulk_inp)) bulk_amax = torch.max(torch.abs(bulk_inp))
fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_OUTPUT].copy_(bulk_amax) fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_OUTPUT].copy_(bulk_amax)
elif ub_obj2 is not None: elif ub_obj2 is not None:
...@@ -551,7 +510,7 @@ def _main(opts): ...@@ -551,7 +510,7 @@ def _main(opts):
kernel_t_fp8 = tex.cast_to_fp8( kernel_t_fp8 = tex.cast_to_fp8(
kernel_t, fp8_meta, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype kernel_t, fp8_meta, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype
) )
if opts.bulk_overlap and opts.comm_type == 0: if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS:
bulk_inp_fp8 = tex.cast_to_fp8( bulk_inp_fp8 = tex.cast_to_fp8(
bulk_inp, fp8_meta, tex.FP8Tensors.GEMM2_OUTPUT, fp8_dtype bulk_inp, fp8_meta, tex.FP8Tensors.GEMM2_OUTPUT, fp8_dtype
) )
...@@ -574,7 +533,7 @@ def _main(opts): ...@@ -574,7 +533,7 @@ def _main(opts):
rtol=0.125, rtol=0.125,
atol=0.0675, atol=0.0675,
) )
if opts.bulk_overlap and opts.comm_type == 0: if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS:
torch.allclose( torch.allclose(
bulk_inp.to(dtype=torch.float32), bulk_inp.to(dtype=torch.float32),
bulk_inp_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT], bulk_inp_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT],
...@@ -590,7 +549,7 @@ def _main(opts): ...@@ -590,7 +549,7 @@ def _main(opts):
) )
# Set Fp8 scales for userbuffers # Set Fp8 scales for userbuffers
if opts.comm_type == 1: if opts.comm_type == tex.CommOverlapType.AG:
ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_INPUT]) ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_INPUT])
if ub_obj2 is not None: if ub_obj2 is not None:
ub_obj2.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT]) ub_obj2.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT])
...@@ -602,7 +561,7 @@ def _main(opts): ...@@ -602,7 +561,7 @@ def _main(opts):
# Set up comm/compute buffers # Set up comm/compute buffers
ubuf_out2 = None ubuf_out2 = None
rs_out2 = None rs_out2 = None
if opts.comm_type == 1: if opts.comm_type == tex.CommOverlapType.AG:
if opts.bulk_overlap: if opts.bulk_overlap:
ub_obj.copy_input_to_ubuf(bulk_inp, 1) ub_obj.copy_input_to_ubuf(bulk_inp, 1)
gemm_inp = inp gemm_inp = inp
...@@ -686,9 +645,9 @@ def _main(opts): ...@@ -686,9 +645,9 @@ def _main(opts):
gelu=False, gelu=False,
use_split_accumulator=te.module.base._2X_ACC_FPROP, use_split_accumulator=te.module.base._2X_ACC_FPROP,
ub_algo=( ub_algo=(
tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P
if opts.atomic_rs_p2p if opts.atomic_rs_p2p
else tex.UbufOverlapAlgo.ATOMIC_GEMM_RS else tex.CommOverlapAlgo.ATOMIC_GEMM_RS
), ),
ub=ub_obj2, ub=ub_obj2,
extra_output_tensor=rs_out2, extra_output_tensor=rs_out2,
...@@ -762,10 +721,14 @@ def _main(opts): ...@@ -762,10 +721,14 @@ def _main(opts):
avg_gpu_time = sum(gpu_times) / opts.timing_iters avg_gpu_time = sum(gpu_times) / opts.timing_iters
gemm_name = "".join( gemm_name = "".join(
[ [
"p2p all-gather + " if opts.comm_type == 1 else "", "p2p all-gather + " if opts.comm_type == tex.CommOverlapType.AG else "",
"atomic " if opts.atomic else "", "atomic " if opts.atomic else "",
"GEMM", "GEMM",
(f" + {'p2p ' if opts.p2p else ''}reduce-scatter" if opts.comm_type == 0 else ""), (
f" + {'p2p ' if opts.p2p else ''}reduce-scatter"
if opts.comm_type == tex.CommOverlapType.RS
else ""
),
] ]
) )
timing_info = ( timing_info = (
...@@ -781,7 +744,7 @@ def _main(opts): ...@@ -781,7 +744,7 @@ def _main(opts):
dist.barrier(tp_group) dist.barrier(tp_group)
if opts.bulk_overlap: if opts.bulk_overlap:
output_info = "" output_info = ""
if opts.comm_type == 1: if opts.comm_type == tex.CommOverlapType.AG:
# Bulk overlap AG output is already gathered # Bulk overlap AG output is already gathered
test_out = ub_obj.get_ubuf_output(1) test_out = ub_obj.get_ubuf_output(1)
else: else:
...@@ -794,7 +757,7 @@ def _main(opts): ...@@ -794,7 +757,7 @@ def _main(opts):
output_info += f"output: {list(test_out.shape)} | reference: {list(ref_out.shape)}" output_info += f"output: {list(test_out.shape)} | reference: {list(ref_out.shape)}"
dist_print( dist_print(
output_info, output_info,
src=0 if opts.comm_type == 0 else None, src=0 if opts.comm_type == tex.CommOverlapType.RS else None,
section=True, section=True,
) )
...@@ -805,7 +768,7 @@ def _main(opts): ...@@ -805,7 +768,7 @@ def _main(opts):
) )
dist_print(nonzero_info, src=0, section=True, group=tp_group) dist_print(nonzero_info, src=0, section=True, group=tp_group)
else: else:
if opts.comm_type == 1: if opts.comm_type == tex.CommOverlapType.AG:
if ub_obj2 is not None: if ub_obj2 is not None:
# AG+RS Output: (M/P, N) -> gather -> (M, N) # AG+RS Output: (M/P, N) -> gather -> (M, N)
output = rs_out2.to(dtype=torch.float32) output = rs_out2.to(dtype=torch.float32)
......
...@@ -9,7 +9,6 @@ import sys ...@@ -9,7 +9,6 @@ import sys
import socket import socket
import argparse import argparse
import warnings import warnings
from functools import partial
import torch import torch
import torch.distributed as dist import torch.distributed as dist
......
...@@ -42,6 +42,9 @@ if not tex.device_supports_multicast(): ...@@ -42,6 +42,9 @@ if not tex.device_supports_multicast():
# Force GPU kernels to launch in the order they're executed by the host CPU # Force GPU kernels to launch in the order they're executed by the host CPU
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
# Clear torch.dynamo caches
torch._dynamo.reset()
def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggregate): def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggregate):
test_path = TEST_ROOT / "run_gemm_with_overlap.py" test_path = TEST_ROOT / "run_gemm_with_overlap.py"
......
...@@ -80,7 +80,11 @@ list(APPEND transformer_engine_SOURCES ...@@ -80,7 +80,11 @@ list(APPEND transformer_engine_SOURCES
fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_rope/fused_rope.cu fused_rope/fused_rope.cu
recipe/delayed_scaling.cu) recipe/delayed_scaling.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/comm_gemm_overlap.cpp)
add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
target_include_directories(transformer_engine PUBLIC target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include") "${CMAKE_CURRENT_SOURCE_DIR}/include")
...@@ -93,6 +97,15 @@ target_include_directories(transformer_engine PRIVATE ...@@ -93,6 +97,15 @@ target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
# Compiling Userbuffers with native MPI bootstrapping requires linking against MPI
option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF)
if (NVTE_UB_WITH_MPI)
find_package(MPI REQUIRED)
target_link_libraries(transformer_engine PUBLIC MPI::MPI_CXX)
target_include_directories(transformer_engine PRIVATE ${MPI_CXX_INCLUDES})
target_compile_definitions(transformer_engine PUBLIC NVTE_UB_WITH_MPI)
endif()
# Hack to enable dynamic loading in cuDNN frontend # Hack to enable dynamic loading in cuDNN frontend
target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING)
......
...@@ -20,7 +20,9 @@ ...@@ -20,7 +20,9 @@
#include <utility> #include <utility>
#include "common/util/cuda_driver.h" #include "common/util/cuda_driver.h"
#include "common/util/cuda_runtime.h"
#include "common/util/logging.h" #include "common/util/logging.h"
#include "common/util/system.h"
#include "ipcsocket.h" #include "ipcsocket.h"
#include "userbuffers.h" #include "userbuffers.h"
...@@ -44,31 +46,19 @@ static MPI_Comm EXT_COMM_INTER; ...@@ -44,31 +46,19 @@ static MPI_Comm EXT_COMM_INTER;
} while (false) } while (false)
void ub_mpi_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes, void ub_mpi_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes,
ExtComm group) { ExtComm comm) {
// UB_MPI_CHECK(MPI_Allgather(localdata, localbytes, MPI_BYTE,
// globaldata, globalbytes, MPI_BYTE,
// static_cast<MPI_Comm>(group)));
MPI_Comm comm = static_cast<MPI_Comm>(group);
int numranks; int numranks;
UB_MPI_CHECK(MPI_Comm_size(comm, &numranks)); UB_MPI_CHECK(MPI_Comm_size(comm, &numranks));
assert(globalbytes == numranks * localbytes); assert(globalbytes == numranks * localbytes);
UB_MPI_CHECK(
int myrank; MPI_Allgather(localdata, localbytes, MPI_BYTE, globaldata, localbytes, MPI_BYTE, comm));
UB_MPI_CHECK(MPI_Comm_rank(comm, &myrank));
char *globaltarget = reinterpret_cast<char *>(globaldata) + (myrank * localbytes);
memcpy(globaltarget, localdata, localbytes);
for (int n = 0; n < numranks; n++) {
globaltarget = reinterpret_cast<char *>(globaldata) + (n * localbytes);
UB_MPI_CHECK(MPI_Bcast(globaltarget, localbytes, MPI_BYTE, n, comm));
}
} }
void ub_mpi_barrier(ExtComm group) { UB_MPI_CHECK(MPI_Barrier(static_cast<MPI_Comm>(group))); } void ub_mpi_barrier(ExtComm comm) { UB_MPI_CHECK(MPI_Barrier(comm)); }
#else #else
static char EXT_COMM_WORLD[] = "world"; #define EXT_COMM_WORLD "world"
static char EXT_COMM_INTRA[] = "intra"; #define EXT_COMM_INTRA "intra"
static char EXT_COMM_INTER[] = "inter"; #define EXT_COMM_INTER "inter"
#endif #endif
#define MULTICAST_GB_TOTAL 512 #define MULTICAST_GB_TOTAL 512
...@@ -106,11 +96,10 @@ int pipe_rank(communicator *comm, int step) { ...@@ -106,11 +96,10 @@ int pipe_rank(communicator *comm, int step) {
return newnode * numlocal + newlocal; return newnode * numlocal + newlocal;
} }
int create_communicator_grouped2( int create_communicator_grouped2(communicator **comm, int myrank, int numranks, int mylocal,
communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numlocal, int mynode, int numnodes,
int numnodes, std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather, ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier,
std::function<void(ExtComm)> ext_barrier, int pipegpus, int pipenodes, int tensorgpus, int pipegpus, int pipenodes, int tensorgpus, int tensornodes) {
int tensornodes) {
*comm = new communicator(); *comm = new communicator();
(*comm)->comm_world = EXT_COMM_WORLD; (*comm)->comm_world = EXT_COMM_WORLD;
...@@ -214,8 +203,11 @@ int create_communicator_grouped2( ...@@ -214,8 +203,11 @@ int create_communicator_grouped2(
(*comm)->asyncblocks = 16; (*comm)->asyncblocks = 16;
#define NBUF 2 #define NBUF 2
if ((*comm)->sm_arch >= 9 && (*comm)->ar2_nvsize > 1 &&
!getenv("UB_SKIPMC")) { // multicast init only for TP ops (____2 operations) #if CUDART_VERSION >= 12010
if (!transformer_engine::getenv<bool>("UB_SKIPMC") &&
transformer_engine::cuda::supports_multicast() && (*comm)->ar2_nvsize > 1) {
// multicast init only for TP ops (____2 operations)
size_t mc_maxsize = MULTICAST_GB_TOTAL * (1ull << 30); size_t mc_maxsize = MULTICAST_GB_TOTAL * (1ull << 30);
(*comm)->mc_offset = 0; (*comm)->mc_offset = 0;
(*comm)->use_mc = 1; (*comm)->use_mc = 1;
...@@ -291,20 +283,20 @@ int create_communicator_grouped2( ...@@ -291,20 +283,20 @@ int create_communicator_grouped2(
(*comm)->_barrier((*comm)->comm_world); (*comm)->_barrier((*comm)->comm_world);
if (!(*comm)->myrank) printf("MC initialized succesfully, window size = %ld\n", mc_maxsize); if (!(*comm)->myrank) printf("MC initialized succesfully, window size = %ld\n", mc_maxsize);
} else { } else {
#endif
if (!(*comm)->myrank) printf("MC NOT initialized and used\n"); if (!(*comm)->myrank) printf("MC NOT initialized and used\n");
(*comm)->mc_maxsize = 0; (*comm)->mc_maxsize = 0;
(*comm)->mc_offset = 0; (*comm)->mc_offset = 0;
(*comm)->use_mc = 0; (*comm)->use_mc = 0;
#if CUDART_VERSION >= 12010
} }
#endif
#define LOCALSIZE 4 * (NVTE_REG0_OFFSET(*comm) + NVTE_REG0_FLAGS + NVTE_REG0_COMMBUFFER * NBUF) #define LOCALSIZE 4 * (NVTE_REG0_OFFSET(*comm) + NVTE_REG0_FLAGS + NVTE_REG0_COMMBUFFER * NBUF)
// peer pointers + op flags + comm buffer // peer pointers + op flags + comm buffer
NVTE_CHECK_CUDA(
cudaMalloc(&(*comm)->gpu_ptrs, LOCALSIZE)); // flags and pointers, no block data yet
NVTE_CHECK_CUDA(cudaMemset((*comm)->gpu_ptrs, 0, LOCALSIZE));
NVTE_CHECK_CUDA(cudaDeviceSynchronize()); NVTE_CHECK_CUDA(cudaDeviceSynchronize());
register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm, false); register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm, true);
NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->send_id, (*comm)->nranks * sizeof(int))); NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->send_id, (*comm)->nranks * sizeof(int)));
NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->recv_id, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int))); NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->recv_id, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int)));
NVTE_CHECK_CUDA(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int))); NVTE_CHECK_CUDA(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int)));
...@@ -346,18 +338,17 @@ int create_communicator_grouped2( ...@@ -346,18 +338,17 @@ int create_communicator_grouped2(
return 0; return 0;
} }
int create_communicator_grouped( int create_communicator_grouped(communicator **comm, int myrank, int numranks, int mylocal,
communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numlocal, int mynode, int numnodes,
int numnodes, std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather, ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier,
std::function<void(ExtComm)> ext_barrier, int pipegpus, int pipenodes) { int pipegpus, int pipenodes) {
return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes,
ext_allgather, ext_barrier, pipegpus, pipenodes, 1, 1); ext_allgather, ext_barrier, pipegpus, pipenodes, 1, 1);
} }
int create_communicator(communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int create_communicator(communicator **comm, int myrank, int numranks, int mylocal, int numlocal,
int mynode, int numnodes, int mynode, int numnodes, ExtAllgatherOp ext_allgather,
std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather, ExtBarrierOp ext_barrier) {
std::function<void(ExtComm)> ext_barrier) {
return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes,
ext_allgather, ext_barrier, 1, 1, 1, 1); ext_allgather, ext_barrier, 1, 1, 1, 1);
} }
...@@ -428,7 +419,7 @@ int create_communicator_mpi(communicator **comm) { ...@@ -428,7 +419,7 @@ int create_communicator_mpi(communicator **comm) {
void destroy_communicator(communicator *comm) { void destroy_communicator(communicator *comm) {
for (int hndl = 0; hndl < comm->free_region; hndl++) { for (int hndl = 0; hndl < comm->free_region; hndl++) {
if (hndl > 0 && comm->use_mc && comm->mem_dealloc[hndl]) { if (comm->use_mc && comm->mem_dealloc[hndl]) {
for (int rank = 0; rank < comm->nvsize; rank++) { for (int rank = 0; rank < comm->nvsize; rank++) {
if (rank == comm->nvrank) { if (rank == comm->nvrank) {
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemRelease, comm->uchandles[hndl][rank]); NVTE_CALL_CHECK_CUDA_DRIVER(cuMemRelease, comm->uchandles[hndl][rank]);
...@@ -479,6 +470,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * ...@@ -479,6 +470,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
comm->memflags[hndl] = 0; comm->memflags[hndl] = 0;
comm->mem_dealloc[hndl] = alloc; comm->mem_dealloc[hndl] = alloc;
#if CUDART_VERSION >= 12010
if (comm->use_mc && alloc) { if (comm->use_mc && alloc) {
int nranks = comm->nvsize; // total GPUs in NVLINK domain int nranks = comm->nvsize; // total GPUs in NVLINK domain
int myrank = comm->nvrank; int myrank = comm->nvrank;
...@@ -594,6 +586,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * ...@@ -594,6 +586,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
} }
} else { } else {
#endif
if (alloc) { if (alloc) {
NVTE_CHECK_CUDA(cudaMalloc(gpubuff, bytes)); NVTE_CHECK_CUDA(cudaMalloc(gpubuff, bytes));
NVTE_CHECK_CUDA(cudaMemset(*gpubuff, 0, bytes)); NVTE_CHECK_CUDA(cudaMemset(*gpubuff, 0, bytes));
...@@ -624,7 +617,9 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * ...@@ -624,7 +617,9 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
NVTE_CHECK_CUDA(cudaDeviceSynchronize()); NVTE_CHECK_CUDA(cudaDeviceSynchronize());
free(tmp); free(tmp);
#if CUDART_VERSION >= 12010
} }
#endif
comm->mem_size[hndl] = aligned_size; comm->mem_size[hndl] = aligned_size;
comm->mem_ptr[hndl] = *gpubuff; comm->mem_ptr[hndl] = *gpubuff;
......
...@@ -392,14 +392,14 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ ...@@ -392,14 +392,14 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id; if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id;
} // fp16 reduce-scatter kernel (out of place) } // fp16 reduce-scatter kernel (out of place)
#if __CUDA_ARCH__ >= 900 #if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010
// All MC kernels here // All MC kernels here
template <int RANKS> template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS) __global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_mc(const int op, const int flagoffset, const int firstrank, userbuffers_fp16_sum_inplace_gpu_mc(const int op, const int flagoffset, const int firstrank,
const int myrank, const int gpustep, const int lineoffset, const int myrank, const int gpustep, const int lineoffset,
const int numlines, void **commbuff, const int handleridx, const int numlines, void **commbuff, const int handleridx,
float4 *mc_ptr) { float4 *mc_ptr, const uint64_t ub_timeout) {
int *flagptr, physgpu, targetgpu, *myptr; int *flagptr, physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id; int *reduceidptr, reduce_id;
...@@ -417,7 +417,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -417,7 +417,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
volatile int *flag = (volatile int *)&(myptr[targetgpu]); volatile int *flag = (volatile int *)&(myptr[targetgpu]);
clock_t s = clock64(); clock_t s = clock64();
while (CHECK_IDS(*flag, reduce_id)) { while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > TIMEOUT) { if (clock64() - s > ub_timeout) {
UB_PRINT("Reduce-scatter: SM %d [%d]: expecting %d got %d", blockIdx.x, threadIdx.x, UB_PRINT("Reduce-scatter: SM %d [%d]: expecting %d got %d", blockIdx.x, threadIdx.x,
reduce_id, *flag); reduce_id, *flag);
break; break;
...@@ -484,7 +484,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -484,7 +484,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
volatile int *flag = (volatile int *)&myptr[targetgpu]; volatile int *flag = (volatile int *)&myptr[targetgpu];
clock_t s = clock64(); clock_t s = clock64();
while (CHECK_IDS(*flag, reduce_id)) { while (CHECK_IDS(*flag, reduce_id)) {
if (clock64() - s > 2ull * TIMEOUT) { if (clock64() - s > 2ull * ub_timeout) {
UB_PRINT("Allgather: SM %d [%d]: expecting %d got %d", blockIdx.x, threadIdx.x, reduce_id, UB_PRINT("Allgather: SM %d [%d]: expecting %d got %d", blockIdx.x, threadIdx.x, reduce_id,
*flag); *flag);
break; break;
...@@ -741,7 +741,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -741,7 +741,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_mc(const int op, const int flagoffset, const int firstrank, userbuffers_fp16_sum_inplace_gpu_mc(const int op, const int flagoffset, const int firstrank,
const int myrank, const int gpustep, const int lineoffset, const int myrank, const int gpustep, const int lineoffset,
const int numlines, void **commbuff, const int handleridx, const int numlines, void **commbuff, const int handleridx,
float4 *mc_ptr) {} float4 *mc_ptr, const uint64_t ub_timeout) {}
template <int RANKS> template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS) __global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_mc_rs_oop(const int op, const int flagoffset, userbuffers_fp16_sum_inplace_gpu_mc_rs_oop(const int op, const int flagoffset,
...@@ -2496,6 +2496,18 @@ static __global__ void consumer_batch_kernel(void *atomic_ptr, int first_chunk_i ...@@ -2496,6 +2496,18 @@ static __global__ void consumer_batch_kernel(void *atomic_ptr, int first_chunk_i
} }
} }
// reset counters kernel
static __global__ void reset_counters_kernel(void *atomic_ptr, int num_chunks, bool allgather) {
if (blockIdx.x == 0 && threadIdx.x == 0) {
#pragma unroll
for (int i = 0; i < num_chunks; i++) {
((unsigned int *)atomic_ptr)[i] = 1;
((unsigned int *)atomic_ptr)[i + num_chunks] = 0;
}
if (allgather) ((unsigned int *)atomic_ptr)[0] = 0;
}
}
void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream) { void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream) {
dim3 block(1); dim3 block(1);
dim3 grid(1); dim3 grid(1);
...@@ -2514,6 +2526,12 @@ void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStr ...@@ -2514,6 +2526,12 @@ void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStr
consumer_batch_kernel<<<grid, block, 0, stream>>>(atomic_ptr, first_chunk_i, num_chunks); consumer_batch_kernel<<<grid, block, 0, stream>>>(atomic_ptr, first_chunk_i, num_chunks);
} }
void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, cudaStream_t stream) {
dim3 block(1);
dim3 grid(1);
reset_counters_kernel<<<grid, block, 0, stream>>>(atomic_ptr, num_chunks, allgather);
}
template <typename fp8type> template <typename fp8type>
__global__ void __launch_bounds__(MAX_THREADS / 4) __global__ void __launch_bounds__(MAX_THREADS / 4)
reduce_fp8_in_bf16_out_cuda(void *inputs, void *output, const float *scale, reduce_fp8_in_bf16_out_cuda(void *inputs, void *output, const float *scale,
...@@ -2546,3 +2564,24 @@ template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(void *inputs, void *output, ...@@ -2546,3 +2564,24 @@ template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(void *inputs, void *output,
template void reduce_fp8_in_bf16_out<__nv_fp8_e5m2>(void *inputs, void *output, float *scale, template void reduce_fp8_in_bf16_out<__nv_fp8_e5m2>(void *inputs, void *output, float *scale,
int num_inputs, int input_size, int num_inputs, int input_size,
cudaStream_t stream); cudaStream_t stream);
__global__ void __launch_bounds__(MAX_THREADS / 4)
reduce_bf16_cuda(void *inputs, void *output, const int num_inputs, const int input_size) {
const size_t tid = threadIdx.x + blockDim.x * blockIdx.x;
half *inputs_half = reinterpret_cast<half *>(inputs);
float accum_buf = static_cast<float>(inputs_half[tid]);
#pragma unroll
for (int i = 1; i < num_inputs; i++) {
accum_buf += static_cast<float>(inputs_half[tid + input_size * i]);
}
half *output_half = reinterpret_cast<half *>(output);
output_half[tid] = (half)accum_buf;
}
void reduce_bf16(void *inputs, void *output, int num_inputs, int input_size, cudaStream_t stream) {
size_t num_threads = MAX_THREADS / 4;
size_t num_blocks = (input_size + num_threads - 1) / num_threads;
dim3 block(num_threads);
dim3 grid(num_blocks);
reduce_bf16_cuda<<<grid, block, 0, stream>>>(inputs, output, num_inputs, input_size);
}
...@@ -19,11 +19,14 @@ ...@@ -19,11 +19,14 @@
#ifdef NVTE_UB_WITH_MPI #ifdef NVTE_UB_WITH_MPI
#include <mpi.h> #include <mpi.h>
typedef MPI_Comm ExtComm; #define ExtComm MPI_Comm
#else #else
typedef char *ExtComm; #define ExtComm const char *
#endif #endif
using ExtAllgatherOp = std::function<void(void *, size_t, void *, size_t, ExtComm)>;
using ExtBarrierOp = std::function<void(ExtComm)>;
#define NVTE_MAX_REGIONS 16 #define NVTE_MAX_REGIONS 16
#define NVTE_MAX_SMS 32 #define NVTE_MAX_SMS 32
#define NVTE_MAX_OPS 32 #define NVTE_MAX_OPS 32
...@@ -142,12 +145,12 @@ struct communicator { ...@@ -142,12 +145,12 @@ struct communicator {
volatile int tail; volatile int tail;
// Abstract communication callbacks to support external bootstrapping (e.g. DL frameworks) // Abstract communication callbacks to support external bootstrapping (e.g. DL frameworks)
std::function<void(void *, size_t, void *, size_t, ExtComm)> _allgather; ExtAllgatherOp _allgather;
std::function<void(ExtComm)> _barrier; ExtBarrierOp _barrier;
ExtComm comm_world, ExtComm comm_world;
comm_inter, // reduction group communicator (subset of the nodes) along GPU rail ExtComm comm_inter; // reduction group communicator (subset of the nodes) along GPU rail
comm_intra; // full intranode (all ndev GPUS) ExtComm comm_intra; // full intranode (all ndev GPUS)
#ifdef NVTE_UB_WITH_MPI #ifdef NVTE_UB_WITH_MPI
MPI_Request mpihndl[NVTE_MAX_SHARP]; MPI_Request mpihndl[NVTE_MAX_SHARP];
#endif #endif
...@@ -161,23 +164,22 @@ typedef struct communicator communicator; ...@@ -161,23 +164,22 @@ typedef struct communicator communicator;
void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream); void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream);
void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream); void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream);
void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStream_t stream); void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStream_t stream);
void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, cudaStream_t stream);
/* creates communicator, allocates all internal buffers if necessary */ /* creates communicator, allocates all internal buffers if necessary */
int create_communicator_grouped2( int create_communicator_grouped2(communicator **comm, int myrank, int numranks, int mylocal,
communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numlocal, int mynode, int numnodes,
int numnodes, std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather, ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier,
std::function<void(ExtComm)> ext_barrier, int pipegpus, int pipenodes, int tensorgpus, int pipegpus, int pipenodes, int tensorgpus, int tensornodes);
int tensornodes);
int create_communicator_grouped( int create_communicator_grouped(communicator **comm, int myrank, int numranks, int mylocal,
communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numlocal, int mynode, int numnodes,
int numnodes, std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather, ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier,
std::function<void(ExtComm)> ext_barrier, int pipegpus, int pipenodes); int pipegpus, int pipenodes);
int create_communicator(communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int create_communicator(communicator **comm, int myrank, int numranks, int mylocal, int numlocal,
int mynode, int numnodes, int mynode, int numnodes, ExtAllgatherOp ext_allgather,
std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather, ExtBarrierOp ext_barrier);
std::function<void(ExtComm)> ext_barrier);
int create_communicator_grouped2_mpi(communicator **comm, int pipegpus, int pipenodes, int create_communicator_grouped2_mpi(communicator **comm, int pipegpus, int pipenodes,
int tensorgpus, int tensornodes); int tensorgpus, int tensornodes);
...@@ -314,4 +316,6 @@ template <typename fp8type> ...@@ -314,4 +316,6 @@ template <typename fp8type>
void reduce_fp8_in_bf16_out(void *input, void *output, float *scale, int num_inputs, int input_size, void reduce_fp8_in_bf16_out(void *input, void *output, float *scale, int num_inputs, int input_size,
cudaStream_t stream); cudaStream_t stream);
void reduce_bf16(void *input, void *output, int num_inputs, int input_size, cudaStream_t stream);
#endif // TRANSFORMER_ENGINE_USERBUFFERS_H_ #endif // TRANSFORMER_ENGINE_USERBUFFERS_H_
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_COMM_GEMM_OVERLAP_H_
#define TRANSFORMER_ENGINE_COMMON_COMM_GEMM_OVERLAP_H_
#include <cuda.h>
#include <cuda_fp8.h>
#include <transformer_engine/transformer_engine.h>
#include <functional>
#include "common/comm_gemm_overlap/userbuffers/userbuffers.h"
#define NVTE_COMM_OVERLAP_MAX_STREAMS 3
namespace transformer_engine {
/* \brief Check if Userbufers bootstraps with direct calls to MPI collectives.
* This can turned on by building Transformer Engine with the `NVTE_UB_WITH_MPI=1` option.
*
* \return True if Userbuffers is built with MPI
*/
bool ubuf_built_with_mpi();
enum class CommOverlapType { RS = 0, AG = 1 };
enum class CommOverlapAlgo {
BULK_OVERLAP_AG = 0,
BULK_OVERLAP_RS = 1,
SPLIT_PIPELINED_AG_P2P = 2,
SPLIT_PIPELINED_RS = 3,
SPLIT_PIPELINED_RS_P2P = 4,
ATOMIC_GEMM_RS = 5,
ATOMIC_GEMM_AG_P2P = 6,
ATOMIC_GEMM_RS_P2P = 7
};
class CommOverlapCore {
protected:
static inline communicator *_ub_comm{nullptr};
static inline bool _comm_created{false};
int _rank;
int _tp_id;
int _tp_size;
int _num_splits;
int _math_sms;
int _num_comm_sm;
int _cga_size;
int _use_ce;
int _ub_reg;
bool _atomic_gemm{false};
bool _is_p2p{false};
TensorWrapper _ubuf;
TensorWrapper _counter;
float *_ubuf_scale_inv;
bool _ubuf_scale_inv_initialized{false};
std::vector<cudaStream_t> _stream_compute;
cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm;
public:
CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes,
int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle,
int num_splits, int num_max_streams, int comm_cga_size, int num_comm_sm,
bool set_sm_margin, bool use_ce, bool atomic_gemm);
virtual ~CommOverlapCore();
void set_ubuf_scale_inv(float *scale_inv) {
_ubuf_scale_inv = scale_inv;
_ubuf_scale_inv_initialized = true;
}
bool is_atomic_gemm() { return _atomic_gemm; }
bool is_p2p_overlap() { return _is_p2p; }
bool is_fp8_ubuf() { return _ubuf.element_size() == 1; }
}; // CommOverlapCore
class CommOverlapBase : public CommOverlapCore {
protected:
int _rs_kernel_type;
cudaStream_t _stream_comm;
cudaEvent_t _start_d2dcopy;
public:
CommOverlapBase(const std::vector<size_t> &buffer_shape, DType buffer_dtype, int myrank,
int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size,
ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3,
int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2,
int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false);
virtual ~CommOverlapBase();
/*
** Bulk GEMM + COMM
** This function assumes the communication input is pre-copied to _ubuf
*/
void bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, TensorWrapper &D,
TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
bool grad, bool accumulate, bool use_split_accumulator,
CommOverlapType comm_type, TensorWrapper &rs_output, cudaStream_t stream_main);
/*
** Split FPROP GEMM + ReduceScatter
*/
void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, bool gemm_overlap,
TensorWrapper &rs_output, cudaStream_t stream_main);
/*
** Split FPROP GEMM + ReduceScatter
*/
void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, bool gemm_overlap, TensorWrapper &rs_output,
cudaStream_t stream_main);
}; // CommOverlapBase
class CommOverlapP2PBase : public CommOverlapCore {
protected:
bool _is_reduce_scatter{false};
bool _use_multiatomic_ag{false};
int _next_rank;
int _prev_rank;
int _rank_round_tp;
int _aggregate;
int _num_ubuf_chunks;
int _self_chunk_id;
std::vector<TensorWrapper> _ubufs;
cudaStream_t _stream_send;
cudaStream_t _stream_recv;
cudaEvent_t _stop_send, _stop_recv;
public:
CommOverlapP2PBase(const std::vector<size_t> &buffer_shape, DType buffer_dtype, int myrank,
int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size,
ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle,
CommOverlapType comm_type, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS,
int comm_cga_size = 1, int num_comm_sm = 1, bool set_sm_margin = false,
bool use_ce = true, bool atomic_gemm = false, bool aggregate = false);
virtual ~CommOverlapP2PBase();
/*
** Split AllGather + AtomicGEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG
** outputs in each rank to be in the contiguous memory space after all ring exchange phases.
*/
void atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &B_copy,
cudaStream_t stream_main);
/*
** Split AllGather + GEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG
** outputs in each rank to be in the contiguous memory space after all ring exchange phases.
*/
void split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &B_copy,
cudaStream_t stream_main);
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main);
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main);
}; // CommOverlapP2PBase
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMM_GEMM_OVERLAP_H_
...@@ -78,13 +78,13 @@ NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType ...@@ -78,13 +78,13 @@ NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType
*/ */
void nvte_destroy_tensor(NVTETensor tensor); void nvte_destroy_tensor(NVTETensor tensor);
/*! \brief Get a tensor's data type. /*! \brief Get a raw pointer to the tensor's data.
* *
* \param[in] tensor Tensor. * \param[in] tensor Tensor.
* *
* \return A data type of the input tensor. * \return A raw pointer to tensor's data.
*/ */
NVTEDType nvte_tensor_type(const NVTETensor tensor); void *nvte_tensor_data(const NVTETensor tensor);
/*! \brief Get a tensor's data shape. /*! \brief Get a tensor's data shape.
* *
...@@ -94,13 +94,46 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor); ...@@ -94,13 +94,46 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor);
*/ */
NVTEShape nvte_tensor_shape(const NVTETensor tensor); NVTEShape nvte_tensor_shape(const NVTETensor tensor);
/*! \brief Get a raw pointer to the tensor's data. /*! \brief Get a tensor's number of dimensions.
* *
* \param[in] tensor Tensor. * \param[in] tensor Tensor.
* *
* \return A raw pointer to tensor's data. * \return Number of tensor dimensions.
*/ */
void *nvte_tensor_data(const NVTETensor tensor); size_t nvte_tensor_ndims(const NVTETensor tensor);
/*! \brief Get the size of a specific tensor dimension.
*
* \param[in] tensor Tensor.
* \param[in] size_t Dimension index.
*
* \return Size of the tensor at the specified dimension.
*/
size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim);
/*! \brief Get a tensor's total number of elements.
*
* \param[in] tensor Tensor.
*
* \return Number of elements in the tensor.
*/
size_t nvte_tensor_numel(const NVTETensor tensor);
/*! \brief Get the byte size for the tensor's data type.
*
* \param[in] tensor Tensor.
*
* \return Byte size of the tensor's data type.
*/
size_t nvte_tensor_element_size(const NVTETensor tensor);
/*! \brief Get a tensor's data type.
*
* \param[in] tensor Tensor.
*
* \return A data type of the input tensor.
*/
NVTEDType nvte_tensor_type(const NVTETensor tensor);
/*! \brief Get a pointer to the tensor's amax data. /*! \brief Get a pointer to the tensor's amax data.
* *
...@@ -265,6 +298,56 @@ class TensorWrapper { ...@@ -265,6 +298,56 @@ class TensorWrapper {
return nvte_tensor_shape(tensor_); return nvte_tensor_shape(tensor_);
} }
/*! \brief Get the size of this TensorWrapper in the given dimension.
*
* \param[in] size_t Dimension index.
*
* \return Size of this TensorWrapper in given dimension.
*/
size_t size(const size_t dim) const {
if (tensor_ == nullptr) return 0;
return nvte_tensor_size(tensor_, dim);
}
/*! \brief Get the number of dimensions for this TensorWrapper.
*
* \return Number of dimensions for this TensorWrapper.
*/
size_t ndim() const noexcept {
if (tensor_ == nullptr) return 0;
return nvte_tensor_ndims(tensor_);
}
/*! \brief Get the number of allocated elements in the tensor. This will return 0 for tensors
* with nullptr data even if the TensorWrapper has a non-zero shape.
*
*
* \return Number of elements in the tensor.
*/
size_t numel() const noexcept {
if (tensor_ == nullptr || this->dptr() == nullptr) return 0;
return nvte_tensor_numel(tensor_);
}
/*! \brief Get the tensor's element size in bytes.
*
* \return Element size in bytes.
*/
size_t element_size() const noexcept {
if (tensor_ == nullptr) return 0;
return nvte_tensor_element_size(tensor_);
}
/*! \brief Get the tensor's allocated size in bytes. This will return 0 for tensors with nullptr
* data even if the TensorWrapper has a non-zero shape and valid dtype.
*
* \return Total tensor size in bytes.
*/
size_t bytes() const noexcept {
if (tensor_ == nullptr || this->dptr() == nullptr) return 0;
return nvte_tensor_numel(tensor_) * nvte_tensor_element_size(tensor_);
}
/*! \brief Get the data type of this TensorWrapper. /*! \brief Get the data type of this TensorWrapper.
* *
* \return Data type of this TensorWrapper. * \return Data type of this TensorWrapper.
...@@ -317,6 +400,6 @@ class TensorWrapper { ...@@ -317,6 +400,6 @@ class TensorWrapper {
} // namespace transformer_engine } // namespace transformer_engine
#endif #endif // __cplusplus
#endif // TRANSFORMER_ENGINE_TRANSFORMER_ENGINE_H_ #endif // TRANSFORMER_ENGINE_TRANSFORMER_ENGINE_H_
...@@ -93,6 +93,31 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { ...@@ -93,6 +93,31 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
return ret; return ret;
} }
size_t nvte_tensor_ndim(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return t.data.shape.size();
}
size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTE_CHECK(dim >= 0 && dim < t.data.shape.size(), "Invalid dimension index: ", dim);
return t.data.shape[dim];
}
size_t nvte_tensor_numel(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
size_t numel = 1;
for (auto size : t.data.shape) {
numel *= size;
}
return numel;
}
size_t nvte_tensor_element_size(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return transformer_engine::typeToSize(t.data.dtype);
}
void *nvte_tensor_data(const NVTETensor tensor) { void *nvte_tensor_data(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return t.data.dptr; return t.data.dptr;
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "../common.h" #include "../common.h"
#include "../util/cuda_driver.h" #include "../util/cuda_driver.h"
#include "../util/system.h" #include "../util/system.h"
#include "common/util/cuda_runtime.h"
namespace transformer_engine { namespace transformer_engine {
...@@ -80,6 +81,31 @@ int sm_count(int device_id) { ...@@ -80,6 +81,31 @@ int sm_count(int device_id) {
return cache[device_id]; return cache[device_id];
} }
bool supports_multicast(int device_id) {
#if CUDART_VERSION >= 12010
// NOTE: This needs to be guarded at compile time because the
// CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED enum is not defined in earlier CUDA versions.
static std::vector<bool> cache(num_devices(), false);
static std::vector<std::once_flag> flags(num_devices());
if (device_id < 0) {
device_id = current_device();
}
NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID");
auto init = [&]() {
CUdevice cudev;
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &cudev, device_id);
int result;
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &result,
CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, cudev);
cache[device_id] = static_cast<bool>(result);
};
std::call_once(flags[device_id], init);
return cache[device_id];
#else
return false;
#endif
}
const std::string &include_directory(bool required) { const std::string &include_directory(bool required) {
static std::string path; static std::string path;
......
...@@ -38,6 +38,14 @@ int sm_arch(int device_id = -1); ...@@ -38,6 +38,14 @@ int sm_arch(int device_id = -1);
*/ */
int sm_count(int device_id = -1); int sm_count(int device_id = -1);
/* \brief CUDA Multicast support status for device
*
* \param[in] device_id CUDA device (default is current device)
*
* \return CUDA multicast support flag
*/
bool supports_multicast(int device_id = -1);
/* \brief Path to CUDA Toolkit headers /* \brief Path to CUDA Toolkit headers
* *
* The path can be configured by setting NVTE_CUDA_INCLUDE_DIR in the * The path can be configured by setting NVTE_CUDA_INCLUDE_DIR in the
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_
#include <pybind11/pybind11.h>
#include <transformer_engine/comm_gemm_overlap.h>
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/transformer_engine.h>
#include "cuda_runtime.h"
#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \
pybind11::enum_<transformer_engine::DType>(m, "DType") \
.value("kByte", transformer_engine::DType::kByte) \
.value("kInt32", transformer_engine::DType::kInt32) \
.value("kFloat32", transformer_engine::DType::kFloat32) \
.value("kFloat16", transformer_engine::DType::kFloat16) \
.value("kBFloat16", transformer_engine::DType::kBFloat16) \
.value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \
pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type") \
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \
.value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \
.value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \
.value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \
pybind11::enum_<NVTE_Mask_Type>(m, "NVTE_Mask_Type") \
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \
.value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) \
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \
pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout") \
.value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \
.value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \
.value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \
.value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) \
.value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) \
.value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) \
.value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) \
.value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) \
.value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) \
.value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) \
.value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) \
.value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \
.value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \
.value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \
.value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend") \
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \
.value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \
.value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \
pybind11::enum_<transformer_engine::CommOverlapType>(m, "CommOverlapType") \
.value("RS", transformer_engine::CommOverlapType::RS) \
.value("AG", transformer_engine::CommOverlapType::AG); \
pybind11::enum_<transformer_engine::CommOverlapAlgo>(m, "CommOverlapAlgo") \
.value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \
.value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \
.value("SPLIT_PIPELINED_AG_P2P", \
transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_P2P) \
.value("SPLIT_PIPELINED_RS", transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS) \
.value("SPLIT_PIPELINED_RS_P2P", \
transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \
.value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \
.value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \
.value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \
m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \
py::call_guard<py::gil_scoped_release>(), py::arg("device_id") = -1); \
m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \
py::call_guard<py::gil_scoped_release>());
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment