Unverified Commit 5ee98175 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[PyTorch] Fixing hang in `initialize_ub()` for multi-node runs after PR901...


[PyTorch] Fixing hang in `initialize_ub()` for multi-node runs after PR901 removal of MPI-dependence (#986)

* Re-implementing PR901 (removing MPI-dependence in Userbuffers) with multi-node fixes

* passing data-parallel rank/size info from torch.distributed to userbuffers
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* multi-node example working with UB_SKIPMC=1 but not with multicast
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed multi-node hang in initialize_ub(), updated comm+GEMM overlap example to support multi-node mixed tensor/data parallelism, added README
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed use case when Userbuffers is asked to allocate the TP overlap buffer with UB_SKIPMC=1
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected example problem to set device by local ordinal instead of global process rank
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* double-free fix in userbuffers destructor
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* removed unnecessary and incorrect torch.cuda.set_device(...)
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected inter-node ranks logic
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* generalized node ID logic in initialize_ub to handle arbitrary world rank layouts within node
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added single-node comm+GEMM overlap unit tests
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* LayerNormMLP example confirmed working with 2 nodes on Eos
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* unit test cleanup
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected DP group ranks logic in LNMLP comm+GEMM overlap example
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected enums in unit test
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed incorrect Ubuf object init signature
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* switched default backend for Userbuffer bootstrapping to Gloo with MPI and NCCL fallbacks, and initialize_ub option to manually select backend
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed all comm+GEMM overlap unit tests
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected all_gather use for Gloo backend
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* changed userbuffers allgather callback to always use all_gather() instead of all_gather_into_tensor()
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* restored and verified old MPI-based bootstrapping via NVTE_UB_WITH_MPI=1 option at compile time
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* disabled scoped GIL release for comm+GEMM overlap algorithms
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* avoid dist.init_device_mesh in comm+GEMM overlap example to support older PyTorch versions
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* applied RS overlap FP8 fix from PR1004
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed segfault in Userbuffers destructor
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected comm+GEMM overlap unit test arguments
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed unit test run command for when Userbuffers is compiled with MPI
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Refactored torch.distributed collectives into pure C++ callbacks
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 71124c31
...@@ -77,14 +77,14 @@ def setup_pytorch_extension( ...@@ -77,14 +77,14 @@ def setup_pytorch_extension(
# Libraries # Libraries
library_dirs = [] library_dirs = []
libraries = [] libraries = []
if os.getenv("UB_MPI_BOOTSTRAP"): if os.getenv("NVTE_UB_WITH_MPI"):
assert ( assert (
os.getenv("MPI_HOME") is not None os.getenv("MPI_HOME") is not None
), "MPI_HOME must be set when compiling with UB_MPI_BOOTSTRAP=1" ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
mpi_home = Path(os.getenv("MPI_HOME")) mpi_home = Path(os.getenv("MPI_HOME"))
include_dirs.append(mpi_home / "include") include_dirs.append(mpi_home / "include")
cxx_flags.append("-DUB_MPI_BOOTSTRAP") cxx_flags.append("-DNVTE_UB_WITH_MPI")
nvcc_flags.append("-DUB_MPI_BOOTSTRAP") nvcc_flags.append("-DNVTE_UB_WITH_MPI")
library_dirs.append(mpi_home / "lib") library_dirs.append(mpi_home / "lib")
libraries.append("mpi") libraries.append("mpi")
......
# Overlapping Communication with GEMM in TransformerEngine Modules
## Requirements
- Tensor-parallel GPUs must be on a single node, and connected over NVLink/NVSwitch.
- `CUDA_DEVICE_MAX_CONNECTIONS=1` must be enabled in the environment.
- For best performance, point-to-point communication via _CUDA Multicast_ needs CUDA Toolkit 12.0+
and CUDA driver 535+ on devices with compute capability 9.0 or newer.
- Devices older than compute capability 9.0 require `UB_SKIPMC=1` in the environment in order fall
back on a less performant implementation based on CUDA Inter-Process Communication (IPC) handles.
## Examples
### Single node, tensor-parallel LayerNormMLP:
Forward and backward passes with layer weights distributed over all GPUs in a single node.
```bash
$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) ln_mlp_with_overlap.py
# Sample output on 8x H100s:
# [rank0:node0] |-- Created tensor-parallel group: [0, 1, 2, 3, 4, 5, 6, 7]
# !!! [UB] Create UbufP2PCommOverlap Communicator
# UB_TIMEOUT is set to 110 sec, 217800000000 cycles, freq: 1980000khz
# MC initialized succesfully, window size = 549755813888
# !!! [UBP2P] Register UBuf 1
# !!! [UBP2P] Register UBuf 2
# !!! [UBP2P] Register UBuf 3
# !!! [UBP2P] Register UBuf 4
# !!! [UB] Register UBuf 5
# !!! [UBP2P] Register UBuf 6
# !!! [UB] Register UBuf 7
# !!! [UB] Register UBuf 8
# !!! [UBP2P] Register UBuf 9
# !!! [UB] Register UBuf 10
# [rank0:node0] Iter 1
# [rank0:node0] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank0:node0] |-- Compute loss
# [rank0:node0] |-- Backward pass
# [rank0:node0] |-- Optimizer step
# [rank0:node0] Iter 2
# [rank0:node0] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank0:node0] |-- Compute loss
# [rank0:node0] |-- Backward pass
# [rank0:node0] |-- Optimizer step
# [rank0:node0] Iter 3
# [rank0:node0] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank0:node0] |-- Compute loss
# [rank0:node0] |-- Backward pass
# [rank0:node0] |-- Optimizer step
# [rank0:node0] Iter 4
# [rank0:node0] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank0:node0] |-- Compute loss
# [rank0:node0] |-- Backward pass
# [rank0:node0] |-- Optimizer step
# [rank0:node0] Iter 5
# [rank0:node0] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank0:node0] |-- Compute loss
# [rank0:node0] |-- Backward pass
# [rank0:node0] |-- Optimizer step
```
### Single node, mixed data- and tensor-parallel LayerNormMLP:
Uses `torch.nn.parallel.DistributedDataParallel` for replicatin the model across 2 tensor-parallel
groups in a single node.
```bash
$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) ln_mlp_overlap.py --num-replicas 2
# Sample output on 8x H100s:
# [rank0:node0] |-- Created tensor-parallel group: [0, 1, 2, 3]
# [rank4:node1] |-- Created tensor-parallel group: [4, 5, 6, 7]
# [rank0:node0] |-- Created data-parallel group: [0, 4]
# [rank3:node1] |-- Created data-parallel group: [3, 7]
# [rank1:node1] |-- Created data-parallel group: [1, 5]
# [rank2:node0] |-- Created data-parallel group: [2, 6]
# !!! [UB] Create UbufP2PCommOverlap Communicator
# UB_TIMEOUT is set to 110 sec, 217800000000 cycles, freq: 1980000khz
# MC initialized succesfully, window size = 549755813888
# !!! [UBP2P] Register UBuf 1
# !!! [UBP2P] Register UBuf 2
# !!! [UBP2P] Register UBuf 3
# !!! [UBP2P] Register UBuf 4
# !!! [UB] Register UBuf 5
# !!! [UBP2P] Register UBuf 6
# !!! [UB] Register UBuf 7
# !!! [UB] Register UBuf 8
# !!! [UBP2P] Register UBuf 9
# !!! [UB] Register UBuf 10
# [rank4:node1] Iter 1
# [rank0:node0] Iter 1
# [rank0:node0] |-- Generate random input batch
# [rank4:node1] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank4:node1] |-- Forward pass
# [rank4:node1] |-- Compute loss
# [rank0:node0] |-- Compute loss
# [rank0:node0] |-- Backward pass
# [rank4:node1] |-- Backward pass
# [rank4:node1] |-- Optimizer step
# [rank0:node0] |-- Optimizer step
# [rank4:node1] Iter 2
# [rank0:node0] Iter 2
# [rank0:node0] |-- Generate random input batch
# [rank4:node1] |-- Generate random input batch
# [rank4:node1] |-- Forward pass
# [rank0:node0] |-- Forward pass
# [rank4:node1] |-- Compute loss
# [rank0:node0] |-- Compute loss
# [rank4:node1] |-- Backward pass
# [rank0:node0] |-- Backward pass
# [rank4:node1] |-- Optimizer step
# [rank0:node0] |-- Optimizer step
# [rank4:node1] Iter 3
# [rank0:node0] Iter 3
# [rank0:node0] |-- Generate random input batch
# [rank4:node1] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank4:node1] |-- Forward pass
# [rank4:node1] |-- Compute loss
# [rank0:node0] |-- Compute loss
# [rank4:node1] |-- Backward pass
# [rank0:node0] |-- Backward pass
# [rank0:node0] |-- Optimizer step
# [rank4:node1] |-- Optimizer step
# [rank0:node0] Iter 4
# [rank4:node1] Iter 4
# [rank0:node0] |-- Generate random input batch
# [rank4:node1] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank4:node1] |-- Forward pass
# [rank0:node0] |-- Compute loss
# [rank4:node1] |-- Compute loss
# [rank4:node1] |-- Backward pass
# [rank0:node0] |-- Backward pass
# [rank4:node1] |-- Optimizer step
# [rank0:node0] |-- Optimizer step
# [rank4:node1] Iter 5
# [rank0:node0] Iter 5
# [rank0:node0] |-- Generate random input batch
# [rank4:node1] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank4:node1] |-- Forward pass
# [rank0:node0] |-- Compute loss
# [rank4:node1] |-- Compute loss
# [rank0:node0] |-- Backward pass
# [rank4:node1] |-- Backward pass
# [rank4:node1] |-- Optimizer step
# [rank0:node0] |-- Optimizer step
```
**NOTE:** To run with Fp8 compute on supporting hardware, add the `--fp8` flag to the commands
shown above.
...@@ -6,17 +6,22 @@ ...@@ -6,17 +6,22 @@
import os import os
import sys import sys
import subprocess import socket
import argparse import argparse
import warnings
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling from transformer_engine.common.recipe import Format, DelayedScaling
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
def parse_args(argv=None, namespace=None):
def _parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Test a te.LayerNormMLP module with GEMM+comm overlap via Userbuffers." description="Test a te.LayerNormMLP module with GEMM+comm overlap via Userbuffers."
) )
...@@ -47,63 +52,182 @@ def parse_args(argv=None, namespace=None): ...@@ -47,63 +52,182 @@ def parse_args(argv=None, namespace=None):
default=False, default=False,
help="Disable the comm+GEMM overlap.", help="Disable the comm+GEMM overlap.",
) )
parser.add_argument("-v", "--verbose", action="store_true", default=False) parser.add_argument(
return parser.parse_args(argv, namespace) "--num-replicas", type=int, default=1, help="Number of data-parallel model replicas."
)
parser.add_argument(
"--tcp-init",
action="store_true",
default=False,
help="Initialize torch.distributed with TcpStore.",
)
parser.add_argument(
"--bind-to-device",
action="store_true",
default=False,
help="Initialize torch.distributed with `device_id` to bind each rank to a single device.",
)
parser.add_argument(
"--bootstrap-backend",
type=str.lower,
default="nccl",
choices=["gloo", "mpi", "nccl"],
help="Communications backend for host tensor collectives during Userbuffers bootstrapping.",
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
default=False,
help="Print out from every rank instead of just the root rank of relevant process groups.",
)
parser.add_argument(
"--debug",
action="store_true",
default=False,
help="Print out additional debug information.",
)
args = parser.parse_args(argv, namespace)
return args
def _train(opts):
if "OMPI_COMM_WORLD_SIZE" in os.environ:
# Execution with `mpirun -np N`
WORLD_RANK = int(os.getenv("OMPI_COMM_WORLD_RANK", "0"))
WORLD_SIZE = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1"))
opts.tcp_init = True
opts.bind_to_device = True
opts.bootstrap_backend = "mpi"
elif "TORCHELASTIC_RUN_ID" in os.environ:
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
else:
raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!")
NUM_NODES = WORLD_SIZE // LOCAL_SIZE
def dist_print(msg, group=None, end="\n", debug=False):
if debug and not opts.debug:
return
group = dist.new_group() if group is None else group
group_rank = dist.get_rank(group)
group_size = dist.get_world_size(group)
all_ranks = dist.get_process_group_ranks(group)
ranks_skip = all_ranks[1] - all_ranks[0] > 1
group_id = WORLD_RANK % group_size if ranks_skip else WORLD_RANK // group_size
if group_rank == 0 or opts.verbose:
print(f"[rank{WORLD_RANK}:node{group_id}] {msg}{end}", end="", flush=True)
dist.barrier(group)
# Initialize torch.distributed global process group and get DP/TP groups
torch.cuda.set_device(LOCAL_RANK)
dist_init_kwargs = {
"backend": "nccl",
"rank": WORLD_RANK,
"world_size": WORLD_SIZE,
}
if opts.tcp_init or NUM_NODES > 1:
if NUM_NODES > 1:
assert (
"MASTER_ADDR" in os.environ
), "Multi-node run requires MASTER_ADDR to be set in the environment."
MASTER_ADDR = os.getenv("MASTER_ADDR", socket.gethostbyname(socket.gethostname()))
MASTER_PORT = os.getenv("MASTER_PORT", "1234")
dist_init_kwargs["init_method"] = f"tcp://{MASTER_ADDR}:{MASTER_PORT}"
if opts.bind_to_device or opts.bootstrap_backend == "nccl":
dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}")
assert dist.is_nccl_available()
dist.init_process_group(**dist_init_kwargs)
nccl_world = dist.new_group(backend="nccl")
dist_print(f"Initialized default NCCL process group with {WORLD_RANK} GPUs", nccl_world)
def train(opts): # Figure out process groups for tensor- and data-parallelism (if any)
WORLD_RANK = int(os.getenv("RANK")) if NUM_NODES > 1:
WORLD_SIZE = int(os.getenv("WORLD_SIZE")) # Create a list of world ranks on this node
hostnames = [None for _ in range(WORLD_SIZE)]
hostname = socket.gethostname()
dist.all_gather_object(hostnames, hostname)
node_ranks = []
for i, host in enumerate(hostnames):
if host == hostname:
node_ranks.append(i)
def dist_print(msg, end="\n", all_ranks=False): if opts.num_replicas > 1:
if WORLD_RANK == 0 or all_ranks: # Split node ranks into multiple replicas
print(f"[RANK-{WORLD_RANK}] {msg}", end=end) assert len(node_ranks) % opts.num_replicas == 0
tp_size = len(node_ranks) // opts.num_replicas
found_replica = False
for replica in range(opts.num_replicas):
start = replica * tp_size
end = start + tp_size
tp_ranks = node_ranks[start:end]
if WORLD_RANK in tp_ranks:
found_replica = True
break
assert found_replica
else:
# The entire node is the tensor-parallel group
tp_ranks = node_ranks
# Seed RNG tp_group = dist.new_group(backend="nccl", ranks=tp_ranks)
torch.cuda.set_device(WORLD_RANK) tp_size = dist.get_world_size(tp_group)
torch.manual_seed(opts.seed + WORLD_RANK) tp_rank = dist.get_rank(tp_group)
torch.cuda.manual_seed(opts.seed + WORLD_RANK)
# Initialize torch.distributed global process group and get TP group # Data-parallelism across TP groups
dist.init_process_group( dp_start = tp_rank
backend="nccl", dp_end = dp_start + WORLD_SIZE
rank=WORLD_RANK, dp_ranks = list(range(dp_start, dp_end, tp_size))
world_size=WORLD_SIZE, dp_group = dist.new_group(backend="nccl", ranks=dp_ranks)
device_id=torch.device(f"cuda:{WORLD_RANK}"),
) else:
tp_group = dist.new_group(backend="nccl") if opts.num_replicas > 1:
# Mixed data- and tensor-parallelism on a single node
# NOTE: Avoid dist.init_device_mesh() to support older PyTorch versions
all_ranks = torch.tensor(list(range(LOCAL_SIZE)), dtype=torch.uint8, device="cpu")
mesh2d = all_ranks.reshape((opts.num_replicas, LOCAL_SIZE // opts.num_replicas))
node_idx = (mesh2d == LOCAL_RANK).nonzero().squeeze().tolist()
tp_ranks = mesh2d[node_idx[0], :].tolist()
tp_group = dist.new_group(backend="nccl", ranks=tp_ranks)
dp_ranks = mesh2d[:, node_idx[1]].tolist()
dp_group = dist.new_group(backend="nccl", ranks=dp_ranks)
else:
dp_group = None
tp_group = nccl_world
tp_rank = dist.get_rank(tp_group)
tp_size = dist.get_world_size(tp_group) tp_size = dist.get_world_size(tp_group)
dist_print(
f"Created tensor-parallel group: {dist.get_process_group_ranks(tp_group)}",
group=tp_group,
)
if dp_group is not None:
dist_print(
f"Created data-parallel group: {dist.get_process_group_ranks(dp_group)}",
group=dp_group,
)
# Intialize userbuffers # Intialize userbuffers
ag_cfg = { # Ring-exchange All-Gather overlap for fc1_fprop and fc2_dgrad
"method": "ring_exchange",
"num_splits": 8,
"num_sm": 1,
"set_sm_margin": False,
}
rs_cfg = { # Reduce-scatter overlap for fc1_dgrad and fc2_fprop
"method": "ring_exchange",
"num_splits": 4,
"num_sm": 1,
"set_sm_margin": True,
}
hidden_size = opts.num_heads * opts.head_dim hidden_size = opts.num_heads * opts.head_dim
batched_size = opts.seq_length * opts.batch_size batched_size = opts.seq_length * opts.batch_size
if not opts.no_comm_overlap: if not opts.no_comm_overlap:
te.initialize_ub( te.module.base.initialize_ub(
[batched_size, hidden_size], [batched_size, hidden_size],
tp_group, tp_size,
use_fp8=opts.fp8, use_fp8=opts.fp8,
dtype=torch.bfloat16, dtype=torch.bfloat16,
ub_cfgs={ bootstrap_backend=opts.bootstrap_backend,
"fc1_fprop": ag_cfg,
"fc1_dgrad": rs_cfg,
"fc2_fprop": rs_cfg,
"fc2_dgrad": ag_cfg,
},
) )
# # Initialize the fused LayerNorm + Multi-layer Perceptron module
torch.manual_seed(opts.seed + tp_rank)
torch.cuda.manual_seed(opts.seed + tp_rank)
model = te.LayerNormMLP( model = te.LayerNormMLP(
hidden_size, hidden_size,
opts.mlp_expansion_factor * hidden_size, opts.mlp_expansion_factor * hidden_size,
...@@ -114,11 +238,14 @@ def train(opts): ...@@ -114,11 +238,14 @@ def train(opts):
set_parallel_mode=True, set_parallel_mode=True,
sequence_parallel=True, # this is required for comm+GEMM overlap sequence_parallel=True, # this is required for comm+GEMM overlap
seq_length=opts.seq_length, seq_length=opts.seq_length,
micro_batch_size=opts.batch_size,
ub_overlap_rs_dgrad=not opts.no_comm_overlap,
ub_overlap_rs=not opts.no_comm_overlap, ub_overlap_rs=not opts.no_comm_overlap,
ub_overlap_ag=not opts.no_comm_overlap, ub_overlap_ag=not opts.no_comm_overlap,
ub_overlap_rs_dgrad=not opts.no_comm_overlap,
ub_bulk_dgrad=False,
ub_bulk_wgrad=not opts.no_comm_overlap,
) )
if dp_group is not None:
model = DistributedDataParallel(model, process_group=dp_group)
# Initialize optimizer with model parameters # Initialize optimizer with model parameters
optim = torch.optim.Adam(model.parameters(), lr=0.0001) optim = torch.optim.Adam(model.parameters(), lr=0.0001)
...@@ -128,10 +255,11 @@ def train(opts): ...@@ -128,10 +255,11 @@ def train(opts):
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
# Start dummy "training" iterations # Start dummy "training" iterations
dist_print("Starting training iterations...", nccl_world)
for i in range(opts.num_iters): for i in range(opts.num_iters):
dist_print(f"Iter {i+1}", all_ranks=opts.verbose) dist_print(f" Iter {i+1}", tp_group, debug=True)
dist_print("|-- Generate random input batch", all_ranks=opts.verbose) dist_print(" |-- Generate random input batch", tp_group, debug=True)
x = torch.rand( x = torch.rand(
(opts.seq_length // tp_size, opts.batch_size, hidden_size), (opts.seq_length // tp_size, opts.batch_size, hidden_size),
dtype=torch.bfloat16, dtype=torch.bfloat16,
...@@ -139,30 +267,29 @@ def train(opts): ...@@ -139,30 +267,29 @@ def train(opts):
requires_grad=True, requires_grad=True,
) )
dist_print("|-- Forward pass", all_ranks=opts.verbose) dist_print(" |-- Forward pass", tp_group, debug=True)
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=tp_group): with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world):
y = model(x) y = model(x)
dist_print("|-- Compute loss", all_ranks=opts.verbose) dist_print(" |-- Compute loss", tp_group, debug=True)
loss = y.flatten().sum() loss = y.flatten().sum()
dist_print("|-- Backward pass", all_ranks=opts.verbose) dist_print(" |-- Backward pass", tp_group, debug=True)
loss.backward() loss.backward()
dist_print("|-- Optimizer step", all_ranks=opts.verbose) dist_print(" |-- Optimizer step", tp_group, debug=True)
optim.step() optim.step()
te.destroy_ub() torch.cuda.synchronize()
dist_print("Finished training!")
te.module.base.destroy_ub()
dist_print("Destroying all process groups...", debug=True)
dist.destroy_process_group() dist.destroy_process_group()
if opts.debug and WORLD_RANK == 0:
print("Exiting...\n", end="", flush=True)
return 0
if __name__ == "__main__": if __name__ == "__main__":
if "TORCHELASTIC_RUN_ID" in os.environ.keys(): sys.exit(_train(_parse_args()))
args = parse_args()
train(args)
else:
subprocess.run(
["torchrun", f"--nproc-per-node={torch.cuda.device_count()}", *sys.argv],
env=os.environ,
check=True,
)
os._exit(0)
This diff is collapsed.
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import subprocess
from pathlib import Path
import pytest
import torch
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
RNG_SEED: int = 1234
SEQ_LENGTH: int = 2024
BATCH_SIZE: int = 2
NUM_HEADS: int = 64
HEAD_DIM: int = 128
TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(torch.cuda.device_count(), 4)
LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"]
if tex.ubuf_built_with_mpi():
LAUNCH_CMD = ["mpirun", "-np", str(NUM_PROCS), "--oversubscribe", "--quiet", "python"]
# Fall back on CUDA IPC if the platform does not support CUDA multicast
if not tex.device_supports_multicast():
os.environ["UB_SKIPMC"] = "1"
# Force GPU kernels to launch in the order they're executed by the host CPU
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
@pytest.mark.skipif(NUM_PROCS < 2, reason="Comm+GEMM overlap requires at least 2 GPUs.")
@pytest.mark.parametrize(
"fp8,p2p,comm_type,aggregate,atomic,bulk",
[
# FP8, P2P, Type, Aggregate, Atomic, Bulk
(False, True, "AG", False, False, False),
(False, True, "AG", True, False, False),
(True, True, "AG", False, False, False),
(True, True, "AG", True, False, False),
(False, False, "RS", False, False, False),
(False, True, "RS", False, False, False),
(True, False, "RS", False, False, False),
(True, True, "RS", False, False, False),
(True, False, "RS", False, True, False),
(True, True, "RS", False, True, False),
(False, False, "AG", False, False, True),
(False, False, "RS", False, False, True),
],
ids=[
" AG -> SPLIT GEMM | BF16 | RING-EXCHANGE ",
" AG -> SPLIT GEMM | BF16 | RING-EXCHANGE (2X AGGREGATED) ",
" AG -> SPLIT GEMM | FP8 | RING-EXCHANGE ",
" AG -> SPLIT GEMM | FP8 | RING-EXCHANGE (2X AGGREGATED) ",
" SPLIT GEMM -> RS | BF16 | PIPELINE ",
" SPLIT GEMM -> RS | BF16 | RING-EXCHANGE ",
" SPLIT GEMM -> RS | FP8 | PIPELINE ",
" SPLIT GEMM -> RS | FP8 | RING-EXCHANGE ",
" ATOMIC GEMM -> RS | FP8 | PIPELINE ",
" ATOMIC GEMM -> RS | FP8 | RING-EXCHANGE ",
" BULK AG & GEMM | BF16 | PIPELINE ",
" BULK RS & GEMM | BF16 | PIPELINE ",
],
)
def test_gemm_with_overlap(fp8, p2p, comm_type, aggregate, atomic, bulk):
"""
Test comm+GEMM overlap algorithms with direct calls to
te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm
"""
test_path = TEST_ROOT / "run_gemm_with_overlap.py"
test_cmd = (
LAUNCH_CMD
+ [str(test_path)]
+ [
"--check-numerics",
f"--seed={RNG_SEED}",
f"--seq-length={SEQ_LENGTH}",
f"--batch-size={BATCH_SIZE}",
f"--num-heads={NUM_HEADS}",
f"--head-dim={HEAD_DIM}",
f"--comm-type={comm_type}",
]
)
if bulk:
test_cmd.append("--bulk-overlap")
else:
if fp8:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
test_cmd.append("--fp8")
if p2p:
test_cmd.append("--p2p")
if aggregate:
test_cmd.append("--aggregate")
if atomic:
if torch.cuda.get_device_properties(0).major < 9:
pytest.skip("Device compute capability 9.0 or higher required for Atomic GEMM.")
test_cmd.append("--atomic")
output = subprocess.run(test_cmd, env=os.environ, text=True, capture_output=True, check=False)
assert "NUMERICAL CHECK PASSED" in str(output)
...@@ -206,9 +206,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -206,9 +206,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv) .def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv)
.def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history); .def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history);
// Communication functions to initialize Userbuffers communicators m.def("device_supports_multicast", &ubuf::device_supports_multicast,
// Note: Callbacks are not called, so safe to release GIL. py::call_guard<py::gil_scoped_release>());
m.def("set_ubuf_bootstrap_callbacks", &ubuf::set_ubuf_bootstrap_callbacks,
m.def("ubuf_built_with_mpi", &ubuf::ubuf_built_with_mpi,
py::call_guard<py::gil_scoped_release>());
py::class_<ubuf::UbufBootstrapCallbacks>(m, "UbufBootstrapCallbacks")
.def(py::init<>(), py::call_guard<py::gil_scoped_release>())
.def(py::init<c10d::ProcessGroup *, c10d::ProcessGroup *>(),
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
py::enum_<ubuf::UBOverlapAlgo>(m, "UbufOverlapAlgo") py::enum_<ubuf::UBOverlapAlgo>(m, "UbufOverlapAlgo")
...@@ -225,8 +231,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -225,8 +231,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// communicator with Python functions (e.g. PyTorch distributed // communicator with Python functions (e.g. PyTorch distributed
// communication) // communication)
py::class_<ubuf::UbufCommOverlap>(m, "UbufCommOverlap") py::class_<ubuf::UbufCommOverlap>(m, "UbufCommOverlap")
.def(py::init<torch::Tensor&, int, int, int, int, int, int, int, bool, int, bool, .def(py::init<torch::Tensor &, int, int, int, int, int, int, int, int, int, int, bool, int,
torch::Tensor>()) bool, ubuf::UbufBootstrapCallbacks &>(),
py::call_guard<py::gil_scoped_release>())
.def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap, .def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs, .def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs,
...@@ -250,8 +257,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -250,8 +257,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// communicator with Python functions (e.g. PyTorch distributed // communicator with Python functions (e.g. PyTorch distributed
// communication) // communication)
py::class_<ubuf::UbufP2PCommOverlap>(m, "UbufP2PCommOverlap") py::class_<ubuf::UbufP2PCommOverlap>(m, "UbufP2PCommOverlap")
.def(py::init<torch::Tensor&, int, int, int, int, int, int, bool, bool, int, bool, bool, bool, .def(py::init<torch::Tensor &, int, int, int, int, int, int, int, int, int, bool, bool, int,
torch::Tensor>()) bool, bool, bool, ubuf::UbufBootstrapCallbacks &>(),
py::call_guard<py::gil_scoped_release>())
.def("split_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_ag, .def("split_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_ag,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("split_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_rs, .def("split_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_rs,
......
...@@ -7,66 +7,82 @@ ...@@ -7,66 +7,82 @@
#include "ipcsocket.h" #include "ipcsocket.h"
#include <errno.h> #include <errno.h>
#include <stdarg.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#define WARN(...) \ #define IPC_MAX_MSGLEN 4096
{}
#define TRACE(...) \
{}
#define SYSCHECK(...) \
{}
#define EQCHECK(...) \
{}
// Enable Linux abstract socket naming void ipc_warn(const char *format, ...) {
#define USE_ABSTRACT_SOCKET char buffer[IPC_MAX_MSGLEN];
#define NCCL_IPC_SOCKNAME_STR "/tmp/nccl-socket-%d-%lx" va_list args;
va_start(args, format);
vsnprintf(buffer, IPC_MAX_MSGLEN - 1, format, args);
snprintf(buffer + strlen(buffer), IPC_MAX_MSGLEN - strlen(buffer) - 1, " : %s (%d)\n",
strerror(errno), errno);
fflush(stdout);
fputs(buffer, stderr);
fflush(NULL);
va_end(args);
}
static const char *ipcSocketResultStrings[static_cast<int>(ipcSocketNumResults)] = {
"Success", "Unhandled CUDA error", "System error", "Internal error",
"Invalid argument", "Invalid usage", "Remote error", "In progress",
};
const char *ipcSocketGetErrorString(ipcSocketResult_t res) {
return ipcSocketResultStrings[static_cast<int>(res)];
}
#define USE_ABSTRACT_SOCKET // Enable Linux abstract socket naming
#define IPC_SOCKNAME_STR "/tmp/ub-ipc-socket-%d-%lx"
/* /*
* Create a Unix Domain Socket * Create a Unix Domain Socket
*/ */
ncclResult_t ncclIpcSocketInit(ncclIpcSocket *handle, int rank, uint64_t hash, ipcSocketResult_t ipcSocketInit(IpcSocketHandle *handle, int rank, uint64_t hash,
volatile uint32_t *abortFlag) { volatile uint32_t *abortFlag) {
int fd = -1; int fd = -1;
struct sockaddr_un cliaddr; struct sockaddr_un cliaddr;
char temp[NCCL_IPC_SOCKNAME_LEN] = ""; char temp[IPC_SOCKNAME_LEN] = "";
if (handle == NULL) { if (handle == NULL) {
return ncclInternalError; return ipcSocketInternalError;
} }
handle->fd = -1; handle->fd = -1;
handle->socketName[0] = '\0'; handle->socketName[0] = '\0';
if ((fd = socket(AF_UNIX, SOCK_DGRAM, 0)) < 0) { if ((fd = socket(AF_UNIX, SOCK_DGRAM, 0)) < 0) {
WARN("UDS: Socket creation error : %s (%d)", strerror(errno), errno); ipc_warn("UDS: Socket creation error");
return ncclSystemError; return ipcSocketSystemError;
} }
bzero(&cliaddr, sizeof(cliaddr)); bzero(&cliaddr, sizeof(cliaddr));
cliaddr.sun_family = AF_UNIX; cliaddr.sun_family = AF_UNIX;
// Create unique name for the socket. // Create unique name for the socket.
size_t len = snprintf(temp, NCCL_IPC_SOCKNAME_LEN, NCCL_IPC_SOCKNAME_STR, rank, hash); size_t len = snprintf(temp, IPC_SOCKNAME_LEN, IPC_SOCKNAME_STR, rank, hash);
if (len > (sizeof(cliaddr.sun_path) - 1)) { if (len > (sizeof(cliaddr.sun_path) - 1)) {
WARN("UDS: Cannot bind provided name to socket. Name too large"); errno = ENAMETOOLONG;
return ncclInternalError; ipc_warn("UDS: Cannot bind provided name to socket. Name too large");
return ipcSocketInternalError;
} }
#ifndef USE_ABSTRACT_SOCKET
unlink(temp);
#endif
TRACE(NCCL_INIT, "UDS: Creating socket %s", temp);
strncpy(cliaddr.sun_path, temp, len); strncpy(cliaddr.sun_path, temp, len);
#ifdef USE_ABSTRACT_SOCKET #ifdef USE_ABSTRACT_SOCKET
cliaddr.sun_path[0] = '\0'; // Linux abstract socket trick cliaddr.sun_path[0] = '\0'; // Linux abstract socket trick
#else
unlink(temp);
#endif #endif
if (bind(fd, (struct sockaddr *)&cliaddr, sizeof(cliaddr)) < 0) { if (bind(fd, (struct sockaddr *)&cliaddr, sizeof(cliaddr)) < 0) {
WARN("UDS: Binding to socket %s failed : %s (%d)", temp, strerror(errno), errno); ipc_warn("UDS: Binding to socket %s failed", temp);
close(fd); close(fd);
return ncclSystemError; return ipcSocketSystemError;
} }
handle->fd = fd; handle->fd = fd;
...@@ -79,24 +95,25 @@ ncclResult_t ncclIpcSocketInit(ncclIpcSocket *handle, int rank, uint64_t hash, ...@@ -79,24 +95,25 @@ ncclResult_t ncclIpcSocketInit(ncclIpcSocket *handle, int rank, uint64_t hash,
fcntl(fd, F_SETFL, flags | O_NONBLOCK); fcntl(fd, F_SETFL, flags | O_NONBLOCK);
} }
return ncclSuccess; return ipcSocketSuccess;
} }
ncclResult_t ncclIpcSocketGetFd(struct ncclIpcSocket *handle, int *fd) { ipcSocketResult_t ipcSocketGetFd(struct IpcSocketHandle *handle, int *fd) {
if (handle == NULL) { if (handle == NULL) {
WARN("ncclSocketGetFd: pass NULL socket"); errno = EINVAL;
return ncclInvalidArgument; ipc_warn("ipcSocketSocketGetFd: pass NULL socket");
return ipcSocketInvalidArgument;
} }
if (fd) *fd = handle->fd; if (fd) *fd = handle->fd;
return ncclSuccess; return ipcSocketSuccess;
} }
ncclResult_t ncclIpcSocketClose(ncclIpcSocket *handle) { ipcSocketResult_t ipcSocketClose(IpcSocketHandle *handle) {
if (handle == NULL) { if (handle == NULL) {
return ncclInternalError; return ipcSocketInternalError;
} }
if (handle->fd <= 0) { if (handle->fd <= 0) {
return ncclSuccess; return ipcSocketSuccess;
} }
#ifndef USE_ABSTRACT_SOCKET #ifndef USE_ABSTRACT_SOCKET
if (handle->socketName[0] != '\0') { if (handle->socketName[0] != '\0') {
...@@ -105,10 +122,10 @@ ncclResult_t ncclIpcSocketClose(ncclIpcSocket *handle) { ...@@ -105,10 +122,10 @@ ncclResult_t ncclIpcSocketClose(ncclIpcSocket *handle) {
#endif #endif
close(handle->fd); close(handle->fd);
return ncclSuccess; return ipcSocketSuccess;
} }
ncclResult_t ncclIpcSocketRecvMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, int *recvFd) { ipcSocketResult_t ipcSocketRecvMsg(IpcSocketHandle *handle, void *hdr, int hdrLen, int *recvFd) {
struct msghdr msg = {0, 0, 0, 0, 0, 0, 0}; struct msghdr msg = {0, 0, 0, 0, 0, 0, 0};
struct iovec iov[1]; struct iovec iov[1];
...@@ -138,39 +155,44 @@ ncclResult_t ncclIpcSocketRecvMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, ...@@ -138,39 +155,44 @@ ncclResult_t ncclIpcSocketRecvMsg(ncclIpcSocket *handle, void *hdr, int hdrLen,
while ((ret = recvmsg(handle->fd, &msg, 0)) <= 0) { while ((ret = recvmsg(handle->fd, &msg, 0)) <= 0) {
if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) { if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) {
WARN("UDS: Receiving data over socket failed : %d", errno); ipc_warn("UDS: Receiving data over socket failed");
return ncclSystemError; return ipcSocketSystemError;
} }
if (handle->abortFlag && *handle->abortFlag) return ncclInternalError; if (handle->abortFlag && *handle->abortFlag) return ipcSocketInternalError;
} }
if (recvFd != NULL) { if (recvFd != NULL) {
if (((cmptr = CMSG_FIRSTHDR(&msg)) != NULL) && (cmptr->cmsg_len == CMSG_LEN(sizeof(int)))) { if (((cmptr = CMSG_FIRSTHDR(&msg)) != NULL) && (cmptr->cmsg_len == CMSG_LEN(sizeof(int)))) {
if ((cmptr->cmsg_level != SOL_SOCKET) || (cmptr->cmsg_type != SCM_RIGHTS)) { if ((cmptr->cmsg_level != SOL_SOCKET) || (cmptr->cmsg_type != SCM_RIGHTS)) {
WARN("UDS: Receiving data over socket failed"); errno = EBADMSG;
return ncclSystemError; ipc_warn("UDS: Receiving data over socket %s failed", handle->socketName);
return ipcSocketSystemError;
} }
memmove(recvFd, CMSG_DATA(cmptr), sizeof(*recvFd)); memmove(recvFd, CMSG_DATA(cmptr), sizeof(*recvFd));
} else { } else {
WARN("UDS: Receiving data over socket %s failed", handle->socketName); errno = ENOMSG;
return ncclSystemError; ipc_warn("UDS: Receiving data over socket %s failed", handle->socketName);
return ipcSocketSystemError;
} }
TRACE(NCCL_INIT | NCCL_P2P, "UDS: Got recvFd %d from socket %s", *recvFd, handle->socketName); } else {
errno = EINVAL;
ipc_warn("UDS: File descriptor pointer cannot be NULL");
return ipcSocketInvalidArgument;
} }
return ncclSuccess; return ipcSocketSuccess;
} }
ncclResult_t ncclIpcSocketRecvFd(ncclIpcSocket *handle, int *recvFd) { ipcSocketResult_t ipcSocketRecvFd(IpcSocketHandle *handle, int *recvFd) {
return ncclIpcSocketRecvMsg(handle, NULL, 0, recvFd); return ipcSocketRecvMsg(handle, NULL, 0, recvFd);
} }
ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, const int sendFd, ipcSocketResult_t ipcSocketSendMsg(IpcSocketHandle *handle, void *hdr, int hdrLen, const int sendFd,
int rank, uint64_t hash) { int rank, uint64_t hash) {
struct msghdr msg = {0, 0, 0, 0, 0, 0, 0}; struct msghdr msg = {0, 0, 0, 0, 0, 0, 0};
struct iovec iov[1]; struct iovec iov[1];
char temp[NCCL_IPC_SOCKNAME_LEN]; char temp[IPC_SOCKNAME_LEN];
union { union {
struct cmsghdr cm; struct cmsghdr cm;
...@@ -185,10 +207,11 @@ ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, ...@@ -185,10 +207,11 @@ ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen,
bzero(&cliaddr, sizeof(cliaddr)); bzero(&cliaddr, sizeof(cliaddr));
cliaddr.sun_family = AF_UNIX; cliaddr.sun_family = AF_UNIX;
size_t len = snprintf(temp, NCCL_IPC_SOCKNAME_LEN, NCCL_IPC_SOCKNAME_STR, rank, hash); size_t len = snprintf(temp, IPC_SOCKNAME_LEN, IPC_SOCKNAME_STR, rank, hash);
if (len > (sizeof(cliaddr.sun_path) - 1)) { if (len > (sizeof(cliaddr.sun_path) - 1)) {
WARN("UDS: Cannot connect to provided name for socket. Name too large"); errno = ENAMETOOLONG;
return ncclInternalError; ipc_warn("UDS: Cannot connect to provided name for socket. Name too large");
return ipcSocketInternalError;
} }
(void)strncpy(cliaddr.sun_path, temp, len); (void)strncpy(cliaddr.sun_path, temp, len);
...@@ -196,11 +219,7 @@ ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, ...@@ -196,11 +219,7 @@ ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen,
cliaddr.sun_path[0] = '\0'; // Linux abstract socket trick cliaddr.sun_path[0] = '\0'; // Linux abstract socket trick
#endif #endif
TRACE(NCCL_INIT, "UDS: Sending hdr %p len %d to UDS socket %s", hdr, hdrLen, temp);
if (sendFd != -1) { if (sendFd != -1) {
TRACE(NCCL_INIT, "UDS: Sending fd %d to UDS socket %s", sendFd, temp);
msg.msg_control = control_un.control; msg.msg_control = control_un.control;
msg.msg_controllen = sizeof(control_un.control); msg.msg_controllen = sizeof(control_un.control);
...@@ -228,15 +247,16 @@ ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, ...@@ -228,15 +247,16 @@ ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen,
ssize_t sendResult; ssize_t sendResult;
while ((sendResult = sendmsg(handle->fd, &msg, 0)) < 0) { while ((sendResult = sendmsg(handle->fd, &msg, 0)) < 0) {
if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) { if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) {
WARN("UDS: Sending data over socket %s failed : %s (%d)", temp, strerror(errno), errno); ipc_warn("UDS: Sending data over socket %s failed", temp);
return ncclSystemError; return ipcSocketSystemError;
} }
if (handle->abortFlag && *handle->abortFlag) return ncclInternalError; if (handle->abortFlag && *handle->abortFlag) return ipcSocketInternalError;
} }
return ncclSuccess; return ipcSocketSuccess;
} }
ncclResult_t ncclIpcSocketSendFd(ncclIpcSocket *handle, const int sendFd, int rank, uint64_t hash) { ipcSocketResult_t ipcSocketSendFd(IpcSocketHandle *handle, const int sendFd, int rank,
return ncclIpcSocketSendMsg(handle, NULL, 0, sendFd, rank, hash); uint64_t hash) {
return ipcSocketSendMsg(handle, NULL, 0, sendFd, rank, hash);
} }
...@@ -4,10 +4,9 @@ ...@@ -4,10 +4,9 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#ifndef NCCL_IPCSOCKET_H #ifndef TRANSFORMER_ENGINE_USERBUFFERS_IPCSOCKET_H
#define NCCL_IPCSOCKET_H #define TRANSFORMER_ENGINE_USERBUFFERS_IPCSOCKET_H
// #include "nccl.h"
#include <errno.h> #include <errno.h>
#include <fcntl.h> #include <fcntl.h>
#include <inttypes.h> #include <inttypes.h>
...@@ -21,32 +20,33 @@ ...@@ -21,32 +20,33 @@
#include <unistd.h> #include <unistd.h>
typedef enum { typedef enum {
ncclSuccess = 0, ipcSocketSuccess = 0,
ncclUnhandledCudaError = 1, ipcSocketUnhandledCudaError = 1,
ncclSystemError = 2, ipcSocketSystemError = 2,
ncclInternalError = 3, ipcSocketInternalError = 3,
ncclInvalidArgument = 4, ipcSocketInvalidArgument = 4,
ncclInvalidUsage = 5, ipcSocketInvalidUsage = 5,
ncclRemoteError = 6, ipcSocketRemoteError = 6,
ncclInProgress = 7, ipcSocketInProgress = 7,
ncclNumResults = 8 ipcSocketNumResults = 8
} ncclResult_t; } ipcSocketResult_t;
#define NCCL_IPC_SOCKNAME_LEN 64 const char *ipcSocketGetErrorString(ipcSocketResult_t res);
struct ncclIpcSocket { #define IPC_SOCKNAME_LEN 64
struct IpcSocketHandle {
int fd; int fd;
char socketName[NCCL_IPC_SOCKNAME_LEN]; char socketName[IPC_SOCKNAME_LEN];
volatile uint32_t *abortFlag; volatile uint32_t *abortFlag;
}; };
ncclResult_t ncclIpcSocketInit(struct ncclIpcSocket *handle, int rank, uint64_t hash, ipcSocketResult_t ipcSocketInit(IpcSocketHandle *handle, int rank, uint64_t hash,
volatile uint32_t *abortFlag); volatile uint32_t *abortFlag);
ncclResult_t ncclIpcSocketClose(struct ncclIpcSocket *handle); ipcSocketResult_t ipcSocketClose(IpcSocketHandle *handle);
ncclResult_t ncclIpcSocketGetFd(struct ncclIpcSocket *handle, int *fd); ipcSocketResult_t ipcSocketGetFd(IpcSocketHandle *handle, int *fd);
ncclResult_t ncclIpcSocketRecvFd(struct ncclIpcSocket *handle, int *fd); ipcSocketResult_t ipcSocketRecvFd(IpcSocketHandle *handle, int *fd);
ncclResult_t ncclIpcSocketSendFd(struct ncclIpcSocket *handle, const int fd, int rank, ipcSocketResult_t ipcSocketSendFd(IpcSocketHandle *handle, const int fd, int rank, uint64_t hash);
uint64_t hash);
#endif /* NCCL_IPCSOCKET_H */ #endif /* TRANSFORMER_ENGINE_USERBUFFERS_IPCSOCKET_H */
...@@ -23,15 +23,6 @@ ...@@ -23,15 +23,6 @@
#define MAX_THREADS 1024 #define MAX_THREADS 1024
#define CUDACHECK(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while (0)
#define ATOMIC_CONSUMER(chunk) \ #define ATOMIC_CONSUMER(chunk) \
if (counters) { \ if (counters) { \
if (threadIdx.x == 0 && blockIdx.x == 0) { \ if (threadIdx.x == 0 && blockIdx.x == 0) { \
...@@ -1391,7 +1382,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1391,7 +1382,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg5), reinterpret_cast<void *>(&arg6), \ reinterpret_cast<void *>(&arg5), reinterpret_cast<void *>(&arg6), \
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \ reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10)}; \ reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10)}; \
CUDACHECK(cudaLaunchKernelExC( \ NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, \ &cfg, \
reinterpret_cast<void *>(comm->use_rr_kernel ? userbuffers_fp16_sum_inplace_gpu_rr_ag<x> \ reinterpret_cast<void *>(comm->use_rr_kernel ? userbuffers_fp16_sum_inplace_gpu_rr_ag<x> \
: userbuffers_fp16_sum_inplace_gpu_rw_ag<x>), \ : userbuffers_fp16_sum_inplace_gpu_rw_ag<x>), \
...@@ -1416,7 +1407,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1416,7 +1407,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \ reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \ reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \
reinterpret_cast<void *>(&arg11)}; \ reinterpret_cast<void *>(&arg11)}; \
CUDACHECK(cudaLaunchKernelExC( \ NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_mc_ag<x>), kernelArgs)); \ &cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_mc_ag<x>), kernelArgs)); \
} }
...@@ -1436,7 +1427,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1436,7 +1427,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg5), reinterpret_cast<void *>(&arg6), \ reinterpret_cast<void *>(&arg5), reinterpret_cast<void *>(&arg6), \
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \ reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10)}; \ reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10)}; \
CUDACHECK(cudaLaunchKernelExC( \ NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_rr_rs<x>), kernelArgs)); \ &cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_rr_rs<x>), kernelArgs)); \
} }
...@@ -1458,7 +1449,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1458,7 +1449,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \ reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \ reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \
reinterpret_cast<void *>(&arg11)}; \ reinterpret_cast<void *>(&arg11)}; \
CUDACHECK(cudaLaunchKernelExC( \ NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_mc_rs<x>), kernelArgs)); \ &cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_mc_rs<x>), kernelArgs)); \
} }
...@@ -1481,7 +1472,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1481,7 +1472,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \ reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \ reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \
reinterpret_cast<void *>(&arg13)}; \ reinterpret_cast<void *>(&arg13)}; \
CUDACHECK(cudaLaunchKernelExC( \ NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop<x>), \ &cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop<x>), \
kernelArgs)); \ kernelArgs)); \
} }
...@@ -1506,7 +1497,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1506,7 +1497,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \ reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \ reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14)}; \ reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14)}; \
CUDACHECK(cudaLaunchKernelExC( \ NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, \ &cfg, \
reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_fp8<x, fp8type>), \ reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_fp8<x, fp8type>), \
kernelArgs)); \ kernelArgs)); \
...@@ -1532,7 +1523,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1532,7 +1523,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \ reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \ reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14)}; \ reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14)}; \
CUDACHECK(cudaLaunchKernelExC( \ NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_mc_rs_oop<x>), \ &cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_mc_rs_oop<x>), \
kernelArgs)); \ kernelArgs)); \
} }
...@@ -1562,7 +1553,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1562,7 +1553,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14), \ reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14), \
reinterpret_cast<void *>(&arg15), reinterpret_cast<void *>(&arg16), \ reinterpret_cast<void *>(&arg15), reinterpret_cast<void *>(&arg16), \
reinterpret_cast<void *>(&arg17), reinterpret_cast<void *>(&arg18)}; \ reinterpret_cast<void *>(&arg17), reinterpret_cast<void *>(&arg18)}; \
CUDACHECK(cudaLaunchKernelExC( \ NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, \ &cfg, \
reinterpret_cast<void *>( \ reinterpret_cast<void *>( \
userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_atomic_fp8<x, fp8type>), \ userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_atomic_fp8<x, fp8type>), \
...@@ -1588,7 +1579,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1588,7 +1579,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \ reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \ reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \
reinterpret_cast<void *>(&arg13)}; \ reinterpret_cast<void *>(&arg13)}; \
CUDACHECK(cudaLaunchKernelExC( \ NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride<x>), \ &cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride<x>), \
kernelArgs)); \ kernelArgs)); \
} }
...@@ -1614,7 +1605,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1614,7 +1605,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \ reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14), \ reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14), \
reinterpret_cast<void *>(&arg15)}; \ reinterpret_cast<void *>(&arg15)}; \
CUDACHECK(cudaLaunchKernelExC( \ NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, \ &cfg, \
reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_atomic<x>), \ reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_atomic<x>), \
kernelArgs)); \ kernelArgs)); \
...@@ -1641,7 +1632,7 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -1641,7 +1632,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \ reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14), \ reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14), \
reinterpret_cast<void *>(&arg15)}; \ reinterpret_cast<void *>(&arg15)}; \
CUDACHECK( \ NVTE_CHECK_CUDA( \
cudaLaunchKernelExC(&cfg, \ cudaLaunchKernelExC(&cfg, \
reinterpret_cast<void *>( \ reinterpret_cast<void *>( \
userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_multiatomic<x>), \ userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_multiatomic<x>), \
...@@ -2206,15 +2197,6 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat ...@@ -2206,15 +2197,6 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat
} }
} }
#define CUDACHECK(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while (0)
// Return TRUE if two ranks share the same NV domain // Return TRUE if two ranks share the same NV domain
#define INTRANODE(peer) ((peer / comm->nvsize) == (comm->myrank / comm->nvsize)) #define INTRANODE(peer) ((peer / comm->nvsize) == (comm->myrank / comm->nvsize))
...@@ -2259,7 +2241,7 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds ...@@ -2259,7 +2241,7 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds
if (comm->use_ce) { if (comm->use_ce) {
// kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_start_ptr)); // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_start_ptr));
CUDACHECK(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); NVTE_CHECK_CUDA(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
// kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_end_ptr)); // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_end_ptr));
} }
SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream); SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream);
...@@ -2269,7 +2251,7 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds ...@@ -2269,7 +2251,7 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds
void *kernelArgs[] = {reinterpret_cast<void *>(&arg1), reinterpret_cast<void *>(&arg2), void *kernelArgs[] = {reinterpret_cast<void *>(&arg1), reinterpret_cast<void *>(&arg2),
reinterpret_cast<void *>(&arg3), reinterpret_cast<void *>(&arg4), reinterpret_cast<void *>(&arg3), reinterpret_cast<void *>(&arg4),
reinterpret_cast<void *>(&arg5)}; reinterpret_cast<void *>(&arg5)};
CUDACHECK( NVTE_CHECK_CUDA(
cudaLaunchKernelExC(&cfg, reinterpret_cast<void *>(kuserbuffers_pushsend), kernelArgs)); cudaLaunchKernelExC(&cfg, reinterpret_cast<void *>(kuserbuffers_pushsend), kernelArgs));
} }
} }
...@@ -2291,7 +2273,8 @@ void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size ...@@ -2291,7 +2273,8 @@ void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size
if (comm->use_ce) { if (comm->use_ce) {
// kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_start_ptr)); // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_start_ptr));
CUDACHECK(cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); NVTE_CHECK_CUDA(
cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
// kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_end_ptr)); // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_end_ptr));
} }
SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream); SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream);
...@@ -2323,7 +2306,7 @@ void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size ...@@ -2323,7 +2306,7 @@ void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12),
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14), reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14),
reinterpret_cast<void *>(&arg15)}; reinterpret_cast<void *>(&arg15)};
CUDACHECK( NVTE_CHECK_CUDA(
cudaLaunchKernelExC(&cfg, reinterpret_cast<void *>(kuserbuffers_pushsendrecv), kernelArgs)); cudaLaunchKernelExC(&cfg, reinterpret_cast<void *>(kuserbuffers_pushsendrecv), kernelArgs));
} }
...@@ -2346,7 +2329,8 @@ void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler, ...@@ -2346,7 +2329,8 @@ void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler,
reinterpret_cast<char *>(comm->peer_ptr[dsthandler][send_peerlocal]) + send_offset; reinterpret_cast<char *>(comm->peer_ptr[dsthandler][send_peerlocal]) + send_offset;
if (comm->use_ce) { if (comm->use_ce) {
// kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_start_ptr)); // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_start_ptr));
CUDACHECK(cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); NVTE_CHECK_CUDA(
cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
// kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_end_ptr)); // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_end_ptr));
} }
SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream); SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream);
...@@ -2379,8 +2363,8 @@ void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler, ...@@ -2379,8 +2363,8 @@ void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler,
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12),
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14), reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14),
reinterpret_cast<void *>(&arg15), reinterpret_cast<void *>(&arg16)}; reinterpret_cast<void *>(&arg15), reinterpret_cast<void *>(&arg16)};
CUDACHECK(cudaLaunchKernelExC(&cfg, reinterpret_cast<void *>(kuserbuffers_pushsendrecv_atomic), NVTE_CHECK_CUDA(cudaLaunchKernelExC(
kernelArgs)); &cfg, reinterpret_cast<void *>(kuserbuffers_pushsendrecv_atomic), kernelArgs));
} }
void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler, void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler,
...@@ -2425,7 +2409,7 @@ void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler ...@@ -2425,7 +2409,7 @@ void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14), reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14),
reinterpret_cast<void *>(&arg15), reinterpret_cast<void *>(&arg16), reinterpret_cast<void *>(&arg15), reinterpret_cast<void *>(&arg16),
reinterpret_cast<void *>(&arg17), reinterpret_cast<void *>(&arg18)}; reinterpret_cast<void *>(&arg17), reinterpret_cast<void *>(&arg18)};
CUDACHECK(cudaLaunchKernelExC( NVTE_CHECK_CUDA(cudaLaunchKernelExC(
&cfg, reinterpret_cast<void *>(kuserbuffers_pushsendrecv_multiatomic), kernelArgs)); &cfg, reinterpret_cast<void *>(kuserbuffers_pushsendrecv_multiatomic), kernelArgs));
} }
...@@ -2451,7 +2435,7 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds ...@@ -2451,7 +2435,7 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds
if (!signalonly) if (!signalonly)
kuserbuffers_inc<<<1, 1, 0, stream>>>(&(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler])); kuserbuffers_inc<<<1, 1, 0, stream>>>(&(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler]));
if (comm->use_ce) { if (comm->use_ce) {
CUDACHECK(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); NVTE_CHECK_CUDA(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
} }
} else { } else {
kuserbuffers_pushrecv<<<1, 1, 0, stream>>>( kuserbuffers_pushrecv<<<1, 1, 0, stream>>>(
......
...@@ -15,39 +15,11 @@ ...@@ -15,39 +15,11 @@
#include <functional> #include <functional>
#include <stdexcept> #include <stdexcept>
#ifdef UB_MPI_BOOTSTRAP #include "common/util/logging.h"
#include <mpi.h>
#include <stdexcept>
#define UB_MPI_CHECK(expr) \
do { \
const int mpicode = (expr); \
if (mpicode != MPI_SUCCESS) { \
char mpimsg[MPI_MAX_ERROR_STRING]; \
int mpilen; \
MPI_Error_string(mpicode, mpimsg, &mpilen); \
std::vector<char> errmsg(1024); \
snprintf(errmsg.data(), errmsg.size(), "%s:%s in function %s: %s", __FILE__, __LINE__, \
__func__, mpimsg); \
throw std::runtime_error(errmsg.data()); \
} \
} while (false)
#ifdef NVTE_UB_WITH_MPI
#include <mpi.h>
typedef MPI_Comm ExtComm; typedef MPI_Comm ExtComm;
void ub_alloc_copy_allgather(void **globaldata, void *localdata, size_t localbytes, ExtComm comm) {
int myrank, nranks;
UB_MPI_CHECK(MPI_Comm_rank(comm, &myrank));
UB_MPI_CHECK(MPI_Comm_size(comm, &nranks));
*globaldata = malloc(nranks * localbytes);
UB_MPI_CHECK(MPI_Allgather(localdata, localbytes, MPI_BYTE, *globaldata, nranks * localbytes,
MPI_BYTE, comm));
}
void ub_barrier(ExtComm comm) { UB_MPI_CHECK(MPI_Barrier(comm)); }
void ub_free(void *ptr) { free(ptr); }
#else #else
typedef char *ExtComm; typedef char *ExtComm;
#endif #endif
...@@ -170,14 +142,13 @@ struct communicator { ...@@ -170,14 +142,13 @@ struct communicator {
volatile int tail; volatile int tail;
// Abstract communication callbacks to support external bootstrapping (e.g. DL frameworks) // Abstract communication callbacks to support external bootstrapping (e.g. DL frameworks)
std::function<void(void **, void *, size_t, ExtComm)> _alloc_copy_allgather; std::function<void(void *, size_t, void *, size_t, ExtComm)> _allgather;
std::function<void(ExtComm)> _barrier; std::function<void(ExtComm)> _barrier;
std::function<void(void *)> _free;
ExtComm comm_world, ExtComm comm_world,
comm_inter, // reduction group communicator (subset of the nodes) along GPU rail comm_inter, // reduction group communicator (subset of the nodes) along GPU rail
comm_intra; // full intranode (all ndev GPUS) comm_intra; // full intranode (all ndev GPUS)
#ifdef UB_MPI_BOOTSTRAP #ifdef NVTE_UB_WITH_MPI
MPI_Request mpihndl[NVTE_MAX_SHARP]; MPI_Request mpihndl[NVTE_MAX_SHARP];
#endif #endif
...@@ -194,20 +165,19 @@ void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStr ...@@ -194,20 +165,19 @@ void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStr
/* creates communicator, allocates all internal buffers if necessary */ /* creates communicator, allocates all internal buffers if necessary */
int create_communicator_grouped2( int create_communicator_grouped2(
communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode,
int numnodes, std::function<void(void **, void *, size_t, ExtComm)> ext_alloc_copy_allgather, int numnodes, std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather,
std::function<void(ExtComm)> ext_barrier, std::function<void(void *)> ext_free, int pipegpus, std::function<void(ExtComm)> ext_barrier, int pipegpus, int pipenodes, int tensorgpus,
int pipenodes, int tensorgpus, int tensornodes); int tensornodes);
int create_communicator_grouped( int create_communicator_grouped(
communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode,
int numnodes, std::function<void(void **, void *, size_t, ExtComm)> ext_alloc_copy_allgather, int numnodes, std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather,
std::function<void(ExtComm)> ext_barrier, std::function<void(void *)> ext_free, int pipegpus, std::function<void(ExtComm)> ext_barrier, int pipegpus, int pipenodes);
int pipenodes);
int create_communicator( int create_communicator(communicator **comm, int myrank, int numranks, int mylocal, int numlocal,
communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, int mynode, int numnodes,
int numnodes, std::function<void(void **, void *, size_t, ExtComm)> ext_alloc_copy_allgather, std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather,
std::function<void(ExtComm)> ext_barrier, std::function<void(void *)> ext_free); std::function<void(ExtComm)> ext_barrier);
int create_communicator_grouped2_mpi(communicator **comm, int pipegpus, int pipenodes, int create_communicator_grouped2_mpi(communicator **comm, int pipegpus, int pipenodes,
int tensorgpus, int tensornodes); int tensorgpus, int tensornodes);
......
...@@ -7,6 +7,9 @@ import io ...@@ -7,6 +7,9 @@ import io
import os import os
import pickle import pickle
import warnings import warnings
import socket
import fcntl
import struct
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, Generator, List, Optional, Tuple, Union from typing import Dict, Generator, List, Optional, Tuple, Union
from contextlib import contextmanager from contextlib import contextmanager
...@@ -79,19 +82,109 @@ def get_multi_stream_cublas_workspace() -> List[torch.Tensor]: ...@@ -79,19 +82,109 @@ def get_multi_stream_cublas_workspace() -> List[torch.Tensor]:
def initialize_ub( def initialize_ub(
shape: list, shape: list,
tp_group: dist_group_type, tp_size: int,
use_fp8: bool = False, use_fp8: bool = False,
dtype: torch.dtype = torch.bfloat16, dtype: torch.dtype = torch.bfloat16,
ub_cfgs: Optional[dict] = None, ub_cfgs: Optional[dict] = None,
bootstrap_backend: Union[str, torch.distributed.Backend] = None,
) -> None: ) -> None:
"""Initialize communicators for TP comm overlap using userbuffers.""" """Initialize communicators for TP comm overlap using userbuffers."""
if not tex.device_supports_multicast():
assert bool(os.getenv("UB_SKIPMC", "0")), (
"CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with "
+ "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead."
)
global _ub_communicators global _ub_communicators
assert _ub_communicators is None, "UB communicators are already initialized." assert _ub_communicators is None, "UB communicators are already initialized."
_ub_communicators = {} _ub_communicators = {}
rank_id = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size() if tex.ubuf_built_with_mpi():
tp_id = torch.distributed.get_rank(tp_group) # Userbuffers will ignore all these values when it is built with MPI, so these are just
tp_size = torch.distributed.get_world_size(tp_group) # placeholders based on an assumption that tp_size covers all devices in a physical node.
assert torch.distributed.is_mpi_available()
mpi_group = torch.distributed.new_group(backend="mpi")
world_rank = torch.distributed.get_rank(mpi_group)
world_size = torch.distributed.get_world_size(mpi_group)
local_rank = world_rank % tp_size
local_size = tp_size
node_id = world_rank // tp_size
num_nodes = world_size // tp_size
ub_callbacks = tex.UbufBootstrapCallbacks()
else:
assert (
torch.distributed.is_initialized()
), "torch.distributed must be initialized before Userbuffers"
if bootstrap_backend is None:
bootstrap_backend = "nccl"
if torch.distributed.is_gloo_available():
bootstrap_backend = "gloo"
elif torch.distributed.is_mpi_available():
bootstrap_backend = "mpi"
else:
assert bootstrap_backend in ["gloo", "mpi", "nccl"]
world_group = torch.distributed.new_group(backend=bootstrap_backend)
world_rank = torch.distributed.get_rank(world_group)
world_size = torch.distributed.get_world_size(world_group)
if world_rank == 0:
print(
f'!!! [NVTE] Bootstrapping Userbuffers with backend="{bootstrap_backend}"\n',
end="",
flush=True,
)
# Construct an intra-node communicator based on global ranks that share the same hostname
# NOTE: If the user specified a valid network interface for NCCL or GLOO, use the host
# address on that interface instead of the hostname. This can help avoid issues when
# different hosts have the same hostname on Kubernetes clusters.
hostname = socket.gethostname()
ifname = os.getenv(
"NVTE_UB_SOCKET_IFNAME",
os.getenv("NCCL_SOCKET_IFNAME", os.getenv("GLOO_SOCKET_IFNAME")),
)
if ifname is not None:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
hostname = socket.inet_ntoa(
fcntl.ioctl(
s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8"))
)[20:24]
)
except OSError as err:
raise OSError(f"Invalid network interface: {ifname}") from err
hostnames = [None for _ in range(world_size)]
torch.distributed.all_gather_object(hostnames, hostname, world_group)
intra_node_ranks = []
for i, host in enumerate(hostnames):
if host == hostname:
intra_node_ranks.append(i)
if len(intra_node_ranks) == world_size:
intra_node_group = world_group
local_rank = world_rank
local_size = world_size
intra_node_ranks = list(range(world_size))
else:
intra_node_group = torch.distributed.new_group(
backend=bootstrap_backend, ranks=intra_node_ranks
)
local_rank = torch.distributed.get_rank(intra_node_group)
local_size = torch.distributed.get_world_size(intra_node_group)
node_id = world_rank // local_size
num_nodes = world_size // local_size
if local_rank == 0:
print(
f"!!! [NVTE] Number of physical nodes: {num_nodes}\n"
+ f"!!! [NVTE] Global ranks on node {node_id}: {intra_node_ranks}\n",
end="",
flush=True,
)
ub_callbacks = tex.UbufBootstrapCallbacks(world_group, intra_node_group)
# Increase the workspace by the number of maximum concurrent streams # Increase the workspace by the number of maximum concurrent streams
global _cublas_workspace global _cublas_workspace
...@@ -127,6 +220,23 @@ def initialize_ub( ...@@ -127,6 +220,23 @@ def initialize_ub(
return method return method
raise KeyError(f"Given layer name {name} does not exist.") raise KeyError(f"Given layer name {name} does not exist.")
def get_default_config(name):
method = get_method(name)
is_reduce_scatter = name in layers_reduce_scatter_overlap
default_cfg = {
"method": method,
"is_reduce_scatter": is_reduce_scatter,
"num_sm": 1 if method == "ring_exchange" else 16,
"cga_size": 1 if method == "ring_exchange" else 2,
"set_sm_margin": False,
"num_splits": 4 if method == "pipeline" else tp_size,
"aggregate": False,
"atomic_gemm": False,
"use_ce": True,
"fp8_buf": name in layers_all_gather_overlap,
}
return default_cfg
def add_ub( def add_ub(
name: str, name: str,
method: str, method: str,
...@@ -180,53 +290,43 @@ def initialize_ub( ...@@ -180,53 +290,43 @@ def initialize_ub(
if method == "ring_exchange": if method == "ring_exchange":
ub_obj = tex.UbufP2PCommOverlap( ub_obj = tex.UbufP2PCommOverlap(
sample_buffer, # Sample userbuffer sample_buffer, # Sample userbuffer
rank_id, # Rank id world_rank, # World rank
world_size, # World size world_size, # World size
tp_id, # TP id local_rank, # Rank within the node
tp_size, # TP size local_size, # Number of ranks/GPUs per node
node_id, # Node ID
num_nodes, # Number of nodes
tp_size, # Tensor-parallel group size (may be different than local_size)
num_sm, # Number of communication SMs num_sm, # Number of communication SMs
cga_size, # CGA cluster size cga_size, # CGA cluster size
set_sm_margin, # Set SM margin set_sm_margin, # Set SM margin
aggregate, # Aggregate 2X GEMM chunks aggregate, # Aggregate 2X GEMM chunks
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams
is_reduce_scatter, # overlap with reduce scatter is_reduce_scatter, # Overlap with reduce scatter
atomic_gemm, # use a single GEMM with atomic-counters atomic_gemm, # Use a single GEMM with atomic-counters
use_ce, # use copy engine for P2P communications use_ce, # Use copy engine for P2P communications
torch.Tensor(), # empty tensor to pass to counters ub_callbacks,
) )
else: else:
ub_obj = tex.UbufCommOverlap( ub_obj = tex.UbufCommOverlap(
sample_buffer, # Sample userbuffer sample_buffer, # Sample userbuffer
rank_id, # Rank id world_rank, # World rank
world_size, # World size world_size, # World size
tp_id, # TP id local_rank, # Rank within the node
tp_size, # TP size local_size, # Number of ranks/GPUs per node
node_id, # Node ID
num_nodes, # Number of nodes
tp_size, # Tensor-parallel group size (may be different than local_size)
num_sm, # Number of communication SMs num_sm, # Number of communication SMs
cga_size, # CGA cluster size cga_size, # CGA cluster size
num_splits, # Number of communication splits num_splits, # Number of communication splits
set_sm_margin, # Set SM margin set_sm_margin, # Set SM margin
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams
atomic_gemm, # use a single GEMM with atomic-counters atomic_gemm, # Use a single GEMM with atomic-counters
torch.Tensor(), # empty tensor to pass to counters ub_callbacks,
) )
_ub_communicators[name] = ub_obj _ub_communicators[name] = ub_obj
def alloc_copy_allgather_callback(local_data: torch.Tensor, group: str) -> torch.Tensor:
pg = None if group == "world" else tp_group
global_size = local_data.numel() * torch.distributed.get_world_size(pg)
global_data = torch.zeros(global_size, dtype=local_data.dtype, device="cuda")
torch.distributed.all_gather_into_tensor(global_data, local_data.cuda(), group=pg)
return global_data.cpu()
def barrier_callback(group: str) -> None:
pg = None if group == "world" else tp_group
torch.distributed.barrier(group=pg)
def free_callback(data: torch.Tensor) -> None:
data.data = torch.Tensor()
tex.set_ubuf_bootstrap_callbacks(alloc_copy_allgather_callback, barrier_callback, free_callback)
if ub_cfgs is not None: if ub_cfgs is not None:
for name in dgrad_reduce_scatter_overlap: for name in dgrad_reduce_scatter_overlap:
if name in ub_cfgs and "method" in ub_cfgs[name] and ub_cfgs[name]["method"] != "bulk": if name in ub_cfgs and "method" in ub_cfgs[name] and ub_cfgs[name]["method"] != "bulk":
...@@ -238,48 +338,18 @@ def initialize_ub( ...@@ -238,48 +338,18 @@ def initialize_ub(
methods["pipeline"].append(name) methods["pipeline"].append(name)
for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]: for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]:
ub_cfg = get_default_config(name)
if ub_cfgs is not None and name in ub_cfgs: if ub_cfgs is not None and name in ub_cfgs:
ub_cfg = ub_cfgs[name]
method = ub_cfg.get("method", get_method(name))
num_sm = ub_cfg.get("num_sm", 1 if method == "ring_exchange" else 16)
cga_size = ub_cfg.get("cga_size", 1 if method == "ring_exchange" else 2)
num_splits = ub_cfg.get("num_splits", 4 if method == "pipeline" else 0)
set_sm_margin = ub_cfg.get("set_sm_margin", 0)
aggregate = ub_cfg.get("aggregate", 0)
atomic_gemm = ub_cfg.get("atomic_gemm", 0)
use_ce = ub_cfg.get("use_ce", True)
is_reduce_scatter = 1 if name in layers_reduce_scatter_overlap else 0
# Support FP8 userbuffer when (1) AllGather and (2) FP8-GEMM output ReduceScatter
fp8_buf = (name in layers_all_gather_overlap) or ( fp8_buf = (name in layers_all_gather_overlap) or (
ub_cfg.get("fp8_buf", False) and name in methods["pipeline"] ub_cfgs[name].get("fp8_buf", False) and name in methods["pipeline"]
)
add_ub(
name,
method,
is_reduce_scatter,
num_sm,
cga_size,
set_sm_margin,
num_splits,
aggregate,
atomic_gemm,
use_ce,
fp8_buf,
)
else:
method = get_method(name)
add_ub(
name,
method=method,
is_reduce_scatter=1 if name in layers_reduce_scatter_overlap else 0,
num_splits=4 if method == "pipeline" else 0,
fp8_buf=name in layers_all_gather_overlap,
) )
ub_cfg.update(ub_cfgs[name])
ub_cfg["fp8_buf"] = fp8_buf
add_ub(name, **ub_cfg)
def get_ub(name: str): def get_ub(name: str):
"""Get userbuffer communicator corresponding to give key.""" """Get userbuffer communicator corresponding to give key."""
global _ub_communicators
assert _ub_communicators is not None, "UB manager is not initialized." assert _ub_communicators is not None, "UB manager is not initialized."
assert name in _ub_communicators, f"UB for {name} is not registered." assert name in _ub_communicators, f"UB for {name} is not registered."
return _ub_communicators[name] return _ub_communicators[name]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment