Unverified Commit 5ee98175 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[PyTorch] Fixing hang in `initialize_ub()` for multi-node runs after PR901...


[PyTorch] Fixing hang in `initialize_ub()` for multi-node runs after PR901 removal of MPI-dependence (#986)

* Re-implementing PR901 (removing MPI-dependence in Userbuffers) with multi-node fixes

* passing data-parallel rank/size info from torch.distributed to userbuffers
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* multi-node example working with UB_SKIPMC=1 but not with multicast
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed multi-node hang in initialize_ub(), updated comm+GEMM overlap example to support multi-node mixed tensor/data parallelism, added README
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed use case when Userbuffers is asked to allocate the TP overlap buffer with UB_SKIPMC=1
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected example problem to set device by local ordinal instead of global process rank
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* double-free fix in userbuffers destructor
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* removed unnecessary and incorrect torch.cuda.set_device(...)
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected inter-node ranks logic
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* generalized node ID logic in initialize_ub to handle arbitrary world rank layouts within node
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added single-node comm+GEMM overlap unit tests
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* LayerNormMLP example confirmed working with 2 nodes on Eos
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* unit test cleanup
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected DP group ranks logic in LNMLP comm+GEMM overlap example
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected enums in unit test
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed incorrect Ubuf object init signature
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* switched default backend for Userbuffer bootstrapping to Gloo with MPI and NCCL fallbacks, and initialize_ub option to manually select backend
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed all comm+GEMM overlap unit tests
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected all_gather use for Gloo backend
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* changed userbuffers allgather callback to always use all_gather() instead of all_gather_into_tensor()
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* restored and verified old MPI-based bootstrapping via NVTE_UB_WITH_MPI=1 option at compile time
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* disabled scoped GIL release for comm+GEMM overlap algorithms
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* avoid dist.init_device_mesh in comm+GEMM overlap example to support older PyTorch versions
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* applied RS overlap FP8 fix from PR1004
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed segfault in Userbuffers destructor
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected comm+GEMM overlap unit test arguments
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed unit test run command for when Userbuffers is compiled with MPI
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Refactored torch.distributed collectives into pure C++ callbacks
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 71124c31
......@@ -77,14 +77,14 @@ def setup_pytorch_extension(
# Libraries
library_dirs = []
libraries = []
if os.getenv("UB_MPI_BOOTSTRAP"):
if os.getenv("NVTE_UB_WITH_MPI"):
assert (
os.getenv("MPI_HOME") is not None
), "MPI_HOME must be set when compiling with UB_MPI_BOOTSTRAP=1"
), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
mpi_home = Path(os.getenv("MPI_HOME"))
include_dirs.append(mpi_home / "include")
cxx_flags.append("-DUB_MPI_BOOTSTRAP")
nvcc_flags.append("-DUB_MPI_BOOTSTRAP")
cxx_flags.append("-DNVTE_UB_WITH_MPI")
nvcc_flags.append("-DNVTE_UB_WITH_MPI")
library_dirs.append(mpi_home / "lib")
libraries.append("mpi")
......
# Overlapping Communication with GEMM in TransformerEngine Modules
## Requirements
- Tensor-parallel GPUs must be on a single node, and connected over NVLink/NVSwitch.
- `CUDA_DEVICE_MAX_CONNECTIONS=1` must be enabled in the environment.
- For best performance, point-to-point communication via _CUDA Multicast_ needs CUDA Toolkit 12.0+
and CUDA driver 535+ on devices with compute capability 9.0 or newer.
- Devices older than compute capability 9.0 require `UB_SKIPMC=1` in the environment in order fall
back on a less performant implementation based on CUDA Inter-Process Communication (IPC) handles.
## Examples
### Single node, tensor-parallel LayerNormMLP:
Forward and backward passes with layer weights distributed over all GPUs in a single node.
```bash
$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) ln_mlp_with_overlap.py
# Sample output on 8x H100s:
# [rank0:node0] |-- Created tensor-parallel group: [0, 1, 2, 3, 4, 5, 6, 7]
# !!! [UB] Create UbufP2PCommOverlap Communicator
# UB_TIMEOUT is set to 110 sec, 217800000000 cycles, freq: 1980000khz
# MC initialized succesfully, window size = 549755813888
# !!! [UBP2P] Register UBuf 1
# !!! [UBP2P] Register UBuf 2
# !!! [UBP2P] Register UBuf 3
# !!! [UBP2P] Register UBuf 4
# !!! [UB] Register UBuf 5
# !!! [UBP2P] Register UBuf 6
# !!! [UB] Register UBuf 7
# !!! [UB] Register UBuf 8
# !!! [UBP2P] Register UBuf 9
# !!! [UB] Register UBuf 10
# [rank0:node0] Iter 1
# [rank0:node0] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank0:node0] |-- Compute loss
# [rank0:node0] |-- Backward pass
# [rank0:node0] |-- Optimizer step
# [rank0:node0] Iter 2
# [rank0:node0] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank0:node0] |-- Compute loss
# [rank0:node0] |-- Backward pass
# [rank0:node0] |-- Optimizer step
# [rank0:node0] Iter 3
# [rank0:node0] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank0:node0] |-- Compute loss
# [rank0:node0] |-- Backward pass
# [rank0:node0] |-- Optimizer step
# [rank0:node0] Iter 4
# [rank0:node0] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank0:node0] |-- Compute loss
# [rank0:node0] |-- Backward pass
# [rank0:node0] |-- Optimizer step
# [rank0:node0] Iter 5
# [rank0:node0] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank0:node0] |-- Compute loss
# [rank0:node0] |-- Backward pass
# [rank0:node0] |-- Optimizer step
```
### Single node, mixed data- and tensor-parallel LayerNormMLP:
Uses `torch.nn.parallel.DistributedDataParallel` for replicatin the model across 2 tensor-parallel
groups in a single node.
```bash
$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) ln_mlp_overlap.py --num-replicas 2
# Sample output on 8x H100s:
# [rank0:node0] |-- Created tensor-parallel group: [0, 1, 2, 3]
# [rank4:node1] |-- Created tensor-parallel group: [4, 5, 6, 7]
# [rank0:node0] |-- Created data-parallel group: [0, 4]
# [rank3:node1] |-- Created data-parallel group: [3, 7]
# [rank1:node1] |-- Created data-parallel group: [1, 5]
# [rank2:node0] |-- Created data-parallel group: [2, 6]
# !!! [UB] Create UbufP2PCommOverlap Communicator
# UB_TIMEOUT is set to 110 sec, 217800000000 cycles, freq: 1980000khz
# MC initialized succesfully, window size = 549755813888
# !!! [UBP2P] Register UBuf 1
# !!! [UBP2P] Register UBuf 2
# !!! [UBP2P] Register UBuf 3
# !!! [UBP2P] Register UBuf 4
# !!! [UB] Register UBuf 5
# !!! [UBP2P] Register UBuf 6
# !!! [UB] Register UBuf 7
# !!! [UB] Register UBuf 8
# !!! [UBP2P] Register UBuf 9
# !!! [UB] Register UBuf 10
# [rank4:node1] Iter 1
# [rank0:node0] Iter 1
# [rank0:node0] |-- Generate random input batch
# [rank4:node1] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank4:node1] |-- Forward pass
# [rank4:node1] |-- Compute loss
# [rank0:node0] |-- Compute loss
# [rank0:node0] |-- Backward pass
# [rank4:node1] |-- Backward pass
# [rank4:node1] |-- Optimizer step
# [rank0:node0] |-- Optimizer step
# [rank4:node1] Iter 2
# [rank0:node0] Iter 2
# [rank0:node0] |-- Generate random input batch
# [rank4:node1] |-- Generate random input batch
# [rank4:node1] |-- Forward pass
# [rank0:node0] |-- Forward pass
# [rank4:node1] |-- Compute loss
# [rank0:node0] |-- Compute loss
# [rank4:node1] |-- Backward pass
# [rank0:node0] |-- Backward pass
# [rank4:node1] |-- Optimizer step
# [rank0:node0] |-- Optimizer step
# [rank4:node1] Iter 3
# [rank0:node0] Iter 3
# [rank0:node0] |-- Generate random input batch
# [rank4:node1] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank4:node1] |-- Forward pass
# [rank4:node1] |-- Compute loss
# [rank0:node0] |-- Compute loss
# [rank4:node1] |-- Backward pass
# [rank0:node0] |-- Backward pass
# [rank0:node0] |-- Optimizer step
# [rank4:node1] |-- Optimizer step
# [rank0:node0] Iter 4
# [rank4:node1] Iter 4
# [rank0:node0] |-- Generate random input batch
# [rank4:node1] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank4:node1] |-- Forward pass
# [rank0:node0] |-- Compute loss
# [rank4:node1] |-- Compute loss
# [rank4:node1] |-- Backward pass
# [rank0:node0] |-- Backward pass
# [rank4:node1] |-- Optimizer step
# [rank0:node0] |-- Optimizer step
# [rank4:node1] Iter 5
# [rank0:node0] Iter 5
# [rank0:node0] |-- Generate random input batch
# [rank4:node1] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank4:node1] |-- Forward pass
# [rank0:node0] |-- Compute loss
# [rank4:node1] |-- Compute loss
# [rank0:node0] |-- Backward pass
# [rank4:node1] |-- Backward pass
# [rank4:node1] |-- Optimizer step
# [rank0:node0] |-- Optimizer step
```
**NOTE:** To run with Fp8 compute on supporting hardware, add the `--fp8` flag to the commands
shown above.
......@@ -6,17 +6,22 @@
import os
import sys
import subprocess
import socket
import argparse
import warnings
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
def parse_args(argv=None, namespace=None):
def _parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser(
description="Test a te.LayerNormMLP module with GEMM+comm overlap via Userbuffers."
)
......@@ -47,63 +52,182 @@ def parse_args(argv=None, namespace=None):
default=False,
help="Disable the comm+GEMM overlap.",
)
parser.add_argument("-v", "--verbose", action="store_true", default=False)
return parser.parse_args(argv, namespace)
parser.add_argument(
"--num-replicas", type=int, default=1, help="Number of data-parallel model replicas."
)
parser.add_argument(
"--tcp-init",
action="store_true",
default=False,
help="Initialize torch.distributed with TcpStore.",
)
parser.add_argument(
"--bind-to-device",
action="store_true",
default=False,
help="Initialize torch.distributed with `device_id` to bind each rank to a single device.",
)
parser.add_argument(
"--bootstrap-backend",
type=str.lower,
default="nccl",
choices=["gloo", "mpi", "nccl"],
help="Communications backend for host tensor collectives during Userbuffers bootstrapping.",
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
default=False,
help="Print out from every rank instead of just the root rank of relevant process groups.",
)
parser.add_argument(
"--debug",
action="store_true",
default=False,
help="Print out additional debug information.",
)
args = parser.parse_args(argv, namespace)
return args
def train(opts):
WORLD_RANK = int(os.getenv("RANK"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE"))
def _train(opts):
if "OMPI_COMM_WORLD_SIZE" in os.environ:
# Execution with `mpirun -np N`
WORLD_RANK = int(os.getenv("OMPI_COMM_WORLD_RANK", "0"))
WORLD_SIZE = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1"))
opts.tcp_init = True
opts.bind_to_device = True
opts.bootstrap_backend = "mpi"
elif "TORCHELASTIC_RUN_ID" in os.environ:
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
else:
raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!")
NUM_NODES = WORLD_SIZE // LOCAL_SIZE
def dist_print(msg, group=None, end="\n", debug=False):
if debug and not opts.debug:
return
group = dist.new_group() if group is None else group
group_rank = dist.get_rank(group)
group_size = dist.get_world_size(group)
all_ranks = dist.get_process_group_ranks(group)
ranks_skip = all_ranks[1] - all_ranks[0] > 1
group_id = WORLD_RANK % group_size if ranks_skip else WORLD_RANK // group_size
if group_rank == 0 or opts.verbose:
print(f"[rank{WORLD_RANK}:node{group_id}] {msg}{end}", end="", flush=True)
dist.barrier(group)
# Initialize torch.distributed global process group and get DP/TP groups
torch.cuda.set_device(LOCAL_RANK)
dist_init_kwargs = {
"backend": "nccl",
"rank": WORLD_RANK,
"world_size": WORLD_SIZE,
}
if opts.tcp_init or NUM_NODES > 1:
if NUM_NODES > 1:
assert (
"MASTER_ADDR" in os.environ
), "Multi-node run requires MASTER_ADDR to be set in the environment."
MASTER_ADDR = os.getenv("MASTER_ADDR", socket.gethostbyname(socket.gethostname()))
MASTER_PORT = os.getenv("MASTER_PORT", "1234")
dist_init_kwargs["init_method"] = f"tcp://{MASTER_ADDR}:{MASTER_PORT}"
if opts.bind_to_device or opts.bootstrap_backend == "nccl":
dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}")
assert dist.is_nccl_available()
dist.init_process_group(**dist_init_kwargs)
nccl_world = dist.new_group(backend="nccl")
dist_print(f"Initialized default NCCL process group with {WORLD_RANK} GPUs", nccl_world)
# Figure out process groups for tensor- and data-parallelism (if any)
if NUM_NODES > 1:
# Create a list of world ranks on this node
hostnames = [None for _ in range(WORLD_SIZE)]
hostname = socket.gethostname()
dist.all_gather_object(hostnames, hostname)
node_ranks = []
for i, host in enumerate(hostnames):
if host == hostname:
node_ranks.append(i)
if opts.num_replicas > 1:
# Split node ranks into multiple replicas
assert len(node_ranks) % opts.num_replicas == 0
tp_size = len(node_ranks) // opts.num_replicas
found_replica = False
for replica in range(opts.num_replicas):
start = replica * tp_size
end = start + tp_size
tp_ranks = node_ranks[start:end]
if WORLD_RANK in tp_ranks:
found_replica = True
break
assert found_replica
else:
# The entire node is the tensor-parallel group
tp_ranks = node_ranks
tp_group = dist.new_group(backend="nccl", ranks=tp_ranks)
tp_size = dist.get_world_size(tp_group)
tp_rank = dist.get_rank(tp_group)
# Data-parallelism across TP groups
dp_start = tp_rank
dp_end = dp_start + WORLD_SIZE
dp_ranks = list(range(dp_start, dp_end, tp_size))
dp_group = dist.new_group(backend="nccl", ranks=dp_ranks)
else:
if opts.num_replicas > 1:
# Mixed data- and tensor-parallelism on a single node
# NOTE: Avoid dist.init_device_mesh() to support older PyTorch versions
all_ranks = torch.tensor(list(range(LOCAL_SIZE)), dtype=torch.uint8, device="cpu")
mesh2d = all_ranks.reshape((opts.num_replicas, LOCAL_SIZE // opts.num_replicas))
node_idx = (mesh2d == LOCAL_RANK).nonzero().squeeze().tolist()
def dist_print(msg, end="\n", all_ranks=False):
if WORLD_RANK == 0 or all_ranks:
print(f"[RANK-{WORLD_RANK}] {msg}", end=end)
tp_ranks = mesh2d[node_idx[0], :].tolist()
tp_group = dist.new_group(backend="nccl", ranks=tp_ranks)
# Seed RNG
torch.cuda.set_device(WORLD_RANK)
torch.manual_seed(opts.seed + WORLD_RANK)
torch.cuda.manual_seed(opts.seed + WORLD_RANK)
dp_ranks = mesh2d[:, node_idx[1]].tolist()
dp_group = dist.new_group(backend="nccl", ranks=dp_ranks)
else:
dp_group = None
tp_group = nccl_world
# Initialize torch.distributed global process group and get TP group
dist.init_process_group(
backend="nccl",
rank=WORLD_RANK,
world_size=WORLD_SIZE,
device_id=torch.device(f"cuda:{WORLD_RANK}"),
tp_rank = dist.get_rank(tp_group)
tp_size = dist.get_world_size(tp_group)
dist_print(
f"Created tensor-parallel group: {dist.get_process_group_ranks(tp_group)}",
group=tp_group,
)
tp_group = dist.new_group(backend="nccl")
tp_size = dist.get_world_size(tp_group)
if dp_group is not None:
dist_print(
f"Created data-parallel group: {dist.get_process_group_ranks(dp_group)}",
group=dp_group,
)
# Intialize userbuffers
ag_cfg = { # Ring-exchange All-Gather overlap for fc1_fprop and fc2_dgrad
"method": "ring_exchange",
"num_splits": 8,
"num_sm": 1,
"set_sm_margin": False,
}
rs_cfg = { # Reduce-scatter overlap for fc1_dgrad and fc2_fprop
"method": "ring_exchange",
"num_splits": 4,
"num_sm": 1,
"set_sm_margin": True,
}
hidden_size = opts.num_heads * opts.head_dim
batched_size = opts.seq_length * opts.batch_size
if not opts.no_comm_overlap:
te.initialize_ub(
te.module.base.initialize_ub(
[batched_size, hidden_size],
tp_group,
tp_size,
use_fp8=opts.fp8,
dtype=torch.bfloat16,
ub_cfgs={
"fc1_fprop": ag_cfg,
"fc1_dgrad": rs_cfg,
"fc2_fprop": rs_cfg,
"fc2_dgrad": ag_cfg,
},
bootstrap_backend=opts.bootstrap_backend,
)
#
# Initialize the fused LayerNorm + Multi-layer Perceptron module
torch.manual_seed(opts.seed + tp_rank)
torch.cuda.manual_seed(opts.seed + tp_rank)
model = te.LayerNormMLP(
hidden_size,
opts.mlp_expansion_factor * hidden_size,
......@@ -114,11 +238,14 @@ def train(opts):
set_parallel_mode=True,
sequence_parallel=True, # this is required for comm+GEMM overlap
seq_length=opts.seq_length,
micro_batch_size=opts.batch_size,
ub_overlap_rs_dgrad=not opts.no_comm_overlap,
ub_overlap_rs=not opts.no_comm_overlap,
ub_overlap_ag=not opts.no_comm_overlap,
ub_overlap_rs_dgrad=not opts.no_comm_overlap,
ub_bulk_dgrad=False,
ub_bulk_wgrad=not opts.no_comm_overlap,
)
if dp_group is not None:
model = DistributedDataParallel(model, process_group=dp_group)
# Initialize optimizer with model parameters
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
......@@ -128,10 +255,11 @@ def train(opts):
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
# Start dummy "training" iterations
dist_print("Starting training iterations...", nccl_world)
for i in range(opts.num_iters):
dist_print(f"Iter {i+1}", all_ranks=opts.verbose)
dist_print(f" Iter {i+1}", tp_group, debug=True)
dist_print("|-- Generate random input batch", all_ranks=opts.verbose)
dist_print(" |-- Generate random input batch", tp_group, debug=True)
x = torch.rand(
(opts.seq_length // tp_size, opts.batch_size, hidden_size),
dtype=torch.bfloat16,
......@@ -139,30 +267,29 @@ def train(opts):
requires_grad=True,
)
dist_print("|-- Forward pass", all_ranks=opts.verbose)
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=tp_group):
dist_print(" |-- Forward pass", tp_group, debug=True)
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world):
y = model(x)
dist_print("|-- Compute loss", all_ranks=opts.verbose)
dist_print(" |-- Compute loss", tp_group, debug=True)
loss = y.flatten().sum()
dist_print("|-- Backward pass", all_ranks=opts.verbose)
dist_print(" |-- Backward pass", tp_group, debug=True)
loss.backward()
dist_print("|-- Optimizer step", all_ranks=opts.verbose)
dist_print(" |-- Optimizer step", tp_group, debug=True)
optim.step()
te.destroy_ub()
torch.cuda.synchronize()
dist_print("Finished training!")
te.module.base.destroy_ub()
dist_print("Destroying all process groups...", debug=True)
dist.destroy_process_group()
if opts.debug and WORLD_RANK == 0:
print("Exiting...\n", end="", flush=True)
return 0
if __name__ == "__main__":
if "TORCHELASTIC_RUN_ID" in os.environ.keys():
args = parse_args()
train(args)
else:
subprocess.run(
["torchrun", f"--nproc-per-node={torch.cuda.device_count()}", *sys.argv],
env=os.environ,
check=True,
)
os._exit(0)
sys.exit(_train(_parse_args()))
This diff is collapsed.
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import subprocess
from pathlib import Path
import pytest
import torch
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
RNG_SEED: int = 1234
SEQ_LENGTH: int = 2024
BATCH_SIZE: int = 2
NUM_HEADS: int = 64
HEAD_DIM: int = 128
TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(torch.cuda.device_count(), 4)
LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"]
if tex.ubuf_built_with_mpi():
LAUNCH_CMD = ["mpirun", "-np", str(NUM_PROCS), "--oversubscribe", "--quiet", "python"]
# Fall back on CUDA IPC if the platform does not support CUDA multicast
if not tex.device_supports_multicast():
os.environ["UB_SKIPMC"] = "1"
# Force GPU kernels to launch in the order they're executed by the host CPU
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
@pytest.mark.skipif(NUM_PROCS < 2, reason="Comm+GEMM overlap requires at least 2 GPUs.")
@pytest.mark.parametrize(
"fp8,p2p,comm_type,aggregate,atomic,bulk",
[
# FP8, P2P, Type, Aggregate, Atomic, Bulk
(False, True, "AG", False, False, False),
(False, True, "AG", True, False, False),
(True, True, "AG", False, False, False),
(True, True, "AG", True, False, False),
(False, False, "RS", False, False, False),
(False, True, "RS", False, False, False),
(True, False, "RS", False, False, False),
(True, True, "RS", False, False, False),
(True, False, "RS", False, True, False),
(True, True, "RS", False, True, False),
(False, False, "AG", False, False, True),
(False, False, "RS", False, False, True),
],
ids=[
" AG -> SPLIT GEMM | BF16 | RING-EXCHANGE ",
" AG -> SPLIT GEMM | BF16 | RING-EXCHANGE (2X AGGREGATED) ",
" AG -> SPLIT GEMM | FP8 | RING-EXCHANGE ",
" AG -> SPLIT GEMM | FP8 | RING-EXCHANGE (2X AGGREGATED) ",
" SPLIT GEMM -> RS | BF16 | PIPELINE ",
" SPLIT GEMM -> RS | BF16 | RING-EXCHANGE ",
" SPLIT GEMM -> RS | FP8 | PIPELINE ",
" SPLIT GEMM -> RS | FP8 | RING-EXCHANGE ",
" ATOMIC GEMM -> RS | FP8 | PIPELINE ",
" ATOMIC GEMM -> RS | FP8 | RING-EXCHANGE ",
" BULK AG & GEMM | BF16 | PIPELINE ",
" BULK RS & GEMM | BF16 | PIPELINE ",
],
)
def test_gemm_with_overlap(fp8, p2p, comm_type, aggregate, atomic, bulk):
"""
Test comm+GEMM overlap algorithms with direct calls to
te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm
"""
test_path = TEST_ROOT / "run_gemm_with_overlap.py"
test_cmd = (
LAUNCH_CMD
+ [str(test_path)]
+ [
"--check-numerics",
f"--seed={RNG_SEED}",
f"--seq-length={SEQ_LENGTH}",
f"--batch-size={BATCH_SIZE}",
f"--num-heads={NUM_HEADS}",
f"--head-dim={HEAD_DIM}",
f"--comm-type={comm_type}",
]
)
if bulk:
test_cmd.append("--bulk-overlap")
else:
if fp8:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
test_cmd.append("--fp8")
if p2p:
test_cmd.append("--p2p")
if aggregate:
test_cmd.append("--aggregate")
if atomic:
if torch.cuda.get_device_properties(0).major < 9:
pytest.skip("Device compute capability 9.0 or higher required for Atomic GEMM.")
test_cmd.append("--atomic")
output = subprocess.run(test_cmd, env=os.environ, text=True, capture_output=True, check=False)
assert "NUMERICAL CHECK PASSED" in str(output)
......@@ -206,11 +206,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv)
.def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history);
// Communication functions to initialize Userbuffers communicators
// Note: Callbacks are not called, so safe to release GIL.
m.def("set_ubuf_bootstrap_callbacks", &ubuf::set_ubuf_bootstrap_callbacks,
m.def("device_supports_multicast", &ubuf::device_supports_multicast,
py::call_guard<py::gil_scoped_release>());
m.def("ubuf_built_with_mpi", &ubuf::ubuf_built_with_mpi,
py::call_guard<py::gil_scoped_release>());
py::class_<ubuf::UbufBootstrapCallbacks>(m, "UbufBootstrapCallbacks")
.def(py::init<>(), py::call_guard<py::gil_scoped_release>())
.def(py::init<c10d::ProcessGroup *, c10d::ProcessGroup *>(),
py::call_guard<py::gil_scoped_release>());
py::enum_<ubuf::UBOverlapAlgo>(m, "UbufOverlapAlgo")
.value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG)
.value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS)
......@@ -225,8 +231,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// communicator with Python functions (e.g. PyTorch distributed
// communication)
py::class_<ubuf::UbufCommOverlap>(m, "UbufCommOverlap")
.def(py::init<torch::Tensor&, int, int, int, int, int, int, int, bool, int, bool,
torch::Tensor>())
.def(py::init<torch::Tensor &, int, int, int, int, int, int, int, int, int, int, bool, int,
bool, ubuf::UbufBootstrapCallbacks &>(),
py::call_guard<py::gil_scoped_release>())
.def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap,
py::call_guard<py::gil_scoped_release>())
.def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs,
......@@ -250,8 +257,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// communicator with Python functions (e.g. PyTorch distributed
// communication)
py::class_<ubuf::UbufP2PCommOverlap>(m, "UbufP2PCommOverlap")
.def(py::init<torch::Tensor&, int, int, int, int, int, int, bool, bool, int, bool, bool, bool,
torch::Tensor>())
.def(py::init<torch::Tensor &, int, int, int, int, int, int, int, int, int, bool, bool, int,
bool, bool, bool, ubuf::UbufBootstrapCallbacks &>(),
py::call_guard<py::gil_scoped_release>())
.def("split_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_ag,
py::call_guard<py::gil_scoped_release>())
.def("split_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_rs,
......
......@@ -7,66 +7,82 @@
#include "ipcsocket.h"
#include <errno.h>
#include <stdarg.h>
#include <stdlib.h>
#include <string.h>
#define WARN(...) \
{}
#define TRACE(...) \
{}
#define SYSCHECK(...) \
{}
#define EQCHECK(...) \
{}
#define IPC_MAX_MSGLEN 4096
// Enable Linux abstract socket naming
#define USE_ABSTRACT_SOCKET
void ipc_warn(const char *format, ...) {
char buffer[IPC_MAX_MSGLEN];
#define NCCL_IPC_SOCKNAME_STR "/tmp/nccl-socket-%d-%lx"
va_list args;
va_start(args, format);
vsnprintf(buffer, IPC_MAX_MSGLEN - 1, format, args);
snprintf(buffer + strlen(buffer), IPC_MAX_MSGLEN - strlen(buffer) - 1, " : %s (%d)\n",
strerror(errno), errno);
fflush(stdout);
fputs(buffer, stderr);
fflush(NULL);
va_end(args);
}
static const char *ipcSocketResultStrings[static_cast<int>(ipcSocketNumResults)] = {
"Success", "Unhandled CUDA error", "System error", "Internal error",
"Invalid argument", "Invalid usage", "Remote error", "In progress",
};
const char *ipcSocketGetErrorString(ipcSocketResult_t res) {
return ipcSocketResultStrings[static_cast<int>(res)];
}
#define USE_ABSTRACT_SOCKET // Enable Linux abstract socket naming
#define IPC_SOCKNAME_STR "/tmp/ub-ipc-socket-%d-%lx"
/*
* Create a Unix Domain Socket
*/
ncclResult_t ncclIpcSocketInit(ncclIpcSocket *handle, int rank, uint64_t hash,
volatile uint32_t *abortFlag) {
ipcSocketResult_t ipcSocketInit(IpcSocketHandle *handle, int rank, uint64_t hash,
volatile uint32_t *abortFlag) {
int fd = -1;
struct sockaddr_un cliaddr;
char temp[NCCL_IPC_SOCKNAME_LEN] = "";
char temp[IPC_SOCKNAME_LEN] = "";
if (handle == NULL) {
return ncclInternalError;
return ipcSocketInternalError;
}
handle->fd = -1;
handle->socketName[0] = '\0';
if ((fd = socket(AF_UNIX, SOCK_DGRAM, 0)) < 0) {
WARN("UDS: Socket creation error : %s (%d)", strerror(errno), errno);
return ncclSystemError;
ipc_warn("UDS: Socket creation error");
return ipcSocketSystemError;
}
bzero(&cliaddr, sizeof(cliaddr));
cliaddr.sun_family = AF_UNIX;
// Create unique name for the socket.
size_t len = snprintf(temp, NCCL_IPC_SOCKNAME_LEN, NCCL_IPC_SOCKNAME_STR, rank, hash);
size_t len = snprintf(temp, IPC_SOCKNAME_LEN, IPC_SOCKNAME_STR, rank, hash);
if (len > (sizeof(cliaddr.sun_path) - 1)) {
WARN("UDS: Cannot bind provided name to socket. Name too large");
return ncclInternalError;
errno = ENAMETOOLONG;
ipc_warn("UDS: Cannot bind provided name to socket. Name too large");
return ipcSocketInternalError;
}
#ifndef USE_ABSTRACT_SOCKET
unlink(temp);
#endif
TRACE(NCCL_INIT, "UDS: Creating socket %s", temp);
strncpy(cliaddr.sun_path, temp, len);
#ifdef USE_ABSTRACT_SOCKET
cliaddr.sun_path[0] = '\0'; // Linux abstract socket trick
#else
unlink(temp);
#endif
if (bind(fd, (struct sockaddr *)&cliaddr, sizeof(cliaddr)) < 0) {
WARN("UDS: Binding to socket %s failed : %s (%d)", temp, strerror(errno), errno);
ipc_warn("UDS: Binding to socket %s failed", temp);
close(fd);
return ncclSystemError;
return ipcSocketSystemError;
}
handle->fd = fd;
......@@ -79,24 +95,25 @@ ncclResult_t ncclIpcSocketInit(ncclIpcSocket *handle, int rank, uint64_t hash,
fcntl(fd, F_SETFL, flags | O_NONBLOCK);
}
return ncclSuccess;
return ipcSocketSuccess;
}
ncclResult_t ncclIpcSocketGetFd(struct ncclIpcSocket *handle, int *fd) {
ipcSocketResult_t ipcSocketGetFd(struct IpcSocketHandle *handle, int *fd) {
if (handle == NULL) {
WARN("ncclSocketGetFd: pass NULL socket");
return ncclInvalidArgument;
errno = EINVAL;
ipc_warn("ipcSocketSocketGetFd: pass NULL socket");
return ipcSocketInvalidArgument;
}
if (fd) *fd = handle->fd;
return ncclSuccess;
return ipcSocketSuccess;
}
ncclResult_t ncclIpcSocketClose(ncclIpcSocket *handle) {
ipcSocketResult_t ipcSocketClose(IpcSocketHandle *handle) {
if (handle == NULL) {
return ncclInternalError;
return ipcSocketInternalError;
}
if (handle->fd <= 0) {
return ncclSuccess;
return ipcSocketSuccess;
}
#ifndef USE_ABSTRACT_SOCKET
if (handle->socketName[0] != '\0') {
......@@ -105,10 +122,10 @@ ncclResult_t ncclIpcSocketClose(ncclIpcSocket *handle) {
#endif
close(handle->fd);
return ncclSuccess;
return ipcSocketSuccess;
}
ncclResult_t ncclIpcSocketRecvMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, int *recvFd) {
ipcSocketResult_t ipcSocketRecvMsg(IpcSocketHandle *handle, void *hdr, int hdrLen, int *recvFd) {
struct msghdr msg = {0, 0, 0, 0, 0, 0, 0};
struct iovec iov[1];
......@@ -138,39 +155,44 @@ ncclResult_t ncclIpcSocketRecvMsg(ncclIpcSocket *handle, void *hdr, int hdrLen,
while ((ret = recvmsg(handle->fd, &msg, 0)) <= 0) {
if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) {
WARN("UDS: Receiving data over socket failed : %d", errno);
return ncclSystemError;
ipc_warn("UDS: Receiving data over socket failed");
return ipcSocketSystemError;
}
if (handle->abortFlag && *handle->abortFlag) return ncclInternalError;
if (handle->abortFlag && *handle->abortFlag) return ipcSocketInternalError;
}
if (recvFd != NULL) {
if (((cmptr = CMSG_FIRSTHDR(&msg)) != NULL) && (cmptr->cmsg_len == CMSG_LEN(sizeof(int)))) {
if ((cmptr->cmsg_level != SOL_SOCKET) || (cmptr->cmsg_type != SCM_RIGHTS)) {
WARN("UDS: Receiving data over socket failed");
return ncclSystemError;
errno = EBADMSG;
ipc_warn("UDS: Receiving data over socket %s failed", handle->socketName);
return ipcSocketSystemError;
}
memmove(recvFd, CMSG_DATA(cmptr), sizeof(*recvFd));
} else {
WARN("UDS: Receiving data over socket %s failed", handle->socketName);
return ncclSystemError;
errno = ENOMSG;
ipc_warn("UDS: Receiving data over socket %s failed", handle->socketName);
return ipcSocketSystemError;
}
TRACE(NCCL_INIT | NCCL_P2P, "UDS: Got recvFd %d from socket %s", *recvFd, handle->socketName);
} else {
errno = EINVAL;
ipc_warn("UDS: File descriptor pointer cannot be NULL");
return ipcSocketInvalidArgument;
}
return ncclSuccess;
return ipcSocketSuccess;
}
ncclResult_t ncclIpcSocketRecvFd(ncclIpcSocket *handle, int *recvFd) {
return ncclIpcSocketRecvMsg(handle, NULL, 0, recvFd);
ipcSocketResult_t ipcSocketRecvFd(IpcSocketHandle *handle, int *recvFd) {
return ipcSocketRecvMsg(handle, NULL, 0, recvFd);
}
ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen, const int sendFd,
int rank, uint64_t hash) {
ipcSocketResult_t ipcSocketSendMsg(IpcSocketHandle *handle, void *hdr, int hdrLen, const int sendFd,
int rank, uint64_t hash) {
struct msghdr msg = {0, 0, 0, 0, 0, 0, 0};
struct iovec iov[1];
char temp[NCCL_IPC_SOCKNAME_LEN];
char temp[IPC_SOCKNAME_LEN];
union {
struct cmsghdr cm;
......@@ -185,10 +207,11 @@ ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen,
bzero(&cliaddr, sizeof(cliaddr));
cliaddr.sun_family = AF_UNIX;
size_t len = snprintf(temp, NCCL_IPC_SOCKNAME_LEN, NCCL_IPC_SOCKNAME_STR, rank, hash);
size_t len = snprintf(temp, IPC_SOCKNAME_LEN, IPC_SOCKNAME_STR, rank, hash);
if (len > (sizeof(cliaddr.sun_path) - 1)) {
WARN("UDS: Cannot connect to provided name for socket. Name too large");
return ncclInternalError;
errno = ENAMETOOLONG;
ipc_warn("UDS: Cannot connect to provided name for socket. Name too large");
return ipcSocketInternalError;
}
(void)strncpy(cliaddr.sun_path, temp, len);
......@@ -196,11 +219,7 @@ ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen,
cliaddr.sun_path[0] = '\0'; // Linux abstract socket trick
#endif
TRACE(NCCL_INIT, "UDS: Sending hdr %p len %d to UDS socket %s", hdr, hdrLen, temp);
if (sendFd != -1) {
TRACE(NCCL_INIT, "UDS: Sending fd %d to UDS socket %s", sendFd, temp);
msg.msg_control = control_un.control;
msg.msg_controllen = sizeof(control_un.control);
......@@ -228,15 +247,16 @@ ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen,
ssize_t sendResult;
while ((sendResult = sendmsg(handle->fd, &msg, 0)) < 0) {
if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) {
WARN("UDS: Sending data over socket %s failed : %s (%d)", temp, strerror(errno), errno);
return ncclSystemError;
ipc_warn("UDS: Sending data over socket %s failed", temp);
return ipcSocketSystemError;
}
if (handle->abortFlag && *handle->abortFlag) return ncclInternalError;
if (handle->abortFlag && *handle->abortFlag) return ipcSocketInternalError;
}
return ncclSuccess;
return ipcSocketSuccess;
}
ncclResult_t ncclIpcSocketSendFd(ncclIpcSocket *handle, const int sendFd, int rank, uint64_t hash) {
return ncclIpcSocketSendMsg(handle, NULL, 0, sendFd, rank, hash);
ipcSocketResult_t ipcSocketSendFd(IpcSocketHandle *handle, const int sendFd, int rank,
uint64_t hash) {
return ipcSocketSendMsg(handle, NULL, 0, sendFd, rank, hash);
}
......@@ -4,10 +4,9 @@
* See LICENSE for license information.
************************************************************************/
#ifndef NCCL_IPCSOCKET_H
#define NCCL_IPCSOCKET_H
#ifndef TRANSFORMER_ENGINE_USERBUFFERS_IPCSOCKET_H
#define TRANSFORMER_ENGINE_USERBUFFERS_IPCSOCKET_H
// #include "nccl.h"
#include <errno.h>
#include <fcntl.h>
#include <inttypes.h>
......@@ -21,32 +20,33 @@
#include <unistd.h>
typedef enum {
ncclSuccess = 0,
ncclUnhandledCudaError = 1,
ncclSystemError = 2,
ncclInternalError = 3,
ncclInvalidArgument = 4,
ncclInvalidUsage = 5,
ncclRemoteError = 6,
ncclInProgress = 7,
ncclNumResults = 8
} ncclResult_t;
#define NCCL_IPC_SOCKNAME_LEN 64
struct ncclIpcSocket {
ipcSocketSuccess = 0,
ipcSocketUnhandledCudaError = 1,
ipcSocketSystemError = 2,
ipcSocketInternalError = 3,
ipcSocketInvalidArgument = 4,
ipcSocketInvalidUsage = 5,
ipcSocketRemoteError = 6,
ipcSocketInProgress = 7,
ipcSocketNumResults = 8
} ipcSocketResult_t;
const char *ipcSocketGetErrorString(ipcSocketResult_t res);
#define IPC_SOCKNAME_LEN 64
struct IpcSocketHandle {
int fd;
char socketName[NCCL_IPC_SOCKNAME_LEN];
char socketName[IPC_SOCKNAME_LEN];
volatile uint32_t *abortFlag;
};
ncclResult_t ncclIpcSocketInit(struct ncclIpcSocket *handle, int rank, uint64_t hash,
volatile uint32_t *abortFlag);
ncclResult_t ncclIpcSocketClose(struct ncclIpcSocket *handle);
ncclResult_t ncclIpcSocketGetFd(struct ncclIpcSocket *handle, int *fd);
ipcSocketResult_t ipcSocketInit(IpcSocketHandle *handle, int rank, uint64_t hash,
volatile uint32_t *abortFlag);
ipcSocketResult_t ipcSocketClose(IpcSocketHandle *handle);
ipcSocketResult_t ipcSocketGetFd(IpcSocketHandle *handle, int *fd);
ncclResult_t ncclIpcSocketRecvFd(struct ncclIpcSocket *handle, int *fd);
ncclResult_t ncclIpcSocketSendFd(struct ncclIpcSocket *handle, const int fd, int rank,
uint64_t hash);
ipcSocketResult_t ipcSocketRecvFd(IpcSocketHandle *handle, int *fd);
ipcSocketResult_t ipcSocketSendFd(IpcSocketHandle *handle, const int fd, int rank, uint64_t hash);
#endif /* NCCL_IPCSOCKET_H */
#endif /* TRANSFORMER_ENGINE_USERBUFFERS_IPCSOCKET_H */
......@@ -23,15 +23,6 @@
#define MAX_THREADS 1024
#define CUDACHECK(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while (0)
#define ATOMIC_CONSUMER(chunk) \
if (counters) { \
if (threadIdx.x == 0 && blockIdx.x == 0) { \
......@@ -1391,7 +1382,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg5), reinterpret_cast<void *>(&arg6), \
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10)}; \
CUDACHECK(cudaLaunchKernelExC( \
NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, \
reinterpret_cast<void *>(comm->use_rr_kernel ? userbuffers_fp16_sum_inplace_gpu_rr_ag<x> \
: userbuffers_fp16_sum_inplace_gpu_rw_ag<x>), \
......@@ -1416,7 +1407,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \
reinterpret_cast<void *>(&arg11)}; \
CUDACHECK(cudaLaunchKernelExC( \
NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_mc_ag<x>), kernelArgs)); \
}
......@@ -1436,7 +1427,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg5), reinterpret_cast<void *>(&arg6), \
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10)}; \
CUDACHECK(cudaLaunchKernelExC( \
NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_rr_rs<x>), kernelArgs)); \
}
......@@ -1458,7 +1449,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \
reinterpret_cast<void *>(&arg11)}; \
CUDACHECK(cudaLaunchKernelExC( \
NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_mc_rs<x>), kernelArgs)); \
}
......@@ -1481,7 +1472,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \
reinterpret_cast<void *>(&arg13)}; \
CUDACHECK(cudaLaunchKernelExC( \
NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop<x>), \
kernelArgs)); \
}
......@@ -1506,7 +1497,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14)}; \
CUDACHECK(cudaLaunchKernelExC( \
NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, \
reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_fp8<x, fp8type>), \
kernelArgs)); \
......@@ -1532,7 +1523,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14)}; \
CUDACHECK(cudaLaunchKernelExC( \
NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_mc_rs_oop<x>), \
kernelArgs)); \
}
......@@ -1562,7 +1553,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14), \
reinterpret_cast<void *>(&arg15), reinterpret_cast<void *>(&arg16), \
reinterpret_cast<void *>(&arg17), reinterpret_cast<void *>(&arg18)}; \
CUDACHECK(cudaLaunchKernelExC( \
NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, \
reinterpret_cast<void *>( \
userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_atomic_fp8<x, fp8type>), \
......@@ -1588,7 +1579,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \
reinterpret_cast<void *>(&arg13)}; \
CUDACHECK(cudaLaunchKernelExC( \
NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride<x>), \
kernelArgs)); \
}
......@@ -1614,7 +1605,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14), \
reinterpret_cast<void *>(&arg15)}; \
CUDACHECK(cudaLaunchKernelExC( \
NVTE_CHECK_CUDA(cudaLaunchKernelExC( \
&cfg, \
reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_atomic<x>), \
kernelArgs)); \
......@@ -1641,7 +1632,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12), \
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14), \
reinterpret_cast<void *>(&arg15)}; \
CUDACHECK( \
NVTE_CHECK_CUDA( \
cudaLaunchKernelExC(&cfg, \
reinterpret_cast<void *>( \
userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_multiatomic<x>), \
......@@ -2206,15 +2197,6 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat
}
}
#define CUDACHECK(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while (0)
// Return TRUE if two ranks share the same NV domain
#define INTRANODE(peer) ((peer / comm->nvsize) == (comm->myrank / comm->nvsize))
......@@ -2259,7 +2241,7 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds
if (comm->use_ce) {
// kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_start_ptr));
CUDACHECK(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
NVTE_CHECK_CUDA(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
// kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_end_ptr));
}
SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream);
......@@ -2269,7 +2251,7 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds
void *kernelArgs[] = {reinterpret_cast<void *>(&arg1), reinterpret_cast<void *>(&arg2),
reinterpret_cast<void *>(&arg3), reinterpret_cast<void *>(&arg4),
reinterpret_cast<void *>(&arg5)};
CUDACHECK(
NVTE_CHECK_CUDA(
cudaLaunchKernelExC(&cfg, reinterpret_cast<void *>(kuserbuffers_pushsend), kernelArgs));
}
}
......@@ -2291,7 +2273,8 @@ void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size
if (comm->use_ce) {
// kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_start_ptr));
CUDACHECK(cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
NVTE_CHECK_CUDA(
cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
// kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_end_ptr));
}
SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream);
......@@ -2323,7 +2306,7 @@ void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12),
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14),
reinterpret_cast<void *>(&arg15)};
CUDACHECK(
NVTE_CHECK_CUDA(
cudaLaunchKernelExC(&cfg, reinterpret_cast<void *>(kuserbuffers_pushsendrecv), kernelArgs));
}
......@@ -2346,7 +2329,8 @@ void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler,
reinterpret_cast<char *>(comm->peer_ptr[dsthandler][send_peerlocal]) + send_offset;
if (comm->use_ce) {
// kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_start_ptr));
CUDACHECK(cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
NVTE_CHECK_CUDA(
cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
// kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast<int *>(ce_send_end_ptr));
}
SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream);
......@@ -2379,8 +2363,8 @@ void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler,
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12),
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14),
reinterpret_cast<void *>(&arg15), reinterpret_cast<void *>(&arg16)};
CUDACHECK(cudaLaunchKernelExC(&cfg, reinterpret_cast<void *>(kuserbuffers_pushsendrecv_atomic),
kernelArgs));
NVTE_CHECK_CUDA(cudaLaunchKernelExC(
&cfg, reinterpret_cast<void *>(kuserbuffers_pushsendrecv_atomic), kernelArgs));
}
void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler,
......@@ -2425,7 +2409,7 @@ void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler
reinterpret_cast<void *>(&arg13), reinterpret_cast<void *>(&arg14),
reinterpret_cast<void *>(&arg15), reinterpret_cast<void *>(&arg16),
reinterpret_cast<void *>(&arg17), reinterpret_cast<void *>(&arg18)};
CUDACHECK(cudaLaunchKernelExC(
NVTE_CHECK_CUDA(cudaLaunchKernelExC(
&cfg, reinterpret_cast<void *>(kuserbuffers_pushsendrecv_multiatomic), kernelArgs));
}
......@@ -2451,7 +2435,7 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds
if (!signalonly)
kuserbuffers_inc<<<1, 1, 0, stream>>>(&(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler]));
if (comm->use_ce) {
CUDACHECK(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
NVTE_CHECK_CUDA(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
}
} else {
kuserbuffers_pushrecv<<<1, 1, 0, stream>>>(
......
......@@ -15,39 +15,11 @@
#include <functional>
#include <stdexcept>
#ifdef UB_MPI_BOOTSTRAP
#include <mpi.h>
#include <stdexcept>
#define UB_MPI_CHECK(expr) \
do { \
const int mpicode = (expr); \
if (mpicode != MPI_SUCCESS) { \
char mpimsg[MPI_MAX_ERROR_STRING]; \
int mpilen; \
MPI_Error_string(mpicode, mpimsg, &mpilen); \
std::vector<char> errmsg(1024); \
snprintf(errmsg.data(), errmsg.size(), "%s:%s in function %s: %s", __FILE__, __LINE__, \
__func__, mpimsg); \
throw std::runtime_error(errmsg.data()); \
} \
} while (false)
#include "common/util/logging.h"
#ifdef NVTE_UB_WITH_MPI
#include <mpi.h>
typedef MPI_Comm ExtComm;
void ub_alloc_copy_allgather(void **globaldata, void *localdata, size_t localbytes, ExtComm comm) {
int myrank, nranks;
UB_MPI_CHECK(MPI_Comm_rank(comm, &myrank));
UB_MPI_CHECK(MPI_Comm_size(comm, &nranks));
*globaldata = malloc(nranks * localbytes);
UB_MPI_CHECK(MPI_Allgather(localdata, localbytes, MPI_BYTE, *globaldata, nranks * localbytes,
MPI_BYTE, comm));
}
void ub_barrier(ExtComm comm) { UB_MPI_CHECK(MPI_Barrier(comm)); }
void ub_free(void *ptr) { free(ptr); }
#else
typedef char *ExtComm;
#endif
......@@ -170,14 +142,13 @@ struct communicator {
volatile int tail;
// Abstract communication callbacks to support external bootstrapping (e.g. DL frameworks)
std::function<void(void **, void *, size_t, ExtComm)> _alloc_copy_allgather;
std::function<void(void *, size_t, void *, size_t, ExtComm)> _allgather;
std::function<void(ExtComm)> _barrier;
std::function<void(void *)> _free;
ExtComm comm_world,
comm_inter, // reduction group communicator (subset of the nodes) along GPU rail
comm_intra; // full intranode (all ndev GPUS)
#ifdef UB_MPI_BOOTSTRAP
#ifdef NVTE_UB_WITH_MPI
MPI_Request mpihndl[NVTE_MAX_SHARP];
#endif
......@@ -194,20 +165,19 @@ void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStr
/* creates communicator, allocates all internal buffers if necessary */
int create_communicator_grouped2(
communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode,
int numnodes, std::function<void(void **, void *, size_t, ExtComm)> ext_alloc_copy_allgather,
std::function<void(ExtComm)> ext_barrier, std::function<void(void *)> ext_free, int pipegpus,
int pipenodes, int tensorgpus, int tensornodes);
int numnodes, std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather,
std::function<void(ExtComm)> ext_barrier, int pipegpus, int pipenodes, int tensorgpus,
int tensornodes);
int create_communicator_grouped(
communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode,
int numnodes, std::function<void(void **, void *, size_t, ExtComm)> ext_alloc_copy_allgather,
std::function<void(ExtComm)> ext_barrier, std::function<void(void *)> ext_free, int pipegpus,
int pipenodes);
int numnodes, std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather,
std::function<void(ExtComm)> ext_barrier, int pipegpus, int pipenodes);
int create_communicator(
communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode,
int numnodes, std::function<void(void **, void *, size_t, ExtComm)> ext_alloc_copy_allgather,
std::function<void(ExtComm)> ext_barrier, std::function<void(void *)> ext_free);
int create_communicator(communicator **comm, int myrank, int numranks, int mylocal, int numlocal,
int mynode, int numnodes,
std::function<void(void *, size_t, void *, size_t, ExtComm)> ext_allgather,
std::function<void(ExtComm)> ext_barrier);
int create_communicator_grouped2_mpi(communicator **comm, int pipegpus, int pipenodes,
int tensorgpus, int tensornodes);
......
......@@ -7,6 +7,9 @@ import io
import os
import pickle
import warnings
import socket
import fcntl
import struct
from abc import ABC, abstractmethod
from typing import Dict, Generator, List, Optional, Tuple, Union
from contextlib import contextmanager
......@@ -79,19 +82,109 @@ def get_multi_stream_cublas_workspace() -> List[torch.Tensor]:
def initialize_ub(
shape: list,
tp_group: dist_group_type,
tp_size: int,
use_fp8: bool = False,
dtype: torch.dtype = torch.bfloat16,
ub_cfgs: Optional[dict] = None,
bootstrap_backend: Union[str, torch.distributed.Backend] = None,
) -> None:
"""Initialize communicators for TP comm overlap using userbuffers."""
if not tex.device_supports_multicast():
assert bool(os.getenv("UB_SKIPMC", "0")), (
"CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with "
+ "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead."
)
global _ub_communicators
assert _ub_communicators is None, "UB communicators are already initialized."
_ub_communicators = {}
rank_id = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
tp_id = torch.distributed.get_rank(tp_group)
tp_size = torch.distributed.get_world_size(tp_group)
if tex.ubuf_built_with_mpi():
# Userbuffers will ignore all these values when it is built with MPI, so these are just
# placeholders based on an assumption that tp_size covers all devices in a physical node.
assert torch.distributed.is_mpi_available()
mpi_group = torch.distributed.new_group(backend="mpi")
world_rank = torch.distributed.get_rank(mpi_group)
world_size = torch.distributed.get_world_size(mpi_group)
local_rank = world_rank % tp_size
local_size = tp_size
node_id = world_rank // tp_size
num_nodes = world_size // tp_size
ub_callbacks = tex.UbufBootstrapCallbacks()
else:
assert (
torch.distributed.is_initialized()
), "torch.distributed must be initialized before Userbuffers"
if bootstrap_backend is None:
bootstrap_backend = "nccl"
if torch.distributed.is_gloo_available():
bootstrap_backend = "gloo"
elif torch.distributed.is_mpi_available():
bootstrap_backend = "mpi"
else:
assert bootstrap_backend in ["gloo", "mpi", "nccl"]
world_group = torch.distributed.new_group(backend=bootstrap_backend)
world_rank = torch.distributed.get_rank(world_group)
world_size = torch.distributed.get_world_size(world_group)
if world_rank == 0:
print(
f'!!! [NVTE] Bootstrapping Userbuffers with backend="{bootstrap_backend}"\n',
end="",
flush=True,
)
# Construct an intra-node communicator based on global ranks that share the same hostname
# NOTE: If the user specified a valid network interface for NCCL or GLOO, use the host
# address on that interface instead of the hostname. This can help avoid issues when
# different hosts have the same hostname on Kubernetes clusters.
hostname = socket.gethostname()
ifname = os.getenv(
"NVTE_UB_SOCKET_IFNAME",
os.getenv("NCCL_SOCKET_IFNAME", os.getenv("GLOO_SOCKET_IFNAME")),
)
if ifname is not None:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
hostname = socket.inet_ntoa(
fcntl.ioctl(
s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8"))
)[20:24]
)
except OSError as err:
raise OSError(f"Invalid network interface: {ifname}") from err
hostnames = [None for _ in range(world_size)]
torch.distributed.all_gather_object(hostnames, hostname, world_group)
intra_node_ranks = []
for i, host in enumerate(hostnames):
if host == hostname:
intra_node_ranks.append(i)
if len(intra_node_ranks) == world_size:
intra_node_group = world_group
local_rank = world_rank
local_size = world_size
intra_node_ranks = list(range(world_size))
else:
intra_node_group = torch.distributed.new_group(
backend=bootstrap_backend, ranks=intra_node_ranks
)
local_rank = torch.distributed.get_rank(intra_node_group)
local_size = torch.distributed.get_world_size(intra_node_group)
node_id = world_rank // local_size
num_nodes = world_size // local_size
if local_rank == 0:
print(
f"!!! [NVTE] Number of physical nodes: {num_nodes}\n"
+ f"!!! [NVTE] Global ranks on node {node_id}: {intra_node_ranks}\n",
end="",
flush=True,
)
ub_callbacks = tex.UbufBootstrapCallbacks(world_group, intra_node_group)
# Increase the workspace by the number of maximum concurrent streams
global _cublas_workspace
......@@ -127,6 +220,23 @@ def initialize_ub(
return method
raise KeyError(f"Given layer name {name} does not exist.")
def get_default_config(name):
method = get_method(name)
is_reduce_scatter = name in layers_reduce_scatter_overlap
default_cfg = {
"method": method,
"is_reduce_scatter": is_reduce_scatter,
"num_sm": 1 if method == "ring_exchange" else 16,
"cga_size": 1 if method == "ring_exchange" else 2,
"set_sm_margin": False,
"num_splits": 4 if method == "pipeline" else tp_size,
"aggregate": False,
"atomic_gemm": False,
"use_ce": True,
"fp8_buf": name in layers_all_gather_overlap,
}
return default_cfg
def add_ub(
name: str,
method: str,
......@@ -180,53 +290,43 @@ def initialize_ub(
if method == "ring_exchange":
ub_obj = tex.UbufP2PCommOverlap(
sample_buffer, # Sample userbuffer
rank_id, # Rank id
world_rank, # World rank
world_size, # World size
tp_id, # TP id
tp_size, # TP size
local_rank, # Rank within the node
local_size, # Number of ranks/GPUs per node
node_id, # Node ID
num_nodes, # Number of nodes
tp_size, # Tensor-parallel group size (may be different than local_size)
num_sm, # Number of communication SMs
cga_size, # CGA cluster size
set_sm_margin, # Set SM margin
aggregate, # Aggregate 2X GEMM chunks
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams
is_reduce_scatter, # overlap with reduce scatter
atomic_gemm, # use a single GEMM with atomic-counters
use_ce, # use copy engine for P2P communications
torch.Tensor(), # empty tensor to pass to counters
is_reduce_scatter, # Overlap with reduce scatter
atomic_gemm, # Use a single GEMM with atomic-counters
use_ce, # Use copy engine for P2P communications
ub_callbacks,
)
else:
ub_obj = tex.UbufCommOverlap(
sample_buffer, # Sample userbuffer
rank_id, # Rank id
world_rank, # World rank
world_size, # World size
tp_id, # TP id
tp_size, # TP size
local_rank, # Rank within the node
local_size, # Number of ranks/GPUs per node
node_id, # Node ID
num_nodes, # Number of nodes
tp_size, # Tensor-parallel group size (may be different than local_size)
num_sm, # Number of communication SMs
cga_size, # CGA cluster size
num_splits, # Number of communication splits
set_sm_margin, # Set SM margin
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams
atomic_gemm, # use a single GEMM with atomic-counters
torch.Tensor(), # empty tensor to pass to counters
atomic_gemm, # Use a single GEMM with atomic-counters
ub_callbacks,
)
_ub_communicators[name] = ub_obj
def alloc_copy_allgather_callback(local_data: torch.Tensor, group: str) -> torch.Tensor:
pg = None if group == "world" else tp_group
global_size = local_data.numel() * torch.distributed.get_world_size(pg)
global_data = torch.zeros(global_size, dtype=local_data.dtype, device="cuda")
torch.distributed.all_gather_into_tensor(global_data, local_data.cuda(), group=pg)
return global_data.cpu()
def barrier_callback(group: str) -> None:
pg = None if group == "world" else tp_group
torch.distributed.barrier(group=pg)
def free_callback(data: torch.Tensor) -> None:
data.data = torch.Tensor()
tex.set_ubuf_bootstrap_callbacks(alloc_copy_allgather_callback, barrier_callback, free_callback)
if ub_cfgs is not None:
for name in dgrad_reduce_scatter_overlap:
if name in ub_cfgs and "method" in ub_cfgs[name] and ub_cfgs[name]["method"] != "bulk":
......@@ -238,48 +338,18 @@ def initialize_ub(
methods["pipeline"].append(name)
for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]:
ub_cfg = get_default_config(name)
if ub_cfgs is not None and name in ub_cfgs:
ub_cfg = ub_cfgs[name]
method = ub_cfg.get("method", get_method(name))
num_sm = ub_cfg.get("num_sm", 1 if method == "ring_exchange" else 16)
cga_size = ub_cfg.get("cga_size", 1 if method == "ring_exchange" else 2)
num_splits = ub_cfg.get("num_splits", 4 if method == "pipeline" else 0)
set_sm_margin = ub_cfg.get("set_sm_margin", 0)
aggregate = ub_cfg.get("aggregate", 0)
atomic_gemm = ub_cfg.get("atomic_gemm", 0)
use_ce = ub_cfg.get("use_ce", True)
is_reduce_scatter = 1 if name in layers_reduce_scatter_overlap else 0
# Support FP8 userbuffer when (1) AllGather and (2) FP8-GEMM output ReduceScatter
fp8_buf = (name in layers_all_gather_overlap) or (
ub_cfg.get("fp8_buf", False) and name in methods["pipeline"]
)
add_ub(
name,
method,
is_reduce_scatter,
num_sm,
cga_size,
set_sm_margin,
num_splits,
aggregate,
atomic_gemm,
use_ce,
fp8_buf,
)
else:
method = get_method(name)
add_ub(
name,
method=method,
is_reduce_scatter=1 if name in layers_reduce_scatter_overlap else 0,
num_splits=4 if method == "pipeline" else 0,
fp8_buf=name in layers_all_gather_overlap,
ub_cfgs[name].get("fp8_buf", False) and name in methods["pipeline"]
)
ub_cfg.update(ub_cfgs[name])
ub_cfg["fp8_buf"] = fp8_buf
add_ub(name, **ub_cfg)
def get_ub(name: str):
"""Get userbuffer communicator corresponding to give key."""
global _ub_communicators
assert _ub_communicators is not None, "UB manager is not initialized."
assert name in _ub_communicators, f"UB for {name} is not registered."
return _ub_communicators[name]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment