Unverified Commit 5ee98175 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[PyTorch] Fixing hang in `initialize_ub()` for multi-node runs after PR901...


[PyTorch] Fixing hang in `initialize_ub()` for multi-node runs after PR901 removal of MPI-dependence (#986)

* Re-implementing PR901 (removing MPI-dependence in Userbuffers) with multi-node fixes

* passing data-parallel rank/size info from torch.distributed to userbuffers
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* multi-node example working with UB_SKIPMC=1 but not with multicast
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed multi-node hang in initialize_ub(), updated comm+GEMM overlap example to support multi-node mixed tensor/data parallelism, added README
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed use case when Userbuffers is asked to allocate the TP overlap buffer with UB_SKIPMC=1
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected example problem to set device by local ordinal instead of global process rank
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* double-free fix in userbuffers destructor
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* removed unnecessary and incorrect torch.cuda.set_device(...)
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected inter-node ranks logic
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* generalized node ID logic in initialize_ub to handle arbitrary world rank layouts within node
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added single-node comm+GEMM overlap unit tests
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* LayerNormMLP example confirmed working with 2 nodes on Eos
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* unit test cleanup
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected DP group ranks logic in LNMLP comm+GEMM overlap example
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected enums in unit test
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed incorrect Ubuf object init signature
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* switched default backend for Userbuffer bootstrapping to Gloo with MPI and NCCL fallbacks, and initialize_ub option to manually select backend
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed all comm+GEMM overlap unit tests
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected all_gather use for Gloo backend
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* changed userbuffers allgather callback to always use all_gather() instead of all_gather_into_tensor()
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* restored and verified old MPI-based bootstrapping via NVTE_UB_WITH_MPI=1 option at compile time
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* disabled scoped GIL release for comm+GEMM overlap algorithms
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* avoid dist.init_device_mesh in comm+GEMM overlap example to support older PyTorch versions
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* applied RS overlap FP8 fix from PR1004
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed segfault in Userbuffers destructor
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected comm+GEMM overlap unit test arguments
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed unit test run command for when Userbuffers is compiled with MPI
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Refactored torch.distributed collectives into pure C++ callbacks
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 71124c31
......@@ -77,14 +77,14 @@ def setup_pytorch_extension(
# Libraries
library_dirs = []
libraries = []
if os.getenv("UB_MPI_BOOTSTRAP"):
if os.getenv("NVTE_UB_WITH_MPI"):
assert (
os.getenv("MPI_HOME") is not None
), "MPI_HOME must be set when compiling with UB_MPI_BOOTSTRAP=1"
), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
mpi_home = Path(os.getenv("MPI_HOME"))
include_dirs.append(mpi_home / "include")
cxx_flags.append("-DUB_MPI_BOOTSTRAP")
nvcc_flags.append("-DUB_MPI_BOOTSTRAP")
cxx_flags.append("-DNVTE_UB_WITH_MPI")
nvcc_flags.append("-DNVTE_UB_WITH_MPI")
library_dirs.append(mpi_home / "lib")
libraries.append("mpi")
......
# Overlapping Communication with GEMM in TransformerEngine Modules
## Requirements
- Tensor-parallel GPUs must be on a single node, and connected over NVLink/NVSwitch.
- `CUDA_DEVICE_MAX_CONNECTIONS=1` must be enabled in the environment.
- For best performance, point-to-point communication via _CUDA Multicast_ needs CUDA Toolkit 12.0+
and CUDA driver 535+ on devices with compute capability 9.0 or newer.
- Devices older than compute capability 9.0 require `UB_SKIPMC=1` in the environment in order fall
back on a less performant implementation based on CUDA Inter-Process Communication (IPC) handles.
## Examples
### Single node, tensor-parallel LayerNormMLP:
Forward and backward passes with layer weights distributed over all GPUs in a single node.
```bash
$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) ln_mlp_with_overlap.py
# Sample output on 8x H100s:
# [rank0:node0] |-- Created tensor-parallel group: [0, 1, 2, 3, 4, 5, 6, 7]
# !!! [UB] Create UbufP2PCommOverlap Communicator
# UB_TIMEOUT is set to 110 sec, 217800000000 cycles, freq: 1980000khz
# MC initialized succesfully, window size = 549755813888
# !!! [UBP2P] Register UBuf 1
# !!! [UBP2P] Register UBuf 2
# !!! [UBP2P] Register UBuf 3
# !!! [UBP2P] Register UBuf 4
# !!! [UB] Register UBuf 5
# !!! [UBP2P] Register UBuf 6
# !!! [UB] Register UBuf 7
# !!! [UB] Register UBuf 8
# !!! [UBP2P] Register UBuf 9
# !!! [UB] Register UBuf 10
# [rank0:node0] Iter 1
# [rank0:node0] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank0:node0] |-- Compute loss
# [rank0:node0] |-- Backward pass
# [rank0:node0] |-- Optimizer step
# [rank0:node0] Iter 2
# [rank0:node0] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank0:node0] |-- Compute loss
# [rank0:node0] |-- Backward pass
# [rank0:node0] |-- Optimizer step
# [rank0:node0] Iter 3
# [rank0:node0] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank0:node0] |-- Compute loss
# [rank0:node0] |-- Backward pass
# [rank0:node0] |-- Optimizer step
# [rank0:node0] Iter 4
# [rank0:node0] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank0:node0] |-- Compute loss
# [rank0:node0] |-- Backward pass
# [rank0:node0] |-- Optimizer step
# [rank0:node0] Iter 5
# [rank0:node0] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank0:node0] |-- Compute loss
# [rank0:node0] |-- Backward pass
# [rank0:node0] |-- Optimizer step
```
### Single node, mixed data- and tensor-parallel LayerNormMLP:
Uses `torch.nn.parallel.DistributedDataParallel` for replicatin the model across 2 tensor-parallel
groups in a single node.
```bash
$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) ln_mlp_overlap.py --num-replicas 2
# Sample output on 8x H100s:
# [rank0:node0] |-- Created tensor-parallel group: [0, 1, 2, 3]
# [rank4:node1] |-- Created tensor-parallel group: [4, 5, 6, 7]
# [rank0:node0] |-- Created data-parallel group: [0, 4]
# [rank3:node1] |-- Created data-parallel group: [3, 7]
# [rank1:node1] |-- Created data-parallel group: [1, 5]
# [rank2:node0] |-- Created data-parallel group: [2, 6]
# !!! [UB] Create UbufP2PCommOverlap Communicator
# UB_TIMEOUT is set to 110 sec, 217800000000 cycles, freq: 1980000khz
# MC initialized succesfully, window size = 549755813888
# !!! [UBP2P] Register UBuf 1
# !!! [UBP2P] Register UBuf 2
# !!! [UBP2P] Register UBuf 3
# !!! [UBP2P] Register UBuf 4
# !!! [UB] Register UBuf 5
# !!! [UBP2P] Register UBuf 6
# !!! [UB] Register UBuf 7
# !!! [UB] Register UBuf 8
# !!! [UBP2P] Register UBuf 9
# !!! [UB] Register UBuf 10
# [rank4:node1] Iter 1
# [rank0:node0] Iter 1
# [rank0:node0] |-- Generate random input batch
# [rank4:node1] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank4:node1] |-- Forward pass
# [rank4:node1] |-- Compute loss
# [rank0:node0] |-- Compute loss
# [rank0:node0] |-- Backward pass
# [rank4:node1] |-- Backward pass
# [rank4:node1] |-- Optimizer step
# [rank0:node0] |-- Optimizer step
# [rank4:node1] Iter 2
# [rank0:node0] Iter 2
# [rank0:node0] |-- Generate random input batch
# [rank4:node1] |-- Generate random input batch
# [rank4:node1] |-- Forward pass
# [rank0:node0] |-- Forward pass
# [rank4:node1] |-- Compute loss
# [rank0:node0] |-- Compute loss
# [rank4:node1] |-- Backward pass
# [rank0:node0] |-- Backward pass
# [rank4:node1] |-- Optimizer step
# [rank0:node0] |-- Optimizer step
# [rank4:node1] Iter 3
# [rank0:node0] Iter 3
# [rank0:node0] |-- Generate random input batch
# [rank4:node1] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank4:node1] |-- Forward pass
# [rank4:node1] |-- Compute loss
# [rank0:node0] |-- Compute loss
# [rank4:node1] |-- Backward pass
# [rank0:node0] |-- Backward pass
# [rank0:node0] |-- Optimizer step
# [rank4:node1] |-- Optimizer step
# [rank0:node0] Iter 4
# [rank4:node1] Iter 4
# [rank0:node0] |-- Generate random input batch
# [rank4:node1] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank4:node1] |-- Forward pass
# [rank0:node0] |-- Compute loss
# [rank4:node1] |-- Compute loss
# [rank4:node1] |-- Backward pass
# [rank0:node0] |-- Backward pass
# [rank4:node1] |-- Optimizer step
# [rank0:node0] |-- Optimizer step
# [rank4:node1] Iter 5
# [rank0:node0] Iter 5
# [rank0:node0] |-- Generate random input batch
# [rank4:node1] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank4:node1] |-- Forward pass
# [rank0:node0] |-- Compute loss
# [rank4:node1] |-- Compute loss
# [rank0:node0] |-- Backward pass
# [rank4:node1] |-- Backward pass
# [rank4:node1] |-- Optimizer step
# [rank0:node0] |-- Optimizer step
```
**NOTE:** To run with Fp8 compute on supporting hardware, add the `--fp8` flag to the commands
shown above.
......@@ -6,17 +6,22 @@
import os
import sys
import subprocess
import socket
import argparse
import warnings
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
def parse_args(argv=None, namespace=None):
def _parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser(
description="Test a te.LayerNormMLP module with GEMM+comm overlap via Userbuffers."
)
......@@ -47,63 +52,182 @@ def parse_args(argv=None, namespace=None):
default=False,
help="Disable the comm+GEMM overlap.",
)
parser.add_argument("-v", "--verbose", action="store_true", default=False)
return parser.parse_args(argv, namespace)
parser.add_argument(
"--num-replicas", type=int, default=1, help="Number of data-parallel model replicas."
)
parser.add_argument(
"--tcp-init",
action="store_true",
default=False,
help="Initialize torch.distributed with TcpStore.",
)
parser.add_argument(
"--bind-to-device",
action="store_true",
default=False,
help="Initialize torch.distributed with `device_id` to bind each rank to a single device.",
)
parser.add_argument(
"--bootstrap-backend",
type=str.lower,
default="nccl",
choices=["gloo", "mpi", "nccl"],
help="Communications backend for host tensor collectives during Userbuffers bootstrapping.",
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
default=False,
help="Print out from every rank instead of just the root rank of relevant process groups.",
)
parser.add_argument(
"--debug",
action="store_true",
default=False,
help="Print out additional debug information.",
)
args = parser.parse_args(argv, namespace)
return args
def _train(opts):
if "OMPI_COMM_WORLD_SIZE" in os.environ:
# Execution with `mpirun -np N`
WORLD_RANK = int(os.getenv("OMPI_COMM_WORLD_RANK", "0"))
WORLD_SIZE = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1"))
opts.tcp_init = True
opts.bind_to_device = True
opts.bootstrap_backend = "mpi"
elif "TORCHELASTIC_RUN_ID" in os.environ:
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
else:
raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!")
NUM_NODES = WORLD_SIZE // LOCAL_SIZE
def dist_print(msg, group=None, end="\n", debug=False):
if debug and not opts.debug:
return
group = dist.new_group() if group is None else group
group_rank = dist.get_rank(group)
group_size = dist.get_world_size(group)
all_ranks = dist.get_process_group_ranks(group)
ranks_skip = all_ranks[1] - all_ranks[0] > 1
group_id = WORLD_RANK % group_size if ranks_skip else WORLD_RANK // group_size
if group_rank == 0 or opts.verbose:
print(f"[rank{WORLD_RANK}:node{group_id}] {msg}{end}", end="", flush=True)
dist.barrier(group)
# Initialize torch.distributed global process group and get DP/TP groups
torch.cuda.set_device(LOCAL_RANK)
dist_init_kwargs = {
"backend": "nccl",
"rank": WORLD_RANK,
"world_size": WORLD_SIZE,
}
if opts.tcp_init or NUM_NODES > 1:
if NUM_NODES > 1:
assert (
"MASTER_ADDR" in os.environ
), "Multi-node run requires MASTER_ADDR to be set in the environment."
MASTER_ADDR = os.getenv("MASTER_ADDR", socket.gethostbyname(socket.gethostname()))
MASTER_PORT = os.getenv("MASTER_PORT", "1234")
dist_init_kwargs["init_method"] = f"tcp://{MASTER_ADDR}:{MASTER_PORT}"
if opts.bind_to_device or opts.bootstrap_backend == "nccl":
dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}")
assert dist.is_nccl_available()
dist.init_process_group(**dist_init_kwargs)
nccl_world = dist.new_group(backend="nccl")
dist_print(f"Initialized default NCCL process group with {WORLD_RANK} GPUs", nccl_world)
def train(opts):
WORLD_RANK = int(os.getenv("RANK"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE"))
# Figure out process groups for tensor- and data-parallelism (if any)
if NUM_NODES > 1:
# Create a list of world ranks on this node
hostnames = [None for _ in range(WORLD_SIZE)]
hostname = socket.gethostname()
dist.all_gather_object(hostnames, hostname)
node_ranks = []
for i, host in enumerate(hostnames):
if host == hostname:
node_ranks.append(i)
def dist_print(msg, end="\n", all_ranks=False):
if WORLD_RANK == 0 or all_ranks:
print(f"[RANK-{WORLD_RANK}] {msg}", end=end)
if opts.num_replicas > 1:
# Split node ranks into multiple replicas
assert len(node_ranks) % opts.num_replicas == 0
tp_size = len(node_ranks) // opts.num_replicas
found_replica = False
for replica in range(opts.num_replicas):
start = replica * tp_size
end = start + tp_size
tp_ranks = node_ranks[start:end]
if WORLD_RANK in tp_ranks:
found_replica = True
break
assert found_replica
else:
# The entire node is the tensor-parallel group
tp_ranks = node_ranks
# Seed RNG
torch.cuda.set_device(WORLD_RANK)
torch.manual_seed(opts.seed + WORLD_RANK)
torch.cuda.manual_seed(opts.seed + WORLD_RANK)
tp_group = dist.new_group(backend="nccl", ranks=tp_ranks)
tp_size = dist.get_world_size(tp_group)
tp_rank = dist.get_rank(tp_group)
# Initialize torch.distributed global process group and get TP group
dist.init_process_group(
backend="nccl",
rank=WORLD_RANK,
world_size=WORLD_SIZE,
device_id=torch.device(f"cuda:{WORLD_RANK}"),
)
tp_group = dist.new_group(backend="nccl")
# Data-parallelism across TP groups
dp_start = tp_rank
dp_end = dp_start + WORLD_SIZE
dp_ranks = list(range(dp_start, dp_end, tp_size))
dp_group = dist.new_group(backend="nccl", ranks=dp_ranks)
else:
if opts.num_replicas > 1:
# Mixed data- and tensor-parallelism on a single node
# NOTE: Avoid dist.init_device_mesh() to support older PyTorch versions
all_ranks = torch.tensor(list(range(LOCAL_SIZE)), dtype=torch.uint8, device="cpu")
mesh2d = all_ranks.reshape((opts.num_replicas, LOCAL_SIZE // opts.num_replicas))
node_idx = (mesh2d == LOCAL_RANK).nonzero().squeeze().tolist()
tp_ranks = mesh2d[node_idx[0], :].tolist()
tp_group = dist.new_group(backend="nccl", ranks=tp_ranks)
dp_ranks = mesh2d[:, node_idx[1]].tolist()
dp_group = dist.new_group(backend="nccl", ranks=dp_ranks)
else:
dp_group = None
tp_group = nccl_world
tp_rank = dist.get_rank(tp_group)
tp_size = dist.get_world_size(tp_group)
dist_print(
f"Created tensor-parallel group: {dist.get_process_group_ranks(tp_group)}",
group=tp_group,
)
if dp_group is not None:
dist_print(
f"Created data-parallel group: {dist.get_process_group_ranks(dp_group)}",
group=dp_group,
)
# Intialize userbuffers
ag_cfg = { # Ring-exchange All-Gather overlap for fc1_fprop and fc2_dgrad
"method": "ring_exchange",
"num_splits": 8,
"num_sm": 1,
"set_sm_margin": False,
}
rs_cfg = { # Reduce-scatter overlap for fc1_dgrad and fc2_fprop
"method": "ring_exchange",
"num_splits": 4,
"num_sm": 1,
"set_sm_margin": True,
}
hidden_size = opts.num_heads * opts.head_dim
batched_size = opts.seq_length * opts.batch_size
if not opts.no_comm_overlap:
te.initialize_ub(
te.module.base.initialize_ub(
[batched_size, hidden_size],
tp_group,
tp_size,
use_fp8=opts.fp8,
dtype=torch.bfloat16,
ub_cfgs={
"fc1_fprop": ag_cfg,
"fc1_dgrad": rs_cfg,
"fc2_fprop": rs_cfg,
"fc2_dgrad": ag_cfg,
},
bootstrap_backend=opts.bootstrap_backend,
)
#
# Initialize the fused LayerNorm + Multi-layer Perceptron module
torch.manual_seed(opts.seed + tp_rank)
torch.cuda.manual_seed(opts.seed + tp_rank)
model = te.LayerNormMLP(
hidden_size,
opts.mlp_expansion_factor * hidden_size,
......@@ -114,11 +238,14 @@ def train(opts):
set_parallel_mode=True,
sequence_parallel=True, # this is required for comm+GEMM overlap
seq_length=opts.seq_length,
micro_batch_size=opts.batch_size,
ub_overlap_rs_dgrad=not opts.no_comm_overlap,
ub_overlap_rs=not opts.no_comm_overlap,
ub_overlap_ag=not opts.no_comm_overlap,
ub_overlap_rs_dgrad=not opts.no_comm_overlap,
ub_bulk_dgrad=False,
ub_bulk_wgrad=not opts.no_comm_overlap,
)
if dp_group is not None:
model = DistributedDataParallel(model, process_group=dp_group)
# Initialize optimizer with model parameters
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
......@@ -128,10 +255,11 @@ def train(opts):
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
# Start dummy "training" iterations
dist_print("Starting training iterations...", nccl_world)
for i in range(opts.num_iters):
dist_print(f"Iter {i+1}", all_ranks=opts.verbose)
dist_print(f" Iter {i+1}", tp_group, debug=True)
dist_print("|-- Generate random input batch", all_ranks=opts.verbose)
dist_print(" |-- Generate random input batch", tp_group, debug=True)
x = torch.rand(
(opts.seq_length // tp_size, opts.batch_size, hidden_size),
dtype=torch.bfloat16,
......@@ -139,30 +267,29 @@ def train(opts):
requires_grad=True,
)
dist_print("|-- Forward pass", all_ranks=opts.verbose)
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=tp_group):
dist_print(" |-- Forward pass", tp_group, debug=True)
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world):
y = model(x)
dist_print("|-- Compute loss", all_ranks=opts.verbose)
dist_print(" |-- Compute loss", tp_group, debug=True)
loss = y.flatten().sum()
dist_print("|-- Backward pass", all_ranks=opts.verbose)
dist_print(" |-- Backward pass", tp_group, debug=True)
loss.backward()
dist_print("|-- Optimizer step", all_ranks=opts.verbose)
dist_print(" |-- Optimizer step", tp_group, debug=True)
optim.step()
te.destroy_ub()
torch.cuda.synchronize()
dist_print("Finished training!")
te.module.base.destroy_ub()
dist_print("Destroying all process groups...", debug=True)
dist.destroy_process_group()
if opts.debug and WORLD_RANK == 0:
print("Exiting...\n", end="", flush=True)
return 0
if __name__ == "__main__":
if "TORCHELASTIC_RUN_ID" in os.environ.keys():
args = parse_args()
train(args)
else:
subprocess.run(
["torchrun", f"--nproc-per-node={torch.cuda.device_count()}", *sys.argv],
env=os.environ,
check=True,
)
os._exit(0)
sys.exit(_train(_parse_args()))
#!/usr/bin/python3
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import sys
import socket
import warnings
import subprocess
import argparse
import operator
from functools import partial, reduce
import torch
import torch.distributed as dist
from torch.distributed.elastic.multiprocessing.errors import record
import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.common.recipe import Format
from transformer_engine.pytorch.fp8 import _default_sf_compute
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
torch_dtypes = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
nvte_comm_types = {
"rs": 0,
"ag": 1,
}
def _mapped_argtype(opt, typemap):
if str(opt).lower() not in typemap.keys():
raise TypeError(f"Unrecognized option! Please choose from: {typemap.keys()}")
return typemap[str(opt).lower()]
def _parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser(description="Test comm+GEMM overlap with Userbuffers.")
parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.")
parser.add_argument("-s", "--seq-length", type=int, default=2048, help="Input sequence length.")
parser.add_argument(
"-n", "--num-heads", type=int, default=64, help="Number of attention heads."
)
parser.add_argument(
"-d", "--head-dim", type=int, default=128, help="Dimension of each attention head."
)
parser.add_argument("--seed", type=int, default=1234, help="RNG seed.")
parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
)
parser.add_argument(
"--p2p", action="store_true", default=False, help="Test overlap with P2P comms."
)
parser.add_argument(
"--atomic", action="store_true", default=False, help="Test overlap with atomic GEMM."
)
parser.add_argument(
"--aggregate",
action="store_true",
default=False,
help="Aggregate 2X chunks for P2P split pipelined all-gather.",
)
parser.add_argument(
"--comm-type",
type=partial(_mapped_argtype, typemap=nvte_comm_types),
default=0,
help="Comm type to overlap.",
)
parser.add_argument(
"--bulk-overlap",
action="store_true",
default=False,
help="Enable bulk AG or RS overlap for a tensor that is not involved in the GEMM compute.",
)
parser.add_argument(
"--check-numerics",
action="store_true",
default=False,
help="Test numerical result against torch.matmul(...)",
)
parser.add_argument(
"--warmup-iters",
type=int,
default=0,
help="Run some warmup iterations of the comm+GEMM overlap before " + "the timing runs.",
)
parser.add_argument(
"--timing-iters",
type=int,
default=1,
help="Benchmark the comm+GEMM overlap as an average of many iterations.",
)
parser.add_argument(
"--clock-speed",
type=int,
default=-1,
help="Set device clock speed to a fixed value via `nvidia-smi`.",
)
parser.add_argument(
"--scale", type=float, default=1e-2, help="Set scaling factor for input and weight tensors."
)
parser.add_argument(
"--tcp-init",
action="store_true",
default=False,
help="Initialize torch.distributed with TcpStore.",
)
parser.add_argument(
"--init-method", type=str, default=None, help="Set the torch.distributed init method."
)
parser.add_argument(
"--bind-to-device",
action="store_true",
default=False,
help=(
"Initialize torch.distributed with 'device_id' argument to bind each rank to 1 device."
),
)
parser.add_argument(
"--bootstrap-backend",
type=str.lower,
default="nccl",
choices=["gloo", "mpi", "nccl"],
help=(
"PyTorch distributed backend for host tensor collectives during comm+GEMM overlap "
+ "initialization."
),
)
parser.add_argument(
"-v", "--verbose", action="store_true", default=False, help="Verbose info messages."
)
opts = parser.parse_args(argv, namespace)
if opts.bulk_overlap:
if opts.p2p:
warnings.warn("Point-2-point comms are not supported with bulk overlap.")
opts.p2p = False
if opts.atomic:
warnings.warn("Atomic GEMM is not supported with bulk overlap.")
opts.atomic = False
if opts.fp8:
warnings.warn("Bulk overlap is supported in FP8 but only tested in BF16.")
opts.fp8 = False
elif opts.comm_type == 1 and not opts.p2p:
warnings.warn("All-gather overlap is only supported with point-2-point comms.")
opts.p2p = True
if opts.atomic:
if not te.fp8.check_fp8_support():
assert not opts.fp8, "Atomic GEMM is only supported in FP8."
opts.fp8 = True
return opts
@record
def _main(opts):
if "OMPI_COMM_WORLD_SIZE" in os.environ:
# Execution with `mpirun -np N`
WORLD_RANK = int(os.getenv("OMPI_COMM_WORLD_RANK", "0"))
WORLD_SIZE = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1"))
opts.tcp_init = True
opts.bootstrap_backend = "mpi"
elif "TORCHELASTIC_RUN_ID" in os.environ:
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
else:
raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!")
assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node
assert LOCAL_SIZE <= torch.cuda.device_count()
# Fix clock speed
torch.cuda.set_device(LOCAL_RANK)
if opts.clock_speed > 0:
subprocess.run(
["nvidia-smi", "-pm", "ENABLED", "-i", str(LOCAL_RANK)],
env=os.environ,
check=False,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
result = subprocess.run(
["nvidia-smi", "-lgc", str(opts.clock_speed), "-i", str(LOCAL_RANK)],
env=os.environ,
check=False,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
msg = result.stdout.decode("utf-8").splitlines()[0]
print(f"[rank:{LOCAL_RANK}] {msg}\n", end="", flush=True)
# Info printout
def dist_print(msg, src=None, info=False, section=False, group=None):
group = dist.new_group() if group is None else group
rank = dist.get_rank(group)
if info or opts.verbose:
if section:
if rank == (0 if src is None else src):
print("\n", end="", flush=True)
dist.barrier(group)
if src is None or rank == src:
prefix = "[GLOBAL] " if src is not None else f"[rank:{rank}] "
lines = msg.splitlines()
msg = "\n".join(
[prefix + lines[0]] + [(" " * len(prefix)) + line for line in lines[1:]]
)
print(msg + "\n", end="", flush=True)
dist.barrier(group)
# Initialize torch.distributed global process group and get TP group
dist_init_kwargs = {
"backend": "nccl",
"rank": WORLD_RANK,
"world_size": WORLD_SIZE,
}
if opts.tcp_init:
if opts.init_method is not None:
assert opts.init_method.startswith("tcp://")
init_method = opts.init_method
else:
MASTER_ADDR = os.getenv("MASTER_ADDR", socket.gethostbyname(socket.gethostname()))
MASTER_PORT = os.getenv("MASTER_PORT", "1234")
init_method = f"tcp://{MASTER_ADDR}:{MASTER_PORT}"
dist_init_kwargs["init_method"] = init_method
elif opts.init_method is not None:
assert (
opts.init_method.startswith("env://")
or opts.init_method.startswith("file://")
or opts.init_method.startswith("tcp://")
)
dist_init_kwargs["init_method"] = opts.init_method
if opts.bind_to_device or opts.bootstrap_backend == "nccl":
dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}")
assert dist.is_nccl_available()
dist.init_process_group(**dist_init_kwargs)
tp_group = dist.new_group(backend="nccl")
tp_rank = dist.get_rank(tp_group)
tp_size = dist.get_world_size(tp_group)
dist_print(
f"Initialized default NCCL process group with {tp_size} GPUs",
src=0,
section=True,
info=True,
group=tp_group,
)
# Initialize backend used in bootstrapping Userbuffers
if opts.bootstrap_backend == "gloo":
assert dist.is_gloo_available()
elif opts.bootstrap_backend == "mpi":
assert dist.is_mpi_available()
bootstrap_pg = dist.new_group(backend=opts.bootstrap_backend)
dist_print(
f'Bootstrapping comm+GEMM overlap with backend="{opts.bootstrap_backend}"',
src=0,
section=True,
info=True,
group=bootstrap_pg,
)
if WORLD_RANK == 0:
print("\n", end="", flush=True)
ub_callbacks = (
tex.UbufBootstrapCallbacks()
if tex.ubuf_built_with_mpi()
else tex.UbufBootstrapCallbacks(bootstrap_pg, bootstrap_pg)
)
if opts.comm_type == 0:
if opts.bulk_overlap:
ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_RS
elif opts.p2p:
ub_algo = (
tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P
if opts.atomic
else tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P
)
else:
ub_algo = (
tex.UbufOverlapAlgo.ATOMIC_GEMM_RS
if opts.atomic
else tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS
)
elif opts.comm_type == 1:
if opts.bulk_overlap:
ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG
else:
ub_algo = (
tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P
if opts.atomic
else tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P
)
else:
raise TypeError("Invalid comm+GEMM overlap type!")
# Initialize userbuffers with (M, N) buffer
# M = sequence * batch
# N = hidden size
hidden_size = opts.num_heads * opts.head_dim
inp_shape = (opts.seq_length, opts.batch_size, hidden_size)
outer_size = reduce(operator.mul, inp_shape[:-1], 1)
ubuf_dtype = torch.uint8 if opts.fp8 and opts.comm_type == 1 else torch.bfloat16
sample_buffer = torch.empty((outer_size, hidden_size), dtype=ubuf_dtype, device="cuda")
ub_obj = ub_obj = (
tex.UbufP2PCommOverlap(
sample_buffer, # Sample userbuffer
WORLD_RANK, # World rank
WORLD_SIZE, # World size
LOCAL_RANK, # Rank within the node
LOCAL_SIZE, # Number of ranks/GPUs per node
0, # Node ID
1, # Number of nodes
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
1, # Number of communication SMs
1, # CGA cluster size
opts.comm_type == 0 or opts.atomic, # Set SM margin
opts.aggregate, # Aggregate 2X GEMM chunks
3, # Max concurrent GEMM streams
opts.comm_type == 0, # overlap with reduce scatter
opts.atomic, # use a single GEMM with atomic-counters
True, # Use copy engine for P2P communications
ub_callbacks,
)
if opts.p2p
else tex.UbufCommOverlap(
sample_buffer, # Sample userbuffer
WORLD_RANK, # World rank
WORLD_SIZE, # World size
LOCAL_RANK, # Rank within the node
LOCAL_SIZE, # Number of ranks/GPUs per node
0, # Node ID
1, # Number of nodes
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
16, # Number of communication SMs
2, # CGA cluster size
4, # Number of communication splits
True, # Set SM margin
3, # Max concurrent GEMM streams
opts.atomic, # uUe a single GEMM with atomic-counters
ub_callbacks,
)
)
# Numerical check on AG + atomic GEMM requires testing an AG+RS pair
ub_obj2 = None
if opts.atomic and opts.comm_type == 1 and opts.check_numerics:
sample_buffer2 = torch.empty((outer_size, hidden_size), dtype=torch.bfloat16, device="cuda")
ub_obj2 = tex.UbufP2PCommOverlap(
sample_buffer2, # Sample userbuffer
WORLD_RANK, # World rank
WORLD_SIZE, # World size
LOCAL_RANK, # Rank within the node
LOCAL_SIZE, # Number of ranks/GPUs per node
0, # Node ID
1, # Number of nodes
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
1, # Number of communication SMs
1, # CGA cluster size
True, # Set SM margin
False, # Aggregate 2X GEMM chunks
3, # Max concurrent GEMM streams
True, # overlap with reduce scatter
True, # use a single GEMM with atomic-counters
True, # use copy engine for P2P communications
ub_callbacks,
)
# Figure out problem sizing:
# M = sequence * batch
# N = hidden size
# K = MLP intermediate size (usually 4x hidden size)
# P = number of devices for sequence/tensor parallelism
# NOTE: TE-GEMM is set up to work with a transposed kernels and non-transposed inputs.
ffn_hidden_size = 4 * hidden_size
if opts.bulk_overlap:
# Bulk overlap weight and input tensors are not relevant so they're globally sized
local_kernel_t_shape = (ffn_hidden_size, hidden_size)
local_inp_shape = (outer_size, hidden_size)
# Bulk overlap comm tensor is distributed for AG overlap only
if opts.comm_type == 1:
bulk_inp_shape = (outer_size // tp_size, hidden_size)
else:
bulk_inp_shape = (outer_size, hidden_size)
else:
if opts.comm_type == 1:
# (M/P, N) -> overlapped AG -> (M, N) x (K/P, N)^T = (M, K/P)
local_kernel_t_shape = (ffn_hidden_size // tp_size, hidden_size)
local_inp_shape = (outer_size // tp_size, hidden_size)
if ub_obj2 is not None:
local_kernel2_t_shape = (hidden_size, ffn_hidden_size // tp_size)
else:
# (M, K/P) x (N, K/P)^T = (M, N) -> overlapped RS -> (M/P, N)
local_kernel_t_shape = (hidden_size, ffn_hidden_size // tp_size)
local_inp_shape = (outer_size, ffn_hidden_size // tp_size)
# Initialize distributed input tensor and GEMM kernels
torch.manual_seed(opts.seed + tp_rank)
torch.cuda.manual_seed(opts.seed + tp_rank)
inp = torch.mul(torch.rand(local_inp_shape, dtype=torch.bfloat16, device="cuda"), opts.scale)
kernel_t = torch.mul(
torch.rand(local_kernel_t_shape, dtype=torch.bfloat16, device="cuda"), opts.scale
)
if ub_obj2 is not None:
kernel2_t = torch.mul(
torch.rand(local_kernel2_t_shape, dtype=torch.bfloat16, device="cuda"), opts.scale
)
# Gather global tensors and calculate reference result (need these first for Fp8 scales)
if opts.bulk_overlap:
ker_g = torch.transpose(kernel_t, 0, 1)
inp_g = inp
bulk_inp = torch.mul(
torch.rand(bulk_inp_shape, dtype=torch.bfloat16, device="cuda"), opts.scale
)
else:
if opts.comm_type == 1:
# AG Kernel: (K/P, N) -> gather -> (K, N) -> T -> (N, K)
ker_g = torch.transpose(
te.distributed.gather_along_first_dim(kernel_t, tp_group)[0], 0, 1
)
# AG Input: (M/P, N) -> gather -> (M, N)
inp_g = te.distributed.gather_along_first_dim(inp, tp_group)[0]
if ub_obj2 is not None:
ker2_g = te.distributed.gather_along_first_dim(
torch.transpose(kernel2_t, 0, 1), tp_group
)[0]
else:
# RS Kernel: (N, K/P) -> T -> (K/P, N) -> gather -> (K, N)
ker_g = te.distributed.gather_along_first_dim(
torch.transpose(kernel_t, 0, 1), tp_group
)[0]
# RS Input: (M, K/P) -> T -> (K/P, M) -> gather -> (K, M) -> T -> (M, K)
inp_g = torch.transpose(
te.distributed.gather_along_first_dim(torch.transpose(inp, 0, 1), tp_group)[0], 0, 1
)
if opts.bulk_overlap:
if opts.comm_type == 1:
ref_g = te.distributed.gather_along_first_dim(bulk_inp, tp_group)[0]
else:
# First all-gather all the bulk inputs into a list
bulk_inp_list = [torch.zeros_like(bulk_inp) for _ in range(tp_size)]
dist.all_gather(bulk_inp_list, bulk_inp, tp_group)
# Sum the list together for final global result
ref_g = torch.stack(bulk_inp_list).sum(dim=0)
else:
ref_g = torch.matmul(inp_g, ker_g)
if ub_obj2 is not None:
inp2_g = torch.mul(ref_g, opts.scale)
ref2_g = torch.matmul(inp2_g, ker2_g)
if opts.fp8:
fp8_formats = {
tex.DType.kFloat8E4M3: Format.E4M3,
tex.DType.kFloat8E5M2: Format.E5M2,
}
# Structure to maintain amax and scale/scale_inv information for the kernel and input
fp8_dtype = tex.DType.kFloat8E4M3
fp8_meta = tex.FP8TensorMeta()
num_gemms = 6 if ub_obj2 is not None else 3
fp8_meta.amax_history = torch.zeros((2, num_gemms), dtype=torch.float, device="cuda")
fp8_meta.scale = torch.ones(num_gemms, dtype=torch.float, device="cuda")
fp8_meta.scale_inv = torch.ones(num_gemms, dtype=torch.float, device="cuda")
# Compute initial amaxes and scales
inp_amax = torch.max(torch.abs(inp_g))
fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_INPUT].copy_(inp_amax)
ker_amax = torch.max(torch.abs(ker_g))
fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_WEIGHT].copy_(ker_amax)
ref_amax = torch.max(torch.abs(ref_g))
fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_OUTPUT].copy_(ref_amax)
if ub_obj2 is not None:
inp2_amax = torch.max(torch.abs(inp2_g))
fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_INPUT].copy_(inp2_amax)
ker2_amax = torch.max(torch.abs(ker2_g))
fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_WEIGHT].copy_(ker2_amax)
ref2_amax = torch.max(torch.abs(ref2_g))
fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_OUTPUT].copy_(ref2_amax)
fp8_meta.scale = _default_sf_compute(
fp8_meta.amax_history[1], fp8_meta.scale, fp8_formats[fp8_dtype].value.max_fwd, 1
)
fp8_meta.scale_inv = torch.reciprocal(fp8_meta.scale)
# Cast input to Float8Tensor
inp_fp8 = tex.cast_to_fp8(inp, fp8_meta, tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype)
# Cast kernel to Float8Tensor
kernel_t_fp8 = tex.cast_to_fp8(
kernel_t, fp8_meta, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype
)
if ub_obj2 is not None:
kernel2_t_fp8 = tex.cast_to_fp8(
kernel2_t, fp8_meta, tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype
)
# Make sure the inputs are cast correctly
if opts.check_numerics:
torch.allclose(
inp.to(dtype=torch.float32),
inp_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_INPUT],
rtol=0.125,
atol=0.0675,
)
torch.allclose(
kernel_t.to(dtype=torch.float32),
kernel_t_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_WEIGHT],
rtol=0.125,
atol=0.0675,
)
if ub_obj2 is not None:
torch.allclose(
kernel2_t.to(dtype=torch.float32),
kernel2_t_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_WEIGHT],
rtol=0.125,
atol=0.0675,
)
# Set Fp8 scales for userbuffers
if opts.comm_type == 1:
ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_INPUT])
if ub_obj2 is not None:
ub_obj2.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT])
else:
ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_OUTPUT])
# Set up comm/compute buffers
ubuf_out2 = None
rs_out2 = None
if opts.comm_type == 1:
if opts.bulk_overlap:
ub_obj.copy_input_to_ubuf(bulk_inp, 1)
gemm_inp = inp
else:
ub_obj.copy_input_to_ubuf(inp_fp8 if opts.fp8 else inp, 1)
gemm_inp = ub_obj.get_ubuf_output(1)
ubuf_out = None
rs_out = None
if ub_obj2 is not None:
ubuf_out2 = ub_obj2.get_ubuf_output(1)
rs_out2 = torch.empty(
(outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda"
)
else:
if opts.bulk_overlap:
ub_obj.copy_input_to_ubuf(bulk_inp, 0)
ubuf_out = None
else:
ubuf_out = ub_obj.get_ubuf_output(1)
gemm_inp = inp_fp8 if opts.fp8 else inp
rs_out = torch.empty(
(outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda"
)
# Trigger GEMM
total_iters = opts.warmup_iters + opts.timing_iters
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(total_iters)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(total_iters)]
torch.cuda.synchronize()
if opts.fp8:
for i in range(total_iters):
start_events[i].record()
all_outputs = tex.fp8_gemm(
kernel_t_fp8,
fp8_meta.scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype,
gemm_inp,
fp8_meta.scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype,
torch.bfloat16,
te.module.base.get_workspace(),
bias=None,
use_bias=False,
gelu=False,
use_split_accumulator=te.module.base._2X_ACC_FPROP,
ub_algo=ub_algo,
ub=ub_obj,
extra_output_tensor=rs_out,
out=ubuf_out,
)
end_events[i].record()
if ub_obj2 is not None:
gemm2_inp = tex.cast_to_fp8(
torch.mul(all_outputs[0], opts.scale),
fp8_meta,
tex.FP8FwdTensors.GEMM2_INPUT,
fp8_dtype,
)
all_outputs = tex.fp8_gemm(
kernel2_t_fp8,
fp8_meta.scale_inv,
tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype,
gemm2_inp,
fp8_meta.scale_inv,
tex.FP8FwdTensors.GEMM2_INPUT,
fp8_dtype,
torch.bfloat16,
te.module.base.get_workspace(),
bias=None,
use_bias=False,
gelu=False,
use_split_accumulator=te.module.base._2X_ACC_FPROP,
ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P,
ub=ub_obj2,
extra_output_tensor=rs_out2,
out=ubuf_out2,
)
else:
for i in range(total_iters):
start_events[i].record()
all_outputs = tex.gemm(
kernel_t,
gemm_inp,
torch.bfloat16,
te.module.base.get_workspace(),
bias=None,
use_bias=False,
gelu=False,
ub_algo=ub_algo,
ub=ub_obj,
extra_output_tensor=rs_out,
out=ubuf_out,
)
end_events[i].record()
torch.cuda.synchronize()
gpu_times = [
s.elapsed_time(e)
for s, e in zip(start_events[opts.warmup_iters :], end_events[opts.warmup_iters :])
]
avg_gpu_time = sum(gpu_times) / opts.timing_iters
gemm_name = "".join(
[
"p2p all-gather + " if opts.comm_type == 1 else "",
"atomic " if opts.atomic else "",
"GEMM",
(f" + {'p2p ' if opts.p2p else ''}reduce-scatter" if opts.comm_type == 0 else ""),
]
)
timing_info = (
f"Avg. GPU time for {gemm_name}: {avg_gpu_time} ms "
+ f"({opts.warmup_iters} warmup + {opts.timing_iters} timing runs)"
)
dist_print(timing_info, section=True, info=True, group=tp_group)
# Compare against standard GEMM
numerics_failed = False
if opts.check_numerics:
torch.cuda.synchronize()
dist.barrier(tp_group)
if opts.bulk_overlap:
output_info = ""
if opts.comm_type == 1:
# Bulk overlap AG output is already gathered
test_out = ub_obj.get_ubuf_output(1)
else:
# Bulk overlap RS output needs to be gathered
out_local = ub_obj.get_ubuf_output(0)
output_info += f"rs_output: {list(out_local.shape)} | "
test_out = te.distributed.gather_along_first_dim(out_local, tp_group)[0]
ref_out = ref_g
output_info += f"output: {list(test_out.shape)} | reference: {list(ref_out.shape)}"
dist_print(output_info, src=0 if opts.comm_type == 0 else None, section=True)
test_nonzeros = torch.count_nonzero(test_out)
ref_nonzeros = torch.count_nonzero(ref_out)
nonzero_info = (
f"output nonzeros = {test_nonzeros} " + f"| reference count = {ref_nonzeros}"
)
dist_print(nonzero_info, src=0, section=True, group=tp_group)
else:
if opts.comm_type == 1:
if ub_obj2 is not None:
# AG+RS Output: (M/P, N) -> gather -> (M, N)
output = rs_out2
test_out = te.distributed.gather_along_first_dim(output, tp_group)[0]
else:
# AG Output: (M, K/P) -> T -> (K/P, M) -> gather -> (K, M) -> T -> (M, K)
output = all_outputs[0]
test_out = torch.transpose(
te.distributed.gather_along_first_dim(
torch.transpose(output, 0, 1), tp_group
)[0],
0,
1,
)
else:
# RS Output: (M/P, N) -> gather -> (M, N)
output = rs_out
test_out = te.distributed.gather_along_first_dim(output, tp_group)[0]
if opts.fp8:
dist_print("GEMM1 FP8 metas = [INPUT, WEIGHT, OUTPUT]", src=0, section=True)
fp8_meta_info = (
f"amax_reference = {fp8_meta.amax_history[1][:3].tolist()}\n"
+ f"amax_history = {fp8_meta.amax_history[0][:3].tolist()}\n"
+ f"scale = {fp8_meta.scale[:3].tolist()}\n"
+ f"scale_inv = {fp8_meta.scale_inv[:3].tolist()}"
)
dist_print(fp8_meta_info, src=0, group=tp_group)
if ub_obj2 is not None:
dist_print("GEMM2 FP8 metas = [INPUT, WEIGHT, OUTPUT]", src=0, section=True)
fp8_meta_info = (
f"amax_reference = {fp8_meta.amax_history[1][3:].tolist()}\n"
+ f"amax_history = {fp8_meta.amax_history[0][3:].tolist()}\n"
+ f"scale = {fp8_meta.scale[3:].tolist()}\n"
+ f"scale_inv = {fp8_meta.scale_inv[3:].tolist()}"
)
dist_print(fp8_meta_info, src=0, group=tp_group)
ref_out = ref2_g if ub_obj2 is not None else ref_g
test_nonzeros = torch.count_nonzero(test_out)
ref_nonzeros = torch.count_nonzero(ref_out)
nonzero_info = (
f"output nonzeros = {test_nonzeros} " + f"| reference count = {ref_nonzeros}"
)
dist_print(nonzero_info, src=0, section=True, group=tp_group)
sizing_info = (
f"input: {list(inp.shape)} " + f"| GEMM1 weights: {list(kernel_t.shape)[::-1]} "
)
if ub_obj2 is not None:
sizing_info += f"| GEMM2 weights: {list(kernel2_t.shape)[::-1]} "
sizing_info += f"| output: {list(output.shape)}\n"
dist_print(sizing_info, section=True, group=tp_group)
sizing_info_g = (
f"input: {list(inp_g.shape)} " + f"| GEMM1 weights: {list(ker_g.shape)} "
)
if ub_obj2 is not None:
sizing_info_g += f"| GEMM2 weights: {list(ker2_g.shape)} "
sizing_info_g += (
f"| output: {list(test_out.shape)} " + f"| reference: {list(ref_out.shape)}\n"
)
dist_print(sizing_info_g, src=0, group=tp_group)
torch.cuda.synchronize()
dist.barrier(tp_group)
test_out = test_out.to(dtype=torch.float32)
ref_out = ref_out.to(dtype=torch.float32)
error_below_tol = torch.allclose(
test_out,
ref_out,
rtol=0.125 if opts.fp8 else 0.02,
atol=0.0675 if opts.fp8 else 0.001,
)
diff = torch.abs(test_out - ref_out).flatten()
m = torch.argmax(diff)
abs_err = diff[m].item()
rel_err = abs_err / (ref_out.flatten()[m].item() + 1e-5)
if not error_below_tol:
numerics_failed = True
numerics_info = (
"NUMERICAL CHECK FAILED: "
+ f"Outputs not close enough at index {m.item()} "
+ f"with {test_out.flatten()[m].item()} vs {ref_out.flatten()[m].item()} "
+ f"(abs error = {abs_err} | rel error = {rel_err})."
)
else:
numerics_info = f"NUMERICAL CHECK PASSED: abs error = {abs_err} | rel error = {rel_err}"
dist_print(numerics_info, src=0, section=True, info=True, group=tp_group)
dist.barrier(tp_group)
if LOCAL_RANK == 0:
print("\n", end="", flush=True)
dist.destroy_process_group()
# Reset clock speeds
if opts.clock_speed > 0:
subprocess.run(
["nvidia-smi", "-pm", "ENABLED", "-i", str(LOCAL_RANK)],
env=os.environ,
check=False,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
result = subprocess.run(
["nvidia-smi", "-rgc", "-i", str(LOCAL_RANK)],
env=os.environ,
check=False,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
return int(numerics_failed)
if __name__ == "__main__":
sys.exit(_main(_parse_args()))
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import subprocess
from pathlib import Path
import pytest
import torch
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
RNG_SEED: int = 1234
SEQ_LENGTH: int = 2024
BATCH_SIZE: int = 2
NUM_HEADS: int = 64
HEAD_DIM: int = 128
TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(torch.cuda.device_count(), 4)
LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"]
if tex.ubuf_built_with_mpi():
LAUNCH_CMD = ["mpirun", "-np", str(NUM_PROCS), "--oversubscribe", "--quiet", "python"]
# Fall back on CUDA IPC if the platform does not support CUDA multicast
if not tex.device_supports_multicast():
os.environ["UB_SKIPMC"] = "1"
# Force GPU kernels to launch in the order they're executed by the host CPU
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
@pytest.mark.skipif(NUM_PROCS < 2, reason="Comm+GEMM overlap requires at least 2 GPUs.")
@pytest.mark.parametrize(
"fp8,p2p,comm_type,aggregate,atomic,bulk",
[
# FP8, P2P, Type, Aggregate, Atomic, Bulk
(False, True, "AG", False, False, False),
(False, True, "AG", True, False, False),
(True, True, "AG", False, False, False),
(True, True, "AG", True, False, False),
(False, False, "RS", False, False, False),
(False, True, "RS", False, False, False),
(True, False, "RS", False, False, False),
(True, True, "RS", False, False, False),
(True, False, "RS", False, True, False),
(True, True, "RS", False, True, False),
(False, False, "AG", False, False, True),
(False, False, "RS", False, False, True),
],
ids=[
" AG -> SPLIT GEMM | BF16 | RING-EXCHANGE ",
" AG -> SPLIT GEMM | BF16 | RING-EXCHANGE (2X AGGREGATED) ",
" AG -> SPLIT GEMM | FP8 | RING-EXCHANGE ",
" AG -> SPLIT GEMM | FP8 | RING-EXCHANGE (2X AGGREGATED) ",
" SPLIT GEMM -> RS | BF16 | PIPELINE ",
" SPLIT GEMM -> RS | BF16 | RING-EXCHANGE ",
" SPLIT GEMM -> RS | FP8 | PIPELINE ",
" SPLIT GEMM -> RS | FP8 | RING-EXCHANGE ",
" ATOMIC GEMM -> RS | FP8 | PIPELINE ",
" ATOMIC GEMM -> RS | FP8 | RING-EXCHANGE ",
" BULK AG & GEMM | BF16 | PIPELINE ",
" BULK RS & GEMM | BF16 | PIPELINE ",
],
)
def test_gemm_with_overlap(fp8, p2p, comm_type, aggregate, atomic, bulk):
"""
Test comm+GEMM overlap algorithms with direct calls to
te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm
"""
test_path = TEST_ROOT / "run_gemm_with_overlap.py"
test_cmd = (
LAUNCH_CMD
+ [str(test_path)]
+ [
"--check-numerics",
f"--seed={RNG_SEED}",
f"--seq-length={SEQ_LENGTH}",
f"--batch-size={BATCH_SIZE}",
f"--num-heads={NUM_HEADS}",
f"--head-dim={HEAD_DIM}",
f"--comm-type={comm_type}",
]
)
if bulk:
test_cmd.append("--bulk-overlap")
else:
if fp8:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
test_cmd.append("--fp8")
if p2p:
test_cmd.append("--p2p")
if aggregate:
test_cmd.append("--aggregate")
if atomic:
if torch.cuda.get_device_properties(0).major < 9:
pytest.skip("Device compute capability 9.0 or higher required for Atomic GEMM.")
test_cmd.append("--atomic")
output = subprocess.run(test_cmd, env=os.environ, text=True, capture_output=True, check=False)
assert "NUMERICAL CHECK PASSED" in str(output)
......@@ -19,7 +19,10 @@
#include <torch/extension.h>
#include <torch/types.h>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include "common/common.h"
#include "common/util/cuda_driver.h"
#include "common/util/logging.h"
#include "common/util/system.h"
#include "extensions.h"
......@@ -28,75 +31,96 @@
#define HALF_BYTES 2
#define UB_MAX_SM 32
#define CHECK_CUDA(call) \
do { \
cudaError_t status_ = call; \
if (status_ != cudaSuccess) { \
fprintf(stderr, "CUDA Error at line %d: %s\n", __LINE__, cudaGetErrorString(status_)); \
exit(1); \
} \
} while (0)
using namespace torch::indexing;
using namespace std::placeholders;
namespace ubuf {
/*
** Static container for Python callbacks to torch.distributed collectives
*/
static struct TorchCallbacks : torch::CustomClassHolder {
bool initialized{false};
std::unordered_map<void *, at::Tensor> gathered_tensors;
std::function<at::Tensor(at::Tensor &, const std::string &)> allgather;
std::function<void(const std::string &)> barrier;
std::function<void(at::Tensor &)> free;
} torch_callbacks;
/*
** Helper function for setting Python callbacks to torch.distributed collectives.
*/
void set_ubuf_bootstrap_callbacks(
std::function<at::Tensor(at::Tensor &, const std::string &)> allgather,
std::function<void(const std::string &)> barrier, std::function<void(at::Tensor &)> free) {
torch_callbacks.allgather = allgather;
torch_callbacks.barrier = barrier;
torch_callbacks.free = free;
torch_callbacks.initialized = true;
bool device_supports_multicast() {
int dev, supports_multicast;
CUdevice cudev;
NVTE_CHECK_CUDA(cudaGetDevice(&dev));
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &cudev, dev);
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &supports_multicast,
CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, cudev);
return static_cast<bool>(supports_multicast);
}
bool ubuf_built_with_mpi() {
#ifdef NVTE_UB_WITH_MPI
return true;
#else
return false;
#endif
}
/*
** Python callback for globaldata = torch.distributed.all_gather(localdata, tp_group).
** This *creates* a new tensor, which Userbuffers later frees with a separate callback.
*/
void ub_alloc_copy_allgather(void **globaldata, void *localdata, size_t localbytes, char *group) {
assert(torch_callbacks.initialized);
class UbufBootstrapCallbacks : torch::CustomClassHolder {
private:
bool initialized{false};
bool backend_is_nccl{false};
std::map<std::string, c10d::ProcessGroup *> pgs;
public:
UbufBootstrapCallbacks() {
#ifndef NVTE_UB_WITH_MPI
NVTE_ERROR("Internal TE error: Dummy UbufBootstrapCallbacks init without NVTE_UB_WITH_MPI=1!");
#endif
}; // empty constructor for NVTE_UB_WITH_MPI=1
UbufBootstrapCallbacks(c10d::ProcessGroup *world_group, c10d::ProcessGroup *intra_node_group) {
pgs.insert({"world", world_group});
c10d::ProcessGroup::BackendType backend = world_group->getBackendType();
backend_is_nccl = (backend == c10d::ProcessGroup::BackendType::NCCL);
NVTE_CHECK(intra_node_group->getBackendType() == backend,
"Internal TE error: Intra-node group must be on the same backend (%s) as the world ",
"group!", world_group->getBackendName());
pgs.insert({"intra", intra_node_group});
initialized = true;
}
~UbufBootstrapCallbacks() {
for (auto &pg : pgs) pg.second = nullptr;
backend_is_nccl = false;
initialized = false;
}
void ub_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes,
char *group) {
NVTE_CHECK(initialized, "Internal TE error: tex.UbufBootstrapCallbacks() is not initialized ",
"with valid process groups!");
auto localtensor =
torch::from_blob(localdata, {static_cast<int64_t>(localbytes / sizeof(uint8_t))},
at::device(torch::kCPU).dtype(torch::kUInt8));
auto globaltensor = torch_callbacks.allgather(localtensor, group);
*globaldata = globaltensor.data_ptr();
torch_callbacks.gathered_tensors[*globaldata] = globaltensor;
}
auto localtmp = (backend_is_nccl) ? localtensor.cuda() : localtensor;
auto globaltensor =
torch::from_blob(globaldata, {static_cast<int64_t>(globalbytes / sizeof(uint8_t))},
at::device(torch::kCPU).dtype(torch::kUInt8));
auto globaltmp = (backend_is_nccl) ? globaltensor.cuda() : globaltensor;
/*
** Python callback for torch.distributed.barrier(tp_group).
*/
void ub_barrier(char *group) {
assert(torch_callbacks.initialized);
torch_callbacks.barrier(group);
}
std::vector<std::vector<torch::Tensor>> globalchunks = {globaltmp.chunk(pgs[group]->getSize())};
std::vector<torch::Tensor> localchunk = {localtmp};
auto work = pgs[group]->allgather(globalchunks, localchunk);
work->wait();
/*
** Python callback for freeing up tensors created in the ub_alloc_copy_allgather(...) callback.
*/
void ub_free(void *ptr) {
assert(torch_callbacks.initialized);
auto i = torch_callbacks.gathered_tensors.find(ptr);
if (i == torch_callbacks.gathered_tensors.end()) return;
auto tensor = std::move(i->second);
torch_callbacks.gathered_tensors.erase(i);
torch_callbacks.free(tensor);
}
if (backend_is_nccl) {
globaltensor.copy_(globaltmp.cpu());
globaltmp = torch::Tensor();
localtmp = torch::Tensor();
}
}
void ub_barrier(char *group) {
NVTE_CHECK(initialized, "Internal TE error: tex.UbufBootstrapCallbacks() is not initialized ",
"with valid process groups!");
auto work = pgs[group]->barrier();
work->wait();
}
};
enum class COMM_TYPE { RS = 0, AG = 1 };
......@@ -127,7 +151,6 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
torch::Tensor _ubuf_scale_inv;
bool _ubuf_scale_inv_initialized;
torch::Tensor counter;
torch::Tensor _empty_tensor;
at::cuda::CUDAStream _stream_comm = at::cuda::getStreamFromPool(true);
std::vector<at::cuda::CUDAStream> _stream_compute;
cudaEvent_t _start_compute, _stop_compute, _start_d2dcopy, _start_comm, _stop_comm;
......@@ -136,36 +159,45 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
int _use_ce;
bool _atomic_gemm;
UbufCommOverlap(torch::Tensor sample, int rank, int world_size, int tp_rank, int tp_size,
int num_comm_sm, int comm_cga_size, int num_splits, bool set_sm_margin,
int num_max_streams, bool atomic_gemm, torch::Tensor empty_tensor) {
UbufCommOverlap(torch::Tensor sample, int myrank, int numranks, int mylocal, int numlocal,
int mynode, int numnodes, int tp_size, int num_comm_sm, int comm_cga_size,
int num_splits, bool set_sm_margin, int num_max_streams, bool atomic_gemm,
UbufBootstrapCallbacks &callbacks) {
// Initialize userbuf communicator
if (!comm_created) {
if (rank == 0) {
if (myrank == 0) {
printf("!!! [UB] Create UbufCommOverlap Communicator\n");
}
if (transformer_engine::getenv<bool>("UB_MPI_BOOTSTRAP")) {
#ifdef NVTE_UB_WITH_MPI
create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1);
} else {
create_communicator_grouped2(&_ub_comm, rank, world_size, tp_rank, tp_size, 1, 1,
&ub_alloc_copy_allgather, &ub_barrier, &ub_free, 1, 1, tp_size,
1);
}
#else
create_communicator_grouped2(
&_ub_comm, myrank, numranks, mylocal, numlocal, mynode, numnodes,
std::bind(&UbufBootstrapCallbacks::ub_allgather, callbacks, _1, _2, _3, _4, _5),
std::bind(&UbufBootstrapCallbacks::ub_barrier, callbacks, _1), 1, 1, tp_size, 1);
#endif
comm_created = true;
}
_use_ce = 0;
_num_comm_sm = num_comm_sm;
_cga_size = comm_cga_size;
_empty_tensor = empty_tensor;
// Allocate and register extra userbuffers
int ubuf_bytes = sample.numel() * sample.element_size();
if (transformer_engine::getenv<bool>("UB_SKIPMC")) {
_ubuf = torch::zeros_like(sample);
_ubuf_ptr = _ubuf.data_ptr();
_ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes,
_ub_comm, false);
} else {
_ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes,
_ub_comm, true);
if (rank == 0) {
_ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options());
}
if (_ub_comm->myrank == 0) {
printf("!!! [UB] Register UBuf %d\n", _ub_reg);
}
_ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options());
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
for (int i = 0; i < std::min(num_max_streams, num_splits); i++) {
......@@ -177,7 +209,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
_num_splits = num_splits;
_tp_size = tp_size;
_tp_id = (rank % tp_size);
_tp_id = (_ub_comm->myrank % _tp_size);
_ubuf_scale_inv_initialized = false;
// Set the number of SMs for GEMM with margin
......@@ -201,6 +233,25 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
cudaEventCreateWithFlags(&_stop_comm, 0);
}
~UbufCommOverlap() {
cudaEventDestroy(_stop_comm);
cudaEventDestroy(_start_comm);
cudaEventDestroy(_start_d2dcopy);
cudaEventDestroy(_stop_compute);
cudaEventDestroy(_start_compute);
for (size_t i = 0; i < _stream_compute.size(); i++) cudaStreamDestroy(_stream_compute[i]);
if (comm_created) {
#ifdef NVTE_UB_WITH_MPI
destroy_communicator_mpi(_ub_comm);
#else
destroy_communicator(_ub_comm);
#endif
comm_created = false;
}
}
/*
** Bulk GEMM + COMM
** This function assumes the communication input is pre-copied to _ubuf
......@@ -226,8 +277,8 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
// Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_comm, (cudaStream_t)stream_main));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, (cudaStream_t)stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
// Communication: AG and RS
if (_comm_type == COMM_TYPE::AG) {
......@@ -261,8 +312,8 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
D_type, D_amax, bias, bias_type, pre_gelu_out, grad, workspace, workspaceSize,
accumulate, use_split_accumulator, _math_sms);
CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0));
// Generate output tensor from userbuf data pointer
int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size;
......@@ -305,9 +356,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
// Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0));
if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor];
......@@ -326,6 +377,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms, _num_splits /*m_split*/, 0 /*n_split*/, true /*gemm_producer*/,
counter);
for (int i = 0; i < _num_splits; i++) {
const char *env_p = std::getenv("NVTE_RS_STRIDED_ATOMIC");
if (env_p != nullptr && env_p[0] == '1') {
......@@ -373,10 +425,10 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
}
_ub_comm->sms = ori_sms;
CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[0]));
CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[0]));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0));
at::cuda::setCurrentCUDAStream(stream_main);
return;
......@@ -416,11 +468,11 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
// Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
for (size_t i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0));
}
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0));
if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor];
......@@ -456,9 +508,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms);
CHECK_CUDA(cudaEventRecord(
NVTE_CHECK_CUDA(cudaEventRecord(
_start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
// Communication chunk
if (_ubuf.element_size() == 1) {
......@@ -479,9 +531,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
}
int last_compute_stream_id =
(_num_splits + _stream_compute.size() - 1) % _stream_compute.size();
CHECK_CUDA(
NVTE_CHECK_CUDA(
cudaEventRecord(_start_comm, (cudaStream_t)_stream_compute[last_compute_stream_id]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
// Last communication chunk with max SM
_ub_comm->sms = UB_MAX_SM;
......@@ -513,9 +565,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms);
CHECK_CUDA(cudaEventRecord(_start_comm,
NVTE_CHECK_CUDA(cudaEventRecord(_start_comm,
(cudaStream_t)_stream_compute[i % _stream_compute.size()]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
// Communication chunk. Uses MAX_SM at the last chunk
if (i == _num_splits - 1) {
......@@ -540,12 +592,12 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
}
}
for (size_t i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
}
_ub_comm->sms = ori_sms;
CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0));
at::cuda::setCurrentCUDAStream(stream_main);
return;
......@@ -576,10 +628,11 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
}
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0));
CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.data_ptr(), input.numel() * input.element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm));
NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0));
NVTE_CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.data_ptr(),
input.numel() * input.element_size(), cudaMemcpyDeviceToDevice,
(cudaStream_t)_stream_comm));
}
torch::Tensor &get_ubuf_output(int comm_type) {
......@@ -609,7 +662,6 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
void *_ubuf_ptr;
torch::Tensor _ubuf;
torch::Tensor counter;
torch::Tensor _empty_tensor;
torch::Tensor _ubuf_scale_inv;
bool _ubuf_scale_inv_initialized;
std::vector<torch::Tensor> _ubufs;
......@@ -622,29 +674,30 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
int _cga_size;
bool _atomic_gemm;
UbufP2PCommOverlap(torch::Tensor sample, int rank, int world_size, int tp_rank, int tp_size,
int num_comm_sm, int comm_cga_size, bool set_sm_margin, bool aggregate2,
int num_max_streams, bool is_reduce_scatter, bool atomic_gemm, bool use_ce,
torch::Tensor empty_tensor) {
UbufP2PCommOverlap(torch::Tensor sample, int myrank, int numranks, int mylocal, int numlocal,
int mynode, int numnodes, int tp_size, int num_comm_sm, int comm_cga_size,
bool set_sm_margin, bool aggregate2, int num_max_streams,
bool is_reduce_scatter, bool atomic_gemm, bool use_ce,
UbufBootstrapCallbacks &callbacks) {
// Initialize userbuf communicator
if (!comm_created) {
if (rank == 0) {
if (myrank == 0) {
printf("!!! [UB] Create UbufP2PCommOverlap Communicator\n");
}
if (transformer_engine::getenv<bool>("UB_MPI_BOOTSTRAP")) {
#ifdef NVTE_UB_WITH_MPI
create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1);
} else {
create_communicator_grouped2(&_ub_comm, rank, world_size, tp_rank, tp_size, 1, 1,
&ub_alloc_copy_allgather, &ub_barrier, &ub_free, 1, 1, tp_size,
1);
}
#else
create_communicator_grouped2(
&_ub_comm, myrank, numranks, mylocal, numlocal, mynode, numnodes,
std::bind(&UbufBootstrapCallbacks::ub_allgather, callbacks, _1, _2, _3, _4, _5),
std::bind(&UbufBootstrapCallbacks::ub_barrier, callbacks, _1), 1, 1, tp_size, 1);
#endif
comm_created = true;
}
_use_ce = use_ce;
_num_comm_sm = num_comm_sm;
_cga_size = comm_cga_size;
_empty_tensor = empty_tensor;
// Create workspace tensor with userbuffer
int ubuf_bytes = sample.numel() * sample.element_size();
int ubuf_chunk_bytes = ubuf_bytes / tp_size;
......@@ -655,15 +708,23 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
ubuf_bytes = static_cast<int>(ubuf_bytes / tp_size * (tp_size * 2 - 1));
num_ubuf_chunks = static_cast<int>(tp_size * 2 - 1);
}
if (transformer_engine::getenv<bool>("UB_SKIPMC")) {
_ubuf = torch::zeros({sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)},
sample.options());
_ubuf_ptr = _ubuf.data_ptr();
_ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes,
_ub_comm, false);
} else {
_ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes,
_ub_comm, true);
if (rank == 0) {
_ubuf =
torch::from_blob(_ubuf_ptr, {sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)},
sample.options());
}
if (_ub_comm->myrank == 0) {
printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg);
}
_ubuf = torch::from_blob(
_ubuf_ptr, {sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)}, sample.options());
// Create tensor chunks for easy management
char *ubuf_byte_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
for (int i = 0; i < num_ubuf_chunks; i++) {
......@@ -690,23 +751,23 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
_tp_size = tp_size;
_aggregate2 = aggregate2;
_rank = rank;
_tp_id = (rank % tp_size);
_rank_round_tp = (rank / tp_size) * tp_size;
_next_rank = (tp_size + rank + 1) % tp_size + _rank_round_tp;
_prev_rank = (tp_size + rank + -1) % tp_size + _rank_round_tp;
_rank = _ub_comm->myrank;
_tp_id = (_rank % _tp_size);
_rank_round_tp = (_rank / _tp_size) * _tp_size;
_next_rank = (_tp_size + _rank + 1) % _tp_size + _rank_round_tp;
_prev_rank = (_tp_size + _rank + -1) % _tp_size + _rank_round_tp;
_ubuf_scale_inv_initialized = false;
_atomic_gemm = atomic_gemm;
_self_chunk_id = _tp_id;
if (_atomic_gemm) {
auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA);
counter = torch::zeros({tp_size * 2}, counter_options);
counter.index_put_({Slice(None, tp_size)}, 1);
counter = torch::zeros({_tp_size * 2}, counter_options);
counter.index_put_({Slice(None, _tp_size)}, 1);
if (!is_reduce_scatter) {
const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC");
if (rank == 0 && env_p != nullptr) {
if (_rank == 0 && env_p != nullptr) {
if (env_p[0] == '1') {
printf("!!userbuffers_sendrecv_multi_atomic_shuffle\n");
}
......@@ -724,6 +785,25 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
cudaEventCreateWithFlags(&_stop_recv, 0);
}
~UbufP2PCommOverlap() {
cudaEventDestroy(_stop_recv);
cudaEventDestroy(_stop_send);
cudaEventDestroy(_start_comm);
cudaEventDestroy(_stop_compute);
cudaEventDestroy(_start_compute);
for (size_t i = 0; i < _stream_compute.size(); i++) cudaStreamDestroy(_stream_compute[i]);
if (comm_created) {
#ifdef NVTE_UB_WITH_MPI
destroy_communicator_mpi(_ub_comm);
#else
destroy_communicator(_ub_comm);
#endif
comm_created = false;
}
}
/*
** Split AllGather + AtomicGEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is
......@@ -766,9 +846,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
// Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options());
......@@ -809,12 +889,12 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_self_chunk_id].numel());
assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size());
CHECK_CUDA(
NVTE_CHECK_CUDA(
cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_self_chunk_id].data_ptr(),
_ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
}
// Reset atomic counters
......@@ -822,7 +902,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
// Copy the first GEMM output chunk to the end chunk position of D_buffer
char *src_ptr = reinterpret_cast<char *>(D_buffer.data_ptr());
CHECK_CUDA(cudaMemcpyAsync(src_ptr + (D.numel() * D.element_size()), src_ptr,
NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + (D.numel() * D.element_size()), src_ptr,
n_chunk * m * D.element_size(), cudaMemcpyDeviceToDevice,
(cudaStream_t)stream_main));
// Return the last N rows of D_buffer
......@@ -871,12 +951,12 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor];
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
for (size_t i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0));
}
if (_aggregate2) {
const int num_steps = _tp_size / 2;
......@@ -892,9 +972,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
(cudaStream_t)_stream_send);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank,
(cudaStream_t)_stream_recv);
CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _stop_recv, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _stop_recv, 0));
int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1;
const int next_rank = (_tp_size + _tp_id + 2) % _tp_size + _rank_round_tp;
......@@ -931,14 +1011,14 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
next_rank, (cudaStream_t)_stream_send);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm,
prev_rank, (cudaStream_t)_stream_recv);
CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0));
CHECK_CUDA(cudaStreamWaitEvent(
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(
(cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
} else if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(),
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send));
}
......@@ -976,27 +1056,27 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
_next_rank, (cudaStream_t)_stream_send);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm,
_prev_rank, (cudaStream_t)_stream_recv);
CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0));
CHECK_CUDA(cudaStreamWaitEvent(
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(
(cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
} else if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(),
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send));
}
}
}
for (size_t i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
}
CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
at::cuda::setCurrentCUDAStream(stream_main);
return D;
......@@ -1032,8 +1112,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
// Catch up the main stream
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
// Atomic GEMM
// Process GEMM chunks in the order that AG+GEMM places the output chunks.
......@@ -1059,8 +1139,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank,
(cudaStream_t)_stream_recv);
}
CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
// Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].data_ptr());
......@@ -1113,11 +1193,11 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
// Catch up the main stream
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
for (size_t i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0));
}
// GEMM and send/recv chunks
......@@ -1145,18 +1225,18 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
int recv_offset = comm_bytes * (i - 1 + _tp_size);
int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp;
int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp;
CHECK_CUDA(cudaEventRecord(
NVTE_CHECK_CUDA(cudaEventRecord(
_start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_comm, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_comm, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_comm, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_comm, 0));
userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm,
send_rank, (cudaStream_t)_stream_send);
userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm,
recv_rank, (cudaStream_t)_stream_recv);
}
}
CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
// Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].data_ptr());
......@@ -1174,11 +1254,11 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
torch::sum_out(rs_output, reduce_buf, 0);
}
for (size_t i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
}
CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
}
/*
......@@ -1191,16 +1271,16 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
if (input.numel() != _ubufs[0].numel() || input.element_size() != _ubufs[0].element_size()) {
NVTE_ERROR("input and ubuf size do not match!");
}
CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].data_ptr(), input.data_ptr(),
input.numel() * input.element_size(), cudaMemcpyDeviceToDevice,
(cudaStream_t)stream_main));
NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].data_ptr(), input.data_ptr(),
input.numel() * input.element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main));
} else {
if (input.numel() != _ubuf.numel() || input.element_size() != _ubuf.element_size()) {
NVTE_ERROR("input and ubuf size do not match!");
}
CHECK_CUDA(cudaMemcpyAsync(_ubuf.data_ptr(), input.data_ptr(),
input.numel() * input.element_size(), cudaMemcpyDeviceToDevice,
(cudaStream_t)stream_main));
NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.data_ptr(), input.data_ptr(),
input.numel() * input.element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main));
}
}
......
......@@ -206,9 +206,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv)
.def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history);
// Communication functions to initialize Userbuffers communicators
// Note: Callbacks are not called, so safe to release GIL.
m.def("set_ubuf_bootstrap_callbacks", &ubuf::set_ubuf_bootstrap_callbacks,
m.def("device_supports_multicast", &ubuf::device_supports_multicast,
py::call_guard<py::gil_scoped_release>());
m.def("ubuf_built_with_mpi", &ubuf::ubuf_built_with_mpi,
py::call_guard<py::gil_scoped_release>());
py::class_<ubuf::UbufBootstrapCallbacks>(m, "UbufBootstrapCallbacks")
.def(py::init<>(), py::call_guard<py::gil_scoped_release>())
.def(py::init<c10d::ProcessGroup *, c10d::ProcessGroup *>(),
py::call_guard<py::gil_scoped_release>());
py::enum_<ubuf::UBOverlapAlgo>(m, "UbufOverlapAlgo")
......@@ -225,8 +231,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// communicator with Python functions (e.g. PyTorch distributed
// communication)
py::class_<ubuf::UbufCommOverlap>(m, "UbufCommOverlap")
.def(py::init<torch::Tensor&, int, int, int, int, int, int, int, bool, int, bool,
torch::Tensor>())
.def(py::init<torch::Tensor &, int, int, int, int, int, int, int, int, int, int, bool, int,
bool, ubuf::UbufBootstrapCallbacks &>(),
py::call_guard<py::gil_scoped_release>())
.def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap,
py::call_guard<py::gil_scoped_release>())
.def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs,
......@@ -250,8 +257,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// communicator with Python functions (e.g. PyTorch distributed
// communication)
py::class_<ubuf::UbufP2PCommOverlap>(m, "UbufP2PCommOverlap")
.def(py::init<torch::Tensor&, int, int, int, int, int, int, bool, bool, int, bool, bool, bool,
torch::Tensor>())
.def(py::init<torch::Tensor &, int, int, int, int, int, int, int, int, int, bool, bool, int,
bool, bool, bool, ubuf::UbufBootstrapCallbacks &>(),
py::call_guard<py::gil_scoped_release>())
.def("split_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_ag,
py::call_guard<py::gil_scoped_release>())
.def("split_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_rs,
......
......@@ -7,66 +7,82 @@
#include "ipcsocket.h"
#include <errno.h>
#include <stdarg.h>
#include <stdlib.h>
#include <string.h>
#define WARN(...) \
{}
#define TRACE(...) \
{}
#define SYSCHECK(...) \
{}
#define EQCHECK(...) \
{}
#define IPC_MAX_MSGLEN 4096
// Enable Linux abstract socket naming
#define USE_ABSTRACT_SOCKET
void ipc_warn(const char *format, ...) {
char buffer[IPC_MAX_MSGLEN];
#define NCCL_IPC_SOCKNAME_STR "/tmp/nccl-socket-%d-%lx"
va_list args;
va_start(args, format);
vsnprintf(buffer, IPC_MAX_MSGLEN - 1, format, args);
snprintf(buffer + strlen(buffer), IPC_MAX_MSGLEN - strlen(buffer) - 1, " : %s (%d)\n",
strerror(errno), errno);
fflush(stdout);
fputs(buffer, stderr);
fflush(NULL);
va_end(args);
}
static const char *ipcSocketResultStrings[static_cast<int>(ipcSocketNumResults)] = {
"Success", "Unhandled CUDA error", "System error", "Internal error",
"Invalid argument", "Invalid usage", "Remote error", "In progress",
};
const char *ipcSocketGetErrorString(ipcSocketResult_t res) {
return ipcSocketResultStrings[static_cast<int>(res)];
}
#define USE_ABSTRACT_SOCKET // Enable Linux abstract socket naming
#define IPC_SOCKNAME_STR "/tmp/ub-ipc-socket-%d-%lx"
/*
* Create a Unix Domain Socket
*/
ncclResult_t ncclIpcSocketInit(ncclIpcSocket *handle, int rank, uint64_t hash,
ipcSocketResult_t ipcSocketInit(IpcSocketHandle *handle, int rank, uint64_t hash,
volatile uint32_t *abortFlag) {
int fd = -1;
struct sockaddr_un cliaddr;
char temp[NCCL_IPC_SOCKNAME_LEN] = "";
char temp[IPC_SOCKNAME_LEN] = "";
if (handle == NULL) {
return ncclInternalError;
return ipcSocketInternalError;
}
handle->fd = -1;
handle->socketName[0] = '\0';
if ((fd = socket(AF_UNIX, SOCK_DGRAM, 0)) < 0) {
WARN("UDS: Socket creation error : %s (%d)", strerror(errno), errno);
return ncclSystemError;
ipc_warn("UDS: Socket creation error");
return ipcSocketSystemError;
}
bzero(&cliaddr, sizeof(cliaddr));
cliaddr.sun_family = AF_UNIX;
// Create unique name for the socket.
size_t len = snprintf(temp, NCCL_IPC_SOCKNAME_LEN, NCCL_IPC_SOCKNAME_STR, rank, hash);
size_t len = snprintf(temp, IPC_SOCKNAME_LEN, IPC_SOCKNAME_STR, rank, hash);
if (len > (sizeof(cliaddr.sun_path) - 1)) {
WARN("UDS: Cannot bind provided name to socket. Name too large");
return ncclInternalError;
errno = ENAMETOOLONG;
ipc_warn("UDS: Cannot bind provided name to socket. Name too large");
return ipcSocketInternalError;
}
#ifndef USE_ABSTRACT_SOCKET
unlink(temp);
#endif
TRACE(NCCL_INIT, "UDS: Creating socket %s", temp);
strncpy(cliaddr.sun_path, temp, len);
#ifdef USE_ABSTRACT_SOCKET
cliaddr.sun_path[0] = '\0'; // Linux abstract socket trick
#else
unlink(temp);
#endif
if (bind(fd, (struct sockaddr *)&cliaddr, sizeof(cliaddr)) < 0) {
WARN("UDS: Binding to socket %s failed : %s (%d)", temp, strerror(errno), errno);
ipc_warn("UDS: Binding to socket %s failed", temp);
close(fd);
return ncclSystemError;
return ipcSocketSystemError;
}
handle->fd = fd;
......@@ -79,24 +95,25 @@ ncclResult_t ncclIpcSocketInit(ncclIpcSocket *handle, int rank, uint64_t hash,
fcntl(fd, F_SETFL, flags | O_NONBLOCK);
}
return ncclSuccess;
return ipcSocketSuccess;
}
ncclResult_t ncclIpcSocketGetFd(struct ncclIpcSocket *handle, int *fd) {
ipcSocketResult_t ipcSocketGetFd(struct IpcSocketHandle *handle, int *fd) {
if (handle == NULL) {
WARN("ncclSocketGetFd: pass NULL socket");
return ncclInvalidArgument;
errno = EINVAL;
ipc_warn("ipcSocketSocketGetFd: pass NULL socket");
return ipcSocketInvalidArgument;
}
if (fd) *fd = handle->fd;
return ncclSuccess;
return ipcSocketSuccess;
}
ncclResult_t ncclIpcSocketClose(ncclIpcSocket *handle) {
ipcSocketResult_t ipcSocketClose(IpcSocketHandle *handle) {
if (handle == NULL) {
return ncclInternalError;
return ipcSocketInternalError;
}
if (handle->fd <= 0) {
return ncclSuccess;
return ipcSocketSuccess;
}
#ifndef USE_ABSTRACT_SOCKET
if (handle->socketName[0] != '\0') {
......@@ -105,10 +122,10 @@ ncclResult_t ncclIpcSocketClose(ncclIpcSocket *handle) {
#endif
close(handle->fd);
return ncclSuccess;
return ipcSocketSuccess;
}
ncclResult_t ncclIpcSocketRecvMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, int *recvFd) {
ipcSocketResult_t ipcSocketRecvMsg(IpcSocketHandle *handle, void *hdr, int hdrLen, int *recvFd) {
struct msghdr msg = {0, 0, 0, 0, 0, 0, 0};
struct iovec iov[1];
......@@ -138,39 +155,44 @@ ncclResult_t ncclIpcSocketRecvMsg(ncclIpcSocket *handle, void *hdr, int hdrLen,
while ((ret = recvmsg(handle->fd, &msg, 0)) <= 0) {
if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) {
WARN("UDS: Receiving data over socket failed : %d", errno);
return ncclSystemError;
ipc_warn("UDS: Receiving data over socket failed");
return ipcSocketSystemError;
}
if (handle->abortFlag && *handle->abortFlag) return ncclInternalError;
if (handle->abortFlag && *handle->abortFlag) return ipcSocketInternalError;
}
if (recvFd != NULL) {
if (((cmptr = CMSG_FIRSTHDR(&msg)) != NULL) && (cmptr->cmsg_len == CMSG_LEN(sizeof(int)))) {
if ((cmptr->cmsg_level != SOL_SOCKET) || (cmptr->cmsg_type != SCM_RIGHTS)) {
WARN("UDS: Receiving data over socket failed");
return ncclSystemError;
errno = EBADMSG;
ipc_warn("UDS: Receiving data over socket %s failed", handle->socketName);
return ipcSocketSystemError;
}
memmove(recvFd, CMSG_DATA(cmptr), sizeof(*recvFd));
} else {
WARN("UDS: Receiving data over socket %s failed", handle->socketName);
return ncclSystemError;
errno = ENOMSG;
ipc_warn("UDS: Receiving data over socket %s failed", handle->socketName);
return ipcSocketSystemError;
}
TRACE(NCCL_INIT | NCCL_P2P, "UDS: Got recvFd %d from socket %s", *recvFd, handle->socketName);
} else {
errno = EINVAL;
ipc_warn("UDS: File descriptor pointer cannot be NULL");
return ipcSocketInvalidArgument;
}
return ncclSuccess;
return ipcSocketSuccess;
}
ncclResult_t ncclIpcSocketRecvFd(ncclIpcSocket *handle, int *recvFd) {
return ncclIpcSocketRecvMsg(handle, NULL, 0, recvFd);
ipcSocketResult_t ipcSocketRecvFd(IpcSocketHandle *handle, int *recvFd) {
return ipcSocketRecvMsg(handle, NULL, 0, recvFd);
}
ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, const int sendFd,
ipcSocketResult_t ipcSocketSendMsg(IpcSocketHandle *handle, void *hdr, int hdrLen, const int sendFd,
int rank, uint64_t hash) {
struct msghdr msg = {0, 0, 0, 0, 0, 0, 0};
struct iovec iov[1];
char temp[NCCL_IPC_SOCKNAME_LEN];
char temp[IPC_SOCKNAME_LEN];
union {
struct cmsghdr cm;
......@@ -185,10 +207,11 @@ ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen,
bzero(&cliaddr, sizeof(cliaddr));
cliaddr.sun_family = AF_UNIX;
size_t len = snprintf(temp, NCCL_IPC_SOCKNAME_LEN, NCCL_IPC_SOCKNAME_STR, rank, hash);
size_t len = snprintf(temp, IPC_SOCKNAME_LEN, IPC_SOCKNAME_STR, rank, hash);
if (len > (sizeof(cliaddr.sun_path) - 1)) {
WARN("UDS: Cannot connect to provided name for socket. Name too large");
return ncclInternalError;
errno = ENAMETOOLONG;
ipc_warn("UDS: Cannot connect to provided name for socket. Name too large");
return ipcSocketInternalError;
}
(void)strncpy(cliaddr.sun_path, temp, len);
......@@ -196,11 +219,7 @@ ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen,
cliaddr.sun_path[0] = '\0'; // Linux abstract socket trick
#endif
TRACE(NCCL_INIT, "UDS: Sending hdr %p len %d to UDS socket %s", hdr, hdrLen, temp);
if (sendFd != -1) {
TRACE(NCCL_INIT, "UDS: Sending fd %d to UDS socket %s", sendFd, temp);
msg.msg_control = control_un.control;
msg.msg_controllen = sizeof(control_un.control);
......@@ -228,15 +247,16 @@ ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen,
ssize_t sendResult;
while ((sendResult = sendmsg(handle->fd, &msg, 0)) < 0) {
if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) {
WARN("UDS: Sending data over socket %s failed : %s (%d)", temp, strerror(errno), errno);
return ncclSystemError;
ipc_warn("UDS: Sending data over socket %s failed", temp);
return ipcSocketSystemError;
}
if (handle->abortFlag && *handle->abortFlag) return ncclInternalError;
if (handle->abortFlag && *handle->abortFlag) return ipcSocketInternalError;
}
return ncclSuccess;
return ipcSocketSuccess;
}
ncclResult_t ncclIpcSocketSendFd(ncclIpcSocket *handle, const int sendFd, int rank, uint64_t hash) {
return ncclIpcSocketSendMsg(handle, NULL, 0, sendFd, rank, hash);
ipcSocketResult_t ipcSocketSendFd(IpcSocketHandle *handle, const int sendFd, int rank,
uint64_t hash) {
return ipcSocketSendMsg(handle, NULL, 0, sendFd, rank, hash);
}
......@@ -4,10 +4,9 @@
* See LICENSE for license information.
************************************************************************/
#ifndef NCCL_IPCSOCKET_H
#define NCCL_IPCSOCKET_H
#ifndef TRANSFORMER_ENGINE_USERBUFFERS_IPCSOCKET_H
#define TRANSFORMER_ENGINE_USERBUFFERS_IPCSOCKET_H
// #include "nccl.h"
#include <errno.h>
#include <fcntl.h>
#include <inttypes.h>
......@@ -21,32 +20,33 @@
#include <unistd.h>
typedef enum {
ncclSuccess = 0,
ncclUnhandledCudaError = 1,
ncclSystemError = 2,
ncclInternalError = 3,
ncclInvalidArgument = 4,
ncclInvalidUsage = 5,
ncclRemoteError = 6,
ncclInProgress = 7,
ncclNumResults = 8
} ncclResult_t;
#define NCCL_IPC_SOCKNAME_LEN 64
struct ncclIpcSocket {
ipcSocketSuccess = 0,
ipcSocketUnhandledCudaError = 1,
ipcSocketSystemError = 2,
ipcSocketInternalError = 3,
ipcSocketInvalidArgument = 4,
ipcSocketInvalidUsage = 5,
ipcSocketRemoteError = 6,
ipcSocketInProgress = 7,
ipcSocketNumResults = 8
} ipcSocketResult_t;
const char *ipcSocketGetErrorString(ipcSocketResult_t res);
#define IPC_SOCKNAME_LEN 64
struct IpcSocketHandle {
int fd;
char socketName[NCCL_IPC_SOCKNAME_LEN];
char socketName[IPC_SOCKNAME_LEN];
volatile uint32_t *abortFlag;
};
ncclResult_t ncclIpcSocketInit(struct ncclIpcSocket *handle, int rank, uint64_t hash,
ipcSocketResult_t ipcSocketInit(IpcSocketHandle *handle, int rank, uint64_t hash,
volatile uint32_t *abortFlag);
ncclResult_t ncclIpcSocketClose(struct ncclIpcSocket *handle);
ncclResult_t ncclIpcSocketGetFd(struct ncclIpcSocket *handle, int *fd);
ipcSocketResult_t ipcSocketClose(IpcSocketHandle *handle);
ipcSocketResult_t ipcSocketGetFd(IpcSocketHandle *handle, int *fd);
ncclResult_t ncclIpcSocketRecvFd(struct ncclIpcSocket *handle, int *fd);
ncclResult_t ncclIpcSocketSendFd(struct ncclIpcSocket *handle, const int fd, int rank,
uint64_t hash);
ipcSocketResult_t ipcSocketRecvFd(IpcSocketHandle *handle, int *fd);
ipcSocketResult_t ipcSocketSendFd(IpcSocketHandle *handle, const int fd, int rank, uint64_t hash);
#endif /* NCCL_IPCSOCKET_H */
#endif /* TRANSFORMER_ENGINE_USERBUFFERS_IPCSOCKET_H */
......@@ -19,15 +19,52 @@
#include <map>
#include <utility>
#include "../util/cuda_driver.h"
#include "common/util/cuda_driver.h"
#include "common/util/logging.h"
#include "ipcsocket.h"
#include "userbuffers.h"
#ifdef UB_MPI_BOOTSTRAP
#include <mpi.h>
#ifdef NVTE_UB_WITH_MPI
static MPI_Comm EXT_COMM_WORLD = MPI_COMM_WORLD;
static MPI_Comm EXT_COMM_INTRA;
static MPI_Comm EXT_COMM_INTER;
#define UB_MPI_CHECK(expr) \
do { \
const int mpicode = (expr); \
if (mpicode != MPI_SUCCESS) { \
char mpimsg[MPI_MAX_ERROR_STRING]; \
int mpilen; \
MPI_Error_string(mpicode, mpimsg, &mpilen); \
std::vector<char> errmsg(1024); \
snprintf(errmsg.data(), errmsg.size(), "%s:%d in function %s: %s", __FILE__, __LINE__, \
__func__, mpimsg); \
throw std::runtime_error(errmsg.data()); \
} \
} while (false)
void ub_mpi_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes,
ExtComm group) {
// UB_MPI_CHECK(MPI_Allgather(localdata, localbytes, MPI_BYTE,
// globaldata, globalbytes, MPI_BYTE,
// static_cast<MPI_Comm>(group)));
MPI_Comm comm = static_cast<MPI_Comm>(group);
int numranks;
UB_MPI_CHECK(MPI_Comm_size(comm, &numranks));
assert(globalbytes == numranks * localbytes);
int myrank;
UB_MPI_CHECK(MPI_Comm_rank(comm, &myrank));
char *globaltarget = reinterpret_cast<char *>(globaldata) + (myrank * localbytes);
memcpy(globaltarget, localdata, localbytes);
for (int n = 0; n < numranks; n++) {
globaltarget = reinterpret_cast<char *>(globaldata) + (n * localbytes);
UB_MPI_CHECK(MPI_Bcast(globaltarget, localbytes, MPI_BYTE, n, comm));
}
}
void ub_mpi_barrier(ExtComm group) { UB_MPI_CHECK(MPI_Barrier(static_cast<MPI_Comm>(group))); }
#else
static char EXT_COMM_WORLD[] = "world";
static char EXT_COMM_INTRA[] = "intra";
......@@ -38,33 +75,19 @@ static char EXT_COMM_INTER[] = "inter";
int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (const char *)b); }
#define CUDACHECK(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while (0)
#define NVTE_UB_ERROR(x) \
#define IPCCHECK(cmd) \
do { \
throw std::runtime_error(std::string(__FILE__ ":") + std::to_string(__LINE__) + \
" in function " + __func__ + ": " + x); \
} while (false)
#define NCCLCHECK(cmd) \
do { \
ncclResult_t r = cmd; \
if (r != ncclSuccess) { \
printf("Failed, NCCL error %s:%d ''\n", __FILE__, __LINE__ /*,ncclGetErrorString(r)*/); \
ipcSocketResult_t r = cmd; \
if (r != ipcSocketSuccess) { \
printf("Failed, UDS error %s:%d '%s'\n", __FILE__, __LINE__, ipcSocketGetErrorString(r)); \
exit(EXIT_FAILURE); \
} \
} while (0)
#define NCCLCHECKGOTO(call, RES, label) \
#define IPCCHECKGOTO(call, RES, label) \
do { \
RES = call; \
if (RES != ncclSuccess && RES != ncclInProgress) { \
if (RES != ipcSocketSuccess && RES != ipcSocketInProgress) { \
goto label; \
} \
} while (0);
......@@ -85,15 +108,14 @@ int pipe_rank(communicator *comm, int step) {
int create_communicator_grouped2(
communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode,
int numnodes, std::function<void(void **, void *, size_t, ExtComm)> ext_alloc_copy_allgather,
std::function<void(ExtComm)> ext_barrier, std::function<void(void *)> ext_free, int pipegpus,
int pipenodes, int tensorgpus, int tensornodes) {
int numnodes, std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather,
std::function<void(ExtComm)> ext_barrier, int pipegpus, int pipenodes, int tensorgpus,
int tensornodes) {
*comm = new communicator();
(*comm)->comm_world = EXT_COMM_WORLD;
(*comm)->_alloc_copy_allgather = ext_alloc_copy_allgather;
(*comm)->_allgather = ext_allgather;
(*comm)->_barrier = ext_barrier;
(*comm)->_free = ext_free;
(*comm)->nranks = numranks;
(*comm)->myrank = myrank;
(*comm)->free_region = 0;
......@@ -101,9 +123,9 @@ int create_communicator_grouped2(
int cur_dev, ndev;
cudaDeviceProp device_prop;
CUDACHECK(cudaGetDevice(&cur_dev));
CUDACHECK(cudaGetDeviceCount(&ndev));
CUDACHECK(cudaGetDeviceProperties(&device_prop, cur_dev));
NVTE_CHECK_CUDA(cudaGetDevice(&cur_dev));
NVTE_CHECK_CUDA(cudaGetDeviceCount(&ndev));
NVTE_CHECK_CUDA(cudaGetDeviceProperties(&device_prop, cur_dev));
(*comm)->sm_arch = device_prop.major;
// (*comm)->use_rr_kernel = device_prop.major == 8;
(*comm)->use_rr_kernel = 0;
......@@ -119,7 +141,7 @@ int create_communicator_grouped2(
int device_clock = 0;
// 110 sec wait time by default
int sec_timeout = getenv("UB_TIMEOUT") ? atoi(getenv("UB_TIMEOUT")) : 110;
CUDACHECK(cudaDeviceGetAttribute(&device_clock, cudaDevAttrClockRate, cur_dev));
NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&device_clock, cudaDevAttrClockRate, cur_dev));
(*comm)->ub_timeout = 1000ull * device_clock * sec_timeout;
if ((*comm)->myrank == 0) {
printf("UB_TIMEOUT is set to %d sec, %" PRIu64 " cycles, freq: %dkhz\n", sec_timeout,
......@@ -154,7 +176,7 @@ int create_communicator_grouped2(
if (ndev == numlocal) { // all visible devices
if (cur_dev != mylocal)
printf("%d: device used %d[%d] ,resetting device to %d\n", myrank, cur_dev, ndev, mylocal);
CUDACHECK(cudaSetDevice(mylocal));
NVTE_CHECK_CUDA(cudaSetDevice(mylocal));
}
(*comm)->mydev = cur_dev;
// FIXME need to check that numlocal is multiple of pipegpus x tensorgpus
......@@ -213,14 +235,14 @@ int create_communicator_grouped2(
// Broadcast the a POSIX file descriptor from the local root rank to other local ranks.
// NOTE: This cannot be done via MPI_Bcast or other external comm libraries. They mangle the
// file descriptor and prevent cuMemImportFromShareableHandle() from correctly
// interpreting the file. Instead, we use system socket to send/recv the file handle
// without mangling.
// interpreting the file. Instead, we use Unix domain sockets for the kernel to
// recreate the correct file descriptor on every receiving rank.
int fd;
volatile uint32_t abortFlag = 0;
struct ncclIpcSocket ipcSock = {0};
IpcSocketHandle ipcSock = {0};
uint64_t opId = 0xdeadcafeb000 + (*comm)->ar2_firstgpu;
ncclResult_t ret = ncclSuccess;
NCCLCHECK(ncclIpcSocketInit(&ipcSock, (*comm)->ar2_nvrank, (uint64_t)opId, &abortFlag));
ipcSocketResult_t ret = ipcSocketSuccess;
IPCCHECK(ipcSocketInit(&ipcSock, (*comm)->ar2_nvrank, (uint64_t)opId, &abortFlag));
(*comm)->_barrier((*comm)->comm_world);
if ((*comm)->ar2_nvrank == 0) {
......@@ -232,19 +254,22 @@ int create_communicator_grouped2(
for (int p = 1; p < (*comm)->ar2_nvsize; p++) {
(*comm)->_barrier((*comm)->comm_intra);
NCCLCHECKGOTO(ncclIpcSocketSendFd(&ipcSock, fd, p, (uint64_t)opId), ret, error);
IPCCHECKGOTO(ipcSocketSendFd(&ipcSock, fd, p, (uint64_t)opId), ret, error);
}
} else {
for (int i = 0; i < (*comm)->ar2_nvrank; i++) (*comm)->_barrier((*comm)->comm_intra);
NCCLCHECKGOTO(ncclIpcSocketRecvFd(&ipcSock, &fd), ret, error);
for (int i = 0; i < (*comm)->ar2_nvsize - (*comm)->ar2_nvrank - 1; i++)
for (int p = 1; p < (*comm)->ar2_nvsize; p++) {
(*comm)->_barrier((*comm)->comm_intra);
if ((*comm)->ar2_nvrank == p) IPCCHECKGOTO(ipcSocketRecvFd(&ipcSock, &fd), ret, error);
}
}
error:
if ((*comm)->ar2_nvrank != 0) {
NVTE_CALL_CHECK_CUDA_DRIVER(
cuMemImportFromShareableHandle, &(*comm)->mc_handle, reinterpret_cast<void *>(fd),
static_cast<CUmemAllocationHandleType>(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
}
error:
NCCLCHECK(ncclIpcSocketClose(&ipcSock));
IPCCHECK(ipcSocketClose(&ipcSock));
close(fd);
NVTE_CALL_CHECK_CUDA_DRIVER(cuMulticastAddDevice, (*comm)->mc_handle,
(CUdeviceptr)(*comm)->mydev);
......@@ -275,14 +300,16 @@ int create_communicator_grouped2(
#define LOCALSIZE 4 * (NVTE_REG0_OFFSET(*comm) + NVTE_REG0_FLAGS + NVTE_REG0_COMMBUFFER * NBUF)
// peer pointers + op flags + comm buffer
CUDACHECK(cudaMalloc(&(*comm)->gpu_ptrs, LOCALSIZE)); // flags and pointers, no block data yet
CUDACHECK(cudaMemset((*comm)->gpu_ptrs, 0, LOCALSIZE));
CUDACHECK(cudaDeviceSynchronize());
NVTE_CHECK_CUDA(
cudaMalloc(&(*comm)->gpu_ptrs, LOCALSIZE)); // flags and pointers, no block data yet
NVTE_CHECK_CUDA(cudaMemset((*comm)->gpu_ptrs, 0, LOCALSIZE));
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm, false);
CUDACHECK(cudaMalloc(&(*comm)->send_id, (*comm)->nranks * sizeof(int)));
CUDACHECK(cudaMalloc(&(*comm)->recv_id, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int)));
CUDACHECK(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int)));
CUDACHECK(cudaMemset((*comm)->recv_id, 0, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int)));
NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->send_id, (*comm)->nranks * sizeof(int)));
NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->recv_id, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int)));
NVTE_CHECK_CUDA(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int)));
NVTE_CHECK_CUDA(
cudaMemset((*comm)->recv_id, 0, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int)));
(*comm)->sms = 16;
(*comm)->threads = 1024;
......@@ -291,8 +318,8 @@ int create_communicator_grouped2(
#define GPU_PAGE_OFFSET (GPU_PAGE_SIZE - 1)
#define GPU_PAGE_MASK (~GPU_PAGE_OFFSET)
CUDACHECK(cudaMalloc(&(*comm)->flags, 2 * GPU_PAGE_SIZE));
CUDACHECK(cudaMemset((*comm)->flags, 0, 2 * GPU_PAGE_SIZE));
NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->flags, 2 * GPU_PAGE_SIZE));
NVTE_CHECK_CUDA(cudaMemset((*comm)->flags, 0, 2 * GPU_PAGE_SIZE));
(*comm)->flags =
reinterpret_cast<int *>(((CUdeviceptr)(*comm)->flags + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK);
......@@ -321,75 +348,73 @@ int create_communicator_grouped2(
int create_communicator_grouped(
communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode,
int numnodes, std::function<void(void **, void *, size_t, ExtComm)> ext_alloc_copy_allgather,
std::function<void(ExtComm)> ext_barrier, std::function<void(void *)> ext_free, int pipegpus,
int pipenodes) {
int numnodes, std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather,
std::function<void(ExtComm)> ext_barrier, int pipegpus, int pipenodes) {
return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes,
ext_alloc_copy_allgather, ext_barrier, ext_free, pipegpus,
pipenodes, 1, 1);
ext_allgather, ext_barrier, pipegpus, pipenodes, 1, 1);
}
int create_communicator(
communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode,
int numnodes, std::function<void(void **, void *, size_t, ExtComm)> ext_alloc_copy_allgather,
std::function<void(ExtComm)> ext_barrier, std::function<void(void *)> ext_free) {
int create_communicator(communicator **comm, int myrank, int numranks, int mylocal, int numlocal,
int mynode, int numnodes,
std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather,
std::function<void(ExtComm)> ext_barrier) {
return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes,
ext_alloc_copy_allgather, ext_barrier, ext_free, 1, 1, 1, 1);
ext_allgather, ext_barrier, 1, 1, 1, 1);
}
int create_communicator_grouped2_mpi(communicator **comm, int pipegpus, int pipenodes,
int tensorgpus, int tensornodes) {
#ifdef UB_MPI_BOOTSTRAP
#ifdef NVTE_UB_WITH_MPI
// get global numbers
int myrank, numranks;
MPI_Comm_rank(EXT_COMM_WORLD, &myrank);
MPI_Comm_size(EXT_COMM_WORLD, &numranks);
UB_MPI_CHECK(MPI_Comm_rank(EXT_COMM_WORLD, &myrank));
UB_MPI_CHECK(MPI_Comm_size(EXT_COMM_WORLD, &numranks));
// find intranode numbers and make internode communicator
char host_name[MPI_MAX_PROCESSOR_NAME];
char(*host_names)[MPI_MAX_PROCESSOR_NAME];
int namelen, bytes, color;
int rank = (*comm)->myrank, size = (*comm)->nranks;
MPI_Get_processor_name(host_name, &namelen);
bytes = size * sizeof(char[MPI_MAX_PROCESSOR_NAME]);
host_names = (char(*)[MPI_MAX_PROCESSOR_NAME])malloc(bytes);
strcpy(host_names[rank], host_name); // NOLINT(*)
for (int n = 0; n < size; n++)
MPI_Bcast(&(host_names[n]), MPI_MAX_PROCESSOR_NAME, MPI_CHAR, n, EXT_COMM_WORLD);
qsort(host_names, size, sizeof(char[MPI_MAX_PROCESSOR_NAME]), stringCmp);
color = 0;
for (int n = 0; n < size; n++) {
if (n > 0 && strcmp(host_names[n - 1], host_names[n])) color++;
if (strcmp(host_name, host_names[n]) == 0) break;
char hostname[MPI_MAX_PROCESSOR_NAME];
int namelen;
UB_MPI_CHECK(MPI_Get_processor_name(hostname, &namelen));
char(*hostnames)[MPI_MAX_PROCESSOR_NAME] =
static_cast<char(*)[MPI_MAX_PROCESSOR_NAME]>(malloc(numranks * MPI_MAX_PROCESSOR_NAME));
strcpy(hostnames[myrank], hostname);
for (int n = 0; n < numranks; n++)
UB_MPI_CHECK(MPI_Bcast(&(hostnames[n]), MPI_MAX_PROCESSOR_NAME, MPI_CHAR, n, EXT_COMM_WORLD));
qsort(hostnames, numranks, MPI_MAX_PROCESSOR_NAME, stringCmp);
int color = 0;
for (int n = 0; n < numranks; n++) {
if (n > 0 && strcmp(hostnames[n - 1], hostnames[n])) color++;
if (strcmp(hostname, hostnames[n]) == 0) break;
}
free(host_names);
free(hostnames);
int mylocal, numlocal;
MPI_Comm_split(EXT_COMM_WORLD, color, rank, &EXT_COMM_INTRA);
MPI_Comm_rank(EXT_COMM_INTRA, &mylocal);
MPI_Comm_size(EXT_COMM_INTRA, &numlocal);
UB_MPI_CHECK(MPI_Comm_split(EXT_COMM_WORLD, color, myrank, &EXT_COMM_INTRA));
UB_MPI_CHECK(MPI_Comm_rank(EXT_COMM_INTRA, &mylocal));
UB_MPI_CHECK(MPI_Comm_size(EXT_COMM_INTRA, &numlocal));
// find internode numbers and make internode communicator
CUDACHECK(cudaFree(0));
NVTE_CHECK_CUDA(cudaFree(0));
int allnodes = numranks / numlocal;
int datanodes = allnodes / pipenodes / tensornodes;
// data reduction group node belongs, equals 0 for all if both pipenodes=1 and tensornodes=1
int datanodegroup_id = myrank / numlocal / datanodes;
// mpi communicator only needed for SHARP which is always allreduce1/data-parallel
MPI_Comm_split(EXT_COMM_WORLD, mylocal + numlocal * datanodegroup_id, rank, &EXT_COMM_INTER);
UB_MPI_CHECK(MPI_Comm_split(EXT_COMM_WORLD, mylocal + numlocal * datanodegroup_id, myrank,
&EXT_COMM_INTER));
// different rails from same group are in different subcommunicators
int mynode, numnodes;
MPI_Comm_size(EXT_COMM_INTER, &numnodes);
MPI_Comm_rank(EXT_COMM_INTER, &mynode);
UB_MPI_CHECK(MPI_Comm_size(EXT_COMM_INTER, &numnodes));
UB_MPI_CHECK(MPI_Comm_rank(EXT_COMM_INTER, &mynode));
// finally call the abstracted constructor with MPI info
return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes,
&ub_alloc_copy_allgather, &ub_barrier, &ub_free, pipegpus,
pipenodes, tensorgpus, tensornodes);
&ub_mpi_allgather, &ub_mpi_barrier, pipegpus, pipenodes,
tensorgpus, tensornodes);
#else
NVTE_UB_ERROR(std::string("Bootstrapping Userbuffers with MPI requires ") +
std::string("building Transformer Engine with UB_MPI_BOOTSTRAP=1"));
NVTE_ERROR(std::string("Bootstrapping Userbuffers with MPI requires building") +
std::string("Transformer Engine with NVTE_UB_WITH_MPI=1 and MPI_HOME=/path/to/mpi"));
#endif
}
......@@ -403,48 +428,45 @@ int create_communicator_mpi(communicator **comm) {
void destroy_communicator(communicator *comm) {
for (int hndl = 0; hndl < comm->free_region; hndl++) {
if (comm->mem_dealloc[hndl]) {
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemAddressFree,
reinterpret_cast<CUdeviceptr>(comm->ucbase_ptr[hndl]),
comm->mem_size[hndl] * comm->nvsize);
if (hndl > 0 && comm->use_mc && comm->mem_dealloc[hndl]) {
for (int rank = 0; rank < comm->nvsize; rank++) {
if (rank == comm->nvrank) {
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemRelease, comm->uchandles[hndl][rank]);
} else {
comm->uchandles[hndl][rank] = 0;
}
}
free(reinterpret_cast<void *>(comm->uchandles[hndl]));
} else {
for (int rank = 0; rank < comm->nvsize; rank++) {
if (rank != comm->nvrank) {
cudaIpcCloseMemHandle(comm->peer_ptr[hndl][rank]);
} else if (comm->mem_dealloc[hndl]) {
NVTE_CHECK_CUDA(cudaFree(comm->peer_ptr[hndl][rank]));
} else {
comm->peer_ptr[hndl][rank] = nullptr; // remove reference to external buffer
}
}
free(comm->peer_ptr[hndl]);
}
free(comm->peer_ptr[hndl]);
comm->mem_ptr[hndl] = nullptr;
}
cudaFree(reinterpret_cast<void *>(comm->flags));
cudaFree(reinterpret_cast<void *>(comm->recv_id));
cudaFree(reinterpret_cast<void *>(comm->send_id));
if (comm->use_mc) {
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemAddressFree, reinterpret_cast<CUdeviceptr>(comm->mc_baseptr),
comm->mc_maxsize);
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemRelease, comm->mc_handle);
}
if (comm->mem_dealloc[0]) {
cudaFree(comm->gpu_ptrs);
}
free(comm->fifo);
delete comm;
}
void destroy_communicator_mpi(communicator *comm) {
#ifdef UB_MPI_BOOTSTRAP
MPI_Comm_free(comm->comm_inter);
MPI_Comm_free(comm->comm_intra);
#ifdef NVTE_UB_WITH_MPI
MPI_Comm_free(static_cast<MPI_Comm *>(&(comm->comm_inter)));
MPI_Comm_free(static_cast<MPI_Comm *>(&(comm->comm_intra)));
destroy_communicator(comm);
#else
NVTE_UB_ERROR(std::string("Communicator is not bootstrapped with MPI and ") +
NVTE_ERROR(std::string("Communicator is not bootstrapped with MPI and ") +
std::string("can only be deallocated with destroy_communicator()."));
#endif
}
......@@ -457,7 +479,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
comm->memflags[hndl] = 0;
comm->mem_dealloc[hndl] = alloc;
if (alloc) {
if (comm->use_mc && alloc) {
int nranks = comm->nvsize; // total GPUs in NVLINK domain
int myrank = comm->nvrank;
void **remptrs = reinterpret_cast<void **>(malloc(nranks * sizeof(void *)));
......@@ -501,26 +523,22 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
(uint64_t)0);
volatile uint32_t abortFlag = 0;
struct ncclIpcSocket ipcSock = {0};
IpcSocketHandle ipcSock = {0};
uint64_t opId = 0xdeadcafebeef;
ncclResult_t ret = ncclSuccess;
// All-gather POSIX file descriptors across local ranks.
// NOTE: This cannot be done via MPI_Allgather or other external comm libraries. They mangle
// the file descriptor and prevent cuMemImportFromShareableHandle() from correctly
// interpreting the file. Instead, we use system socket to send/recv the file handle
// without mangling.
NCCLCHECK(ncclIpcSocketInit(&ipcSock, myrank, (uint64_t)opId, &abortFlag));
ipcSocketResult_t ret = ipcSocketSuccess;
// All-gather POSIX file descriptors across local ranks
IPCCHECK(ipcSocketInit(&ipcSock, myrank, (uint64_t)opId, &abortFlag));
for (int p = 1; p < nranks; p++) {
int send_to = (myrank + p) % nranks;
int recv_from = (myrank + nranks - p) % nranks;
comm->_barrier(comm->comm_intra);
NCCLCHECKGOTO(
ncclIpcSocketSendFd(&ipcSock, peerfd[myrank], (myrank + p) % nranks, (uint64_t)opId), ret,
error);
NCCLCHECKGOTO(ncclIpcSocketRecvFd(&ipcSock, &peerfd[(myrank + nranks - p) % nranks]), ret,
error);
IPCCHECKGOTO(ipcSocketSendFd(&ipcSock, peerfd[myrank], send_to, (uint64_t)opId), ret, error);
IPCCHECKGOTO(ipcSocketRecvFd(&ipcSock, &peerfd[recv_from]), ret, error);
}
error:
NCCLCHECK(ncclIpcSocketClose(&ipcSock));
IPCCHECK(ipcSocketClose(&ipcSock));
for (int p = 0; p < nranks; p++) {
if (p != myrank)
......@@ -530,6 +548,8 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
static_cast<CUmemAllocationHandleType>(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
close(peerfd[p]);
}
free(peerfd);
CUdeviceptr ptr;
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemAddressReserve, &ptr, (size_t)(aligned_size * nranks),
(size_t)0, (CUdeviceptr)0, (uint64_t)0);
......@@ -554,12 +574,11 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemSetAccess, ptr, (size_t)(aligned_size * nranks),
const_cast<CUmemAccessDesc *>(&accessDesc), (size_t)1);
if (hndl == 0) CUDACHECK(cudaMemset(comm->gpu_ptrs, 0, aligned_size));
CUDACHECK(
if (hndl == 0) NVTE_CHECK_CUDA(cudaMemset(comm->gpu_ptrs, 0, aligned_size));
NVTE_CHECK_CUDA(
cudaMemcpy((reinterpret_cast<char *>(comm->gpu_ptrs)) + (hndl * nranks * sizeof(void *)),
remptrs, nranks * sizeof(void *), cudaMemcpyHostToDevice));
free(remptrs);
free(peerfd);
comm->memflags[hndl] = UB_MEM_UC_CONTIG | UB_MEM_ALLOCATED;
if (comm->use_mc && comm->mc_maxsize >= comm->mc_offset + aligned_size) {
......@@ -575,29 +594,36 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
}
} else {
assert(comm->nvsize <= 8);
if (alloc) {
NVTE_CHECK_CUDA(cudaMalloc(gpubuff, bytes));
NVTE_CHECK_CUDA(cudaMemset(*gpubuff, 0, bytes));
}
NVTE_CHECK(comm->nvsize <= 8, "CUDA IPC supports only up to 8 GPUs in an NVLink domain.");
cudaIpcMemHandle_t memhndl;
CUDACHECK(cudaIpcGetMemHandle(&memhndl, *gpubuff));
NVTE_CHECK_CUDA(cudaIpcGetMemHandle(&memhndl, *gpubuff));
cudaIpcMemHandle_t *tmp;
comm->_alloc_copy_allgather(reinterpret_cast<void **>(&tmp), reinterpret_cast<void *>(&memhndl),
sizeof(cudaIpcMemHandle_t), comm->comm_intra);
cudaIpcMemHandle_t *tmp =
reinterpret_cast<cudaIpcMemHandle_t *>(malloc(comm->nvsize * sizeof(cudaIpcMemHandle_t)));
comm->_allgather(reinterpret_cast<void *>(tmp), comm->nvsize * sizeof(cudaIpcMemHandle_t),
reinterpret_cast<void *>(&memhndl), sizeof(cudaIpcMemHandle_t),
comm->comm_intra);
for (int i = 0; i < comm->nvsize; i++) {
if (i != comm->nvrank) {
CUDACHECK(cudaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i], // NOLINT(*)
NVTE_CHECK_CUDA(cudaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i], // NOLINT(*)
cudaIpcMemLazyEnablePeerAccess));
}
}
comm->peer_ptr[hndl][comm->nvrank] = *gpubuff;
CUDACHECK(cudaDeviceSynchronize());
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
CUDACHECK(cudaMemcpy(
NVTE_CHECK_CUDA(cudaMemcpy(
reinterpret_cast<char *>(comm->gpu_ptrs) + (hndl * comm->nvsize * sizeof(void *)),
comm->peer_ptr[hndl], comm->nvsize * sizeof(void *), cudaMemcpyHostToDevice));
CUDACHECK(cudaDeviceSynchronize());
comm->_free(tmp);
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
free(tmp);
}
comm->mem_size[hndl] = aligned_size;
......
......@@ -23,15 +23,6 @@
#define MAX_THREADS 1024
#define CUDACHECK(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while (0)
#define ATOMIC_CONSUMER(chunk) \
if (counters) { \
if (threadIdx.x == 0 && blockIdx.x == 0) { \
......@@ -1391,7 +1382,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg5), reinterpret_cast<void *>(&arg6), \
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10)}; \
CUDACHECK(cudaLaunchKernelExC( \
NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, \
reinterpret_cast<void *>(comm->use_rr_kernel ? userbuffers_fp16_sum_inplace_gpu_rr_ag<x> \
: userbuffers_fp16_sum_inplace_gpu_rw_ag<x>), \
......@@ -1416,7 +1407,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \
reinterpret_cast<void *>(&arg11)}; \
CUDACHECK(cudaLaunchKernelExC( \
NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_mc_ag<x>), kernelArgs)); \
}
......@@ -1436,7 +1427,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg5), reinterpret_cast<void *>(&arg6), \
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10)}; \
CUDACHECK(cudaLaunchKernelExC( \
NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_rr_rs<x>), kernelArgs)); \
}
......@@ -1458,7 +1449,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \
reinterpret_cast<void *>(&arg11)}; \
CUDACHECK(cudaLaunchKernelExC( \
NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_mc_rs<x>), kernelArgs)); \
}
......@@ -1481,7 +1472,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \
reinterpret_cast<void *>(&arg13)}; \
CUDACHECK(cudaLaunchKernelExC( \
NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop<x>), \
kernelArgs)); \
}
......@@ -1506,7 +1497,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14)}; \
CUDACHECK(cudaLaunchKernelExC( \
NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, \
reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_fp8<x, fp8type>), \
kernelArgs)); \
......@@ -1532,7 +1523,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14)}; \
CUDACHECK(cudaLaunchKernelExC( \
NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_mc_rs_oop<x>), \
kernelArgs)); \
}
......@@ -1562,7 +1553,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14), \
reinterpret_cast<void *>(&arg15), reinterpret_cast<void *>(&arg16), \
reinterpret_cast<void *>(&arg17), reinterpret_cast<void *>(&arg18)}; \
CUDACHECK(cudaLaunchKernelExC( \
NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, \
reinterpret_cast<void *>( \
userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_atomic_fp8<x, fp8type>), \
......@@ -1588,7 +1579,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \
reinterpret_cast<void *>(&arg13)}; \
CUDACHECK(cudaLaunchKernelExC( \
NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride<x>), \
kernelArgs)); \
}
......@@ -1614,7 +1605,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14), \
reinterpret_cast<void *>(&arg15)}; \
CUDACHECK(cudaLaunchKernelExC( \
NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, \
reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_atomic<x>), \
kernelArgs)); \
......@@ -1641,7 +1632,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14), \
reinterpret_cast<void *>(&arg15)}; \
CUDACHECK( \
NVTE_CHECK_CUDA( \
cudaLaunchKernelExC(&cfg, \
reinterpret_cast<void *>( \
userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_multiatomic<x>), \
......@@ -2206,15 +2197,6 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat
}
}
#define CUDACHECK(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while (0)
// Return TRUE if two ranks share the same NV domain
#define INTRANODE(peer) ((peer / comm->nvsize) == (comm->myrank / comm->nvsize))
......@@ -2259,7 +2241,7 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds
if (comm->use_ce) {
// kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_start_ptr));
CUDACHECK(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
NVTE_CHECK_CUDA(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
// kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_end_ptr));
}
SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream);
......@@ -2269,7 +2251,7 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds
void *kernelArgs[] = {reinterpret_cast<void *>(&arg1), reinterpret_cast<void *>(&arg2),
reinterpret_cast<void *>(&arg3), reinterpret_cast<void *>(&arg4),
reinterpret_cast<void *>(&arg5)};
CUDACHECK(
NVTE_CHECK_CUDA(
cudaLaunchKernelExC(&cfg, reinterpret_cast<void *>(kuserbuffers_pushsend), kernelArgs));
}
}
......@@ -2291,7 +2273,8 @@ void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size
if (comm->use_ce) {
// kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_start_ptr));
CUDACHECK(cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
NVTE_CHECK_CUDA(
cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
// kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_end_ptr));
}
SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream);
......@@ -2323,7 +2306,7 @@ void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12),
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14),
reinterpret_cast<void *>(&arg15)};
CUDACHECK(
NVTE_CHECK_CUDA(
cudaLaunchKernelExC(&cfg, reinterpret_cast<void *>(kuserbuffers_pushsendrecv), kernelArgs));
}
......@@ -2346,7 +2329,8 @@ void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler,
reinterpret_cast<char *>(comm->peer_ptr[dsthandler][send_peerlocal]) + send_offset;
if (comm->use_ce) {
// kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_start_ptr));
CUDACHECK(cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
NVTE_CHECK_CUDA(
cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
// kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_end_ptr));
}
SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream);
......@@ -2379,8 +2363,8 @@ void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler,
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12),
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14),
reinterpret_cast<void *>(&arg15), reinterpret_cast<void *>(&arg16)};
CUDACHECK(cudaLaunchKernelExC(&cfg, reinterpret_cast<void *>(kuserbuffers_pushsendrecv_atomic),
kernelArgs));
NVTE_CHECK_CUDA(cudaLaunchKernelExC(
&cfg, reinterpret_cast<void *>(kuserbuffers_pushsendrecv_atomic), kernelArgs));
}
void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler,
......@@ -2425,7 +2409,7 @@ void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14),
reinterpret_cast<void *>(&arg15), reinterpret_cast<void *>(&arg16),
reinterpret_cast<void *>(&arg17), reinterpret_cast<void *>(&arg18)};
CUDACHECK(cudaLaunchKernelExC(
NVTE_CHECK_CUDA(cudaLaunchKernelExC(
&cfg, reinterpret_cast<void *>(kuserbuffers_pushsendrecv_multiatomic), kernelArgs));
}
......@@ -2451,7 +2435,7 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds
if (!signalonly)
kuserbuffers_inc<<<1, 1, 0, stream>>>(&(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler]));
if (comm->use_ce) {
CUDACHECK(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
NVTE_CHECK_CUDA(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
}
} else {
kuserbuffers_pushrecv<<<1, 1, 0, stream>>>(
......
......@@ -15,39 +15,11 @@
#include <functional>
#include <stdexcept>
#ifdef UB_MPI_BOOTSTRAP
#include <mpi.h>
#include <stdexcept>
#define UB_MPI_CHECK(expr) \
do { \
const int mpicode = (expr); \
if (mpicode != MPI_SUCCESS) { \
char mpimsg[MPI_MAX_ERROR_STRING]; \
int mpilen; \
MPI_Error_string(mpicode, mpimsg, &mpilen); \
std::vector<char> errmsg(1024); \
snprintf(errmsg.data(), errmsg.size(), "%s:%s in function %s: %s", __FILE__, __LINE__, \
__func__, mpimsg); \
throw std::runtime_error(errmsg.data()); \
} \
} while (false)
#include "common/util/logging.h"
#ifdef NVTE_UB_WITH_MPI
#include <mpi.h>
typedef MPI_Comm ExtComm;
void ub_alloc_copy_allgather(void **globaldata, void *localdata, size_t localbytes, ExtComm comm) {
int myrank, nranks;
UB_MPI_CHECK(MPI_Comm_rank(comm, &myrank));
UB_MPI_CHECK(MPI_Comm_size(comm, &nranks));
*globaldata = malloc(nranks * localbytes);
UB_MPI_CHECK(MPI_Allgather(localdata, localbytes, MPI_BYTE, *globaldata, nranks * localbytes,
MPI_BYTE, comm));
}
void ub_barrier(ExtComm comm) { UB_MPI_CHECK(MPI_Barrier(comm)); }
void ub_free(void *ptr) { free(ptr); }
#else
typedef char *ExtComm;
#endif
......@@ -170,14 +142,13 @@ struct communicator {
volatile int tail;
// Abstract communication callbacks to support external bootstrapping (e.g. DL frameworks)
std::function<void(void **, void *, size_t, ExtComm)> _alloc_copy_allgather;
std::function<void(void *, size_t, void *, size_t, ExtComm)> _allgather;
std::function<void(ExtComm)> _barrier;
std::function<void(void *)> _free;
ExtComm comm_world,
comm_inter, // reduction group communicator (subset of the nodes) along GPU rail
comm_intra; // full intranode (all ndev GPUS)
#ifdef UB_MPI_BOOTSTRAP
#ifdef NVTE_UB_WITH_MPI
MPI_Request mpihndl[NVTE_MAX_SHARP];
#endif
......@@ -194,20 +165,19 @@ void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStr
/* creates communicator, allocates all internal buffers if necessary */
int create_communicator_grouped2(
communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode,
int numnodes, std::function<void(void **, void *, size_t, ExtComm)> ext_alloc_copy_allgather,
std::function<void(ExtComm)> ext_barrier, std::function<void(void *)> ext_free, int pipegpus,
int pipenodes, int tensorgpus, int tensornodes);
int numnodes, std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather,
std::function<void(ExtComm)> ext_barrier, int pipegpus, int pipenodes, int tensorgpus,
int tensornodes);
int create_communicator_grouped(
communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode,
int numnodes, std::function<void(void **, void *, size_t, ExtComm)> ext_alloc_copy_allgather,
std::function<void(ExtComm)> ext_barrier, std::function<void(void *)> ext_free, int pipegpus,
int pipenodes);
int numnodes, std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather,
std::function<void(ExtComm)> ext_barrier, int pipegpus, int pipenodes);
int create_communicator(
communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode,
int numnodes, std::function<void(void **, void *, size_t, ExtComm)> ext_alloc_copy_allgather,
std::function<void(ExtComm)> ext_barrier, std::function<void(void *)> ext_free);
int create_communicator(communicator **comm, int myrank, int numranks, int mylocal, int numlocal,
int mynode, int numnodes,
std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather,
std::function<void(ExtComm)> ext_barrier);
int create_communicator_grouped2_mpi(communicator **comm, int pipegpus, int pipenodes,
int tensorgpus, int tensornodes);
......
......@@ -7,6 +7,9 @@ import io
import os
import pickle
import warnings
import socket
import fcntl
import struct
from abc import ABC, abstractmethod
from typing import Dict, Generator, List, Optional, Tuple, Union
from contextlib import contextmanager
......@@ -79,19 +82,109 @@ def get_multi_stream_cublas_workspace() -> List[torch.Tensor]:
def initialize_ub(
shape: list,
tp_group: dist_group_type,
tp_size: int,
use_fp8: bool = False,
dtype: torch.dtype = torch.bfloat16,
ub_cfgs: Optional[dict] = None,
bootstrap_backend: Union[str, torch.distributed.Backend] = None,
) -> None:
"""Initialize communicators for TP comm overlap using userbuffers."""
if not tex.device_supports_multicast():
assert bool(os.getenv("UB_SKIPMC", "0")), (
"CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with "
+ "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead."
)
global _ub_communicators
assert _ub_communicators is None, "UB communicators are already initialized."
_ub_communicators = {}
rank_id = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
tp_id = torch.distributed.get_rank(tp_group)
tp_size = torch.distributed.get_world_size(tp_group)
if tex.ubuf_built_with_mpi():
# Userbuffers will ignore all these values when it is built with MPI, so these are just
# placeholders based on an assumption that tp_size covers all devices in a physical node.
assert torch.distributed.is_mpi_available()
mpi_group = torch.distributed.new_group(backend="mpi")
world_rank = torch.distributed.get_rank(mpi_group)
world_size = torch.distributed.get_world_size(mpi_group)
local_rank = world_rank % tp_size
local_size = tp_size
node_id = world_rank // tp_size
num_nodes = world_size // tp_size
ub_callbacks = tex.UbufBootstrapCallbacks()
else:
assert (
torch.distributed.is_initialized()
), "torch.distributed must be initialized before Userbuffers"
if bootstrap_backend is None:
bootstrap_backend = "nccl"
if torch.distributed.is_gloo_available():
bootstrap_backend = "gloo"
elif torch.distributed.is_mpi_available():
bootstrap_backend = "mpi"
else:
assert bootstrap_backend in ["gloo", "mpi", "nccl"]
world_group = torch.distributed.new_group(backend=bootstrap_backend)
world_rank = torch.distributed.get_rank(world_group)
world_size = torch.distributed.get_world_size(world_group)
if world_rank == 0:
print(
f'!!! [NVTE] Bootstrapping Userbuffers with backend="{bootstrap_backend}"\n',
end="",
flush=True,
)
# Construct an intra-node communicator based on global ranks that share the same hostname
# NOTE: If the user specified a valid network interface for NCCL or GLOO, use the host
# address on that interface instead of the hostname. This can help avoid issues when
# different hosts have the same hostname on Kubernetes clusters.
hostname = socket.gethostname()
ifname = os.getenv(
"NVTE_UB_SOCKET_IFNAME",
os.getenv("NCCL_SOCKET_IFNAME", os.getenv("GLOO_SOCKET_IFNAME")),
)
if ifname is not None:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
hostname = socket.inet_ntoa(
fcntl.ioctl(
s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8"))
)[20:24]
)
except OSError as err:
raise OSError(f"Invalid network interface: {ifname}") from err
hostnames = [None for _ in range(world_size)]
torch.distributed.all_gather_object(hostnames, hostname, world_group)
intra_node_ranks = []
for i, host in enumerate(hostnames):
if host == hostname:
intra_node_ranks.append(i)
if len(intra_node_ranks) == world_size:
intra_node_group = world_group
local_rank = world_rank
local_size = world_size
intra_node_ranks = list(range(world_size))
else:
intra_node_group = torch.distributed.new_group(
backend=bootstrap_backend, ranks=intra_node_ranks
)
local_rank = torch.distributed.get_rank(intra_node_group)
local_size = torch.distributed.get_world_size(intra_node_group)
node_id = world_rank // local_size
num_nodes = world_size // local_size
if local_rank == 0:
print(
f"!!! [NVTE] Number of physical nodes: {num_nodes}\n"
+ f"!!! [NVTE] Global ranks on node {node_id}: {intra_node_ranks}\n",
end="",
flush=True,
)
ub_callbacks = tex.UbufBootstrapCallbacks(world_group, intra_node_group)
# Increase the workspace by the number of maximum concurrent streams
global _cublas_workspace
......@@ -127,6 +220,23 @@ def initialize_ub(
return method
raise KeyError(f"Given layer name {name} does not exist.")
def get_default_config(name):
method = get_method(name)
is_reduce_scatter = name in layers_reduce_scatter_overlap
default_cfg = {
"method": method,
"is_reduce_scatter": is_reduce_scatter,
"num_sm": 1 if method == "ring_exchange" else 16,
"cga_size": 1 if method == "ring_exchange" else 2,
"set_sm_margin": False,
"num_splits": 4 if method == "pipeline" else tp_size,
"aggregate": False,
"atomic_gemm": False,
"use_ce": True,
"fp8_buf": name in layers_all_gather_overlap,
}
return default_cfg
def add_ub(
name: str,
method: str,
......@@ -180,53 +290,43 @@ def initialize_ub(
if method == "ring_exchange":
ub_obj = tex.UbufP2PCommOverlap(
sample_buffer, # Sample userbuffer
rank_id, # Rank id
world_rank, # World rank
world_size, # World size
tp_id, # TP id
tp_size, # TP size
local_rank, # Rank within the node
local_size, # Number of ranks/GPUs per node
node_id, # Node ID
num_nodes, # Number of nodes
tp_size, # Tensor-parallel group size (may be different than local_size)
num_sm, # Number of communication SMs
cga_size, # CGA cluster size
set_sm_margin, # Set SM margin
aggregate, # Aggregate 2X GEMM chunks
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams
is_reduce_scatter, # overlap with reduce scatter
atomic_gemm, # use a single GEMM with atomic-counters
use_ce, # use copy engine for P2P communications
torch.Tensor(), # empty tensor to pass to counters
is_reduce_scatter, # Overlap with reduce scatter
atomic_gemm, # Use a single GEMM with atomic-counters
use_ce, # Use copy engine for P2P communications
ub_callbacks,
)
else:
ub_obj = tex.UbufCommOverlap(
sample_buffer, # Sample userbuffer
rank_id, # Rank id
world_rank, # World rank
world_size, # World size
tp_id, # TP id
tp_size, # TP size
local_rank, # Rank within the node
local_size, # Number of ranks/GPUs per node
node_id, # Node ID
num_nodes, # Number of nodes
tp_size, # Tensor-parallel group size (may be different than local_size)
num_sm, # Number of communication SMs
cga_size, # CGA cluster size
num_splits, # Number of communication splits
set_sm_margin, # Set SM margin
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams
atomic_gemm, # use a single GEMM with atomic-counters
torch.Tensor(), # empty tensor to pass to counters
atomic_gemm, # Use a single GEMM with atomic-counters
ub_callbacks,
)
_ub_communicators[name] = ub_obj
def alloc_copy_allgather_callback(local_data: torch.Tensor, group: str) -> torch.Tensor:
pg = None if group == "world" else tp_group
global_size = local_data.numel() * torch.distributed.get_world_size(pg)
global_data = torch.zeros(global_size, dtype=local_data.dtype, device="cuda")
torch.distributed.all_gather_into_tensor(global_data, local_data.cuda(), group=pg)
return global_data.cpu()
def barrier_callback(group: str) -> None:
pg = None if group == "world" else tp_group
torch.distributed.barrier(group=pg)
def free_callback(data: torch.Tensor) -> None:
data.data = torch.Tensor()
tex.set_ubuf_bootstrap_callbacks(alloc_copy_allgather_callback, barrier_callback, free_callback)
if ub_cfgs is not None:
for name in dgrad_reduce_scatter_overlap:
if name in ub_cfgs and "method" in ub_cfgs[name] and ub_cfgs[name]["method"] != "bulk":
......@@ -238,48 +338,18 @@ def initialize_ub(
methods["pipeline"].append(name)
for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]:
ub_cfg = get_default_config(name)
if ub_cfgs is not None and name in ub_cfgs:
ub_cfg = ub_cfgs[name]
method = ub_cfg.get("method", get_method(name))
num_sm = ub_cfg.get("num_sm", 1 if method == "ring_exchange" else 16)
cga_size = ub_cfg.get("cga_size", 1 if method == "ring_exchange" else 2)
num_splits = ub_cfg.get("num_splits", 4 if method == "pipeline" else 0)
set_sm_margin = ub_cfg.get("set_sm_margin", 0)
aggregate = ub_cfg.get("aggregate", 0)
atomic_gemm = ub_cfg.get("atomic_gemm", 0)
use_ce = ub_cfg.get("use_ce", True)
is_reduce_scatter = 1 if name in layers_reduce_scatter_overlap else 0
# Support FP8 userbuffer when (1) AllGather and (2) FP8-GEMM output ReduceScatter
fp8_buf = (name in layers_all_gather_overlap) or (
ub_cfg.get("fp8_buf", False) and name in methods["pipeline"]
)
add_ub(
name,
method,
is_reduce_scatter,
num_sm,
cga_size,
set_sm_margin,
num_splits,
aggregate,
atomic_gemm,
use_ce,
fp8_buf,
)
else:
method = get_method(name)
add_ub(
name,
method=method,
is_reduce_scatter=1 if name in layers_reduce_scatter_overlap else 0,
num_splits=4 if method == "pipeline" else 0,
fp8_buf=name in layers_all_gather_overlap,
ub_cfgs[name].get("fp8_buf", False) and name in methods["pipeline"]
)
ub_cfg.update(ub_cfgs[name])
ub_cfg["fp8_buf"] = fp8_buf
add_ub(name, **ub_cfg)
def get_ub(name: str):
"""Get userbuffer communicator corresponding to give key."""
global _ub_communicators
assert _ub_communicators is not None, "UB manager is not initialized."
assert name in _ub_communicators, f"UB for {name} is not registered."
return _ub_communicators[name]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment