Unverified Commit fa4b866d authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[C/PyTorch] Fixed incorrect use of `torch.distributed.new_group()` when...


[C/PyTorch] Fixed incorrect use of `torch.distributed.new_group()` when creating intra-node group in `initialize_ub()` (#1087)

* updated initialize_ub() to use new_subgroups_by_enumeration() to generate intra-node groups, added new unit tests for TE layers with comm overlap
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 86f27e12
......@@ -7,6 +7,8 @@
import os
import sys
import socket
import fcntl
import struct
import argparse
import warnings
......@@ -15,15 +17,37 @@ import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.common.recipe import Format, DelayedScaling
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
if not tex.device_supports_multicast():
os.environ["UB_SKIPMC"] = "1"
def _te_layer_argtype(name):
te_layers = [
te.Linear,
te.LayerNormLinear,
te.LayerNormMLP,
te.MultiheadAttention,
te.TransformerLayer,
]
layer_map = dict(zip([layer.__name__.lower() for layer in te_layers], te_layers))
if name.lower() not in layer_map.keys():
raise argparse.ArgumentTypeError(
f"Invalid TE layer name! Please choose from: {layer_map.keys()}"
)
return layer_map[name.lower()]
def _parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser(
description="Test a te.LayerNormMLP module with GEMM+comm overlap via Userbuffers."
description="Train a Transformer Engine module with GEMM+comm overlap via Userbuffers."
)
parser.add_argument(
"-i", "--num-iters", type=int, default=5, help="Number of dummy 'training' iterations."
......@@ -37,10 +61,10 @@ def _parse_args(argv=None, namespace=None):
"-d", "--head-dim", type=int, default=128, help="Dimension of each attention head."
)
parser.add_argument(
"--mlp-expansion-factor",
type=int,
default=4,
help="MLP block intermediate size as a factor of hidden dimension.",
"--layer-type",
type=_te_layer_argtype,
default=te.TransformerLayer,
help="Transformer Engine layer to train with comm+GEMM overlap.",
)
parser.add_argument("--seed", type=int, default=1234, help="RNG seed.")
parser.add_argument(
......@@ -88,9 +112,57 @@ def _parse_args(argv=None, namespace=None):
help="Print out additional debug information.",
)
args = parser.parse_args(argv, namespace)
if args.bootstrap_backend == "nccl":
args.bind_to_device = True
return args
def _get_layer_args(config, tp_group, tp_size, reference=False):
hidden_size = config.num_heads * config.head_dim
input_shape = [config.seq_length, config.batch_size, hidden_size]
args = [hidden_size]
kwargs = {
"params_dtype": torch.float32,
"device": "cuda",
"tp_group": tp_group,
"tp_size": tp_size,
"sequence_parallel": True,
}
kwargs["ub_overlap_ag"] = not config.no_comm_overlap
if config.layer_type is te.Linear:
input_shape[2] = hidden_size // tp_size
args.append(hidden_size)
kwargs["parallel_mode"] = "row"
kwargs["ub_overlap_rs"] = not config.no_comm_overlap
kwargs["ub_name"] = "proj"
else:
input_shape[0] = config.seq_length // tp_size
kwargs["ub_bulk_wgrad"] = not config.no_comm_overlap
kwargs["ub_bulk_dgrad"] = not config.no_comm_overlap
if config.layer_type is te.LayerNormLinear:
args.append(3 * hidden_size)
kwargs["parallel_mode"] = "column"
kwargs["ub_name"] = "qkv"
else:
kwargs["set_parallel_mode"] = True
kwargs["ub_overlap_rs"] = not config.no_comm_overlap
if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]:
args.append(4 * hidden_size)
kwargs["seq_length"] = config.seq_length
if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
args.append(config.num_heads)
kwargs["attention_dropout"] = 0.0
kwargs["fuse_qkv_params"] = True
if config.layer_type is te.MultiheadAttention:
kwargs["input_layernorm"] = True
else:
kwargs["ub_tp_comm_overlap"] = not config.no_comm_overlap
kwargs["hidden_dropout"] = 0.0
return args, kwargs, input_shape
def _train(opts):
if "OMPI_COMM_WORLD_SIZE" in os.environ:
# Execution with `mpirun -np N`
......@@ -110,19 +182,6 @@ def _train(opts):
raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!")
NUM_NODES = WORLD_SIZE // LOCAL_SIZE
def dist_print(msg, group=None, end="\n", debug=False):
if debug and not opts.debug:
return
group = dist.new_group() if group is None else group
group_rank = dist.get_rank(group)
group_size = dist.get_world_size(group)
all_ranks = dist.get_process_group_ranks(group)
ranks_skip = all_ranks[1] - all_ranks[0] > 1
group_id = WORLD_RANK % group_size if ranks_skip else WORLD_RANK // group_size
if group_rank == 0 or opts.verbose:
print(f"[rank{WORLD_RANK}:node{group_id}] {msg}{end}", end="", flush=True)
dist.barrier(group)
# Initialize torch.distributed global process group and get DP/TP groups
torch.cuda.set_device(LOCAL_RANK)
dist_init_kwargs = {
......@@ -143,75 +202,117 @@ def _train(opts):
assert dist.is_nccl_available()
dist.init_process_group(**dist_init_kwargs)
nccl_world = dist.new_group(backend="nccl")
dist_print(f"Initialized default NCCL process group with {WORLD_RANK} GPUs", nccl_world)
def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False):
if debug and not opts.debug:
return
group_rank = dist.get_rank(group)
stream = sys.stderr if error else sys.stdout
if group_rank == src:
stream.write(f"[rank{WORLD_RANK}] {msg}{end}")
dist.barrier(group)
dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs")
# Figure out process groups for tensor- and data-parallelism (if any)
if NUM_NODES > 1:
# Create a list of world ranks on this node
hostnames = [None for _ in range(WORLD_SIZE)]
hostname = socket.gethostname()
ifname = os.getenv(
"NVTE_UB_SOCKET_IFNAME",
os.getenv("NCCL_SOCKET_IFNAME", os.getenv("GLOO_SOCKET_IFNAME")),
)
if ifname is not None:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
hostname = socket.inet_ntoa(
fcntl.ioctl(
s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8"))
)[20:24]
)
except OSError as err:
raise OSError(f"Invalid network interface: {ifname}") from err
hostnames = [None for _ in range(WORLD_SIZE)]
dist.all_gather_object(hostnames, hostname)
node_ranks = []
unique_hosts = []
for host in hostnames:
if host not in unique_hosts:
unique_hosts.append(host)
assert len(unique_hosts) == NUM_NODES
ranks_per_node_list = [[] for _ in range(NUM_NODES)]
self_node_idx = -1
for i, host in enumerate(hostnames):
node_idx = unique_hosts.index(host)
ranks_per_node_list[node_idx].append(i)
if host == hostname:
node_ranks.append(i)
self_node_idx = node_idx
assert self_node_idx >= 0
self_node_ranks = ranks_per_node_list[self_node_idx]
if opts.num_replicas > 1:
# Split node ranks into multiple replicas
assert len(node_ranks) % opts.num_replicas == 0
tp_size = len(node_ranks) // opts.num_replicas
found_replica = False
for replica in range(opts.num_replicas):
start = replica * tp_size
end = start + tp_size
tp_ranks = node_ranks[start:end]
if WORLD_RANK in tp_ranks:
found_replica = True
assert len(self_node_ranks) % opts.num_replicas == 0
tp_size = len(self_node_ranks) // opts.num_replicas
ranks_per_replica_list = []
for node_ranks in ranks_per_node_list:
for i in range(opts.num_replicas):
start = i * tp_size
end = start + tp_size
ranks_per_replica_list.append(node_ranks[start:end])
self_replica_idx = -1
for i, replica_ranks in enumerate(ranks_per_replica_list):
if WORLD_RANK in replica_ranks:
self_replica_idx = i
break
assert found_replica
assert self_replica_idx >= 0
else:
# The entire node is the tensor-parallel group
tp_ranks = node_ranks
tp_group = dist.new_group(backend="nccl", ranks=tp_ranks)
tp_size = dist.get_world_size(tp_group)
tp_rank = dist.get_rank(tp_group)
ranks_per_replica_list = ranks_per_node_list
self_replica_idx = self_node_idx
# Data-parallelism across TP groups
dp_start = tp_rank
dp_end = dp_start + WORLD_SIZE
dp_ranks = list(range(dp_start, dp_end, tp_size))
dp_group = dist.new_group(backend="nccl", ranks=dp_ranks)
tp_group, _ = dist.new_subgroups_by_enumeration(ranks_per_replica_list, backend="nccl")
ranks_per_replica_tensor = torch.tensor(ranks_per_replica_list, dtype=torch.int32)
dp_group, _ = dist.new_subgroups_by_enumeration(
ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl"
)
else:
if opts.num_replicas > 1:
# Mixed data- and tensor-parallelism on a single node
# NOTE: Avoid dist.init_device_mesh() to support older PyTorch versions
all_ranks = torch.tensor(list(range(LOCAL_SIZE)), dtype=torch.uint8, device="cpu")
mesh2d = all_ranks.reshape((opts.num_replicas, LOCAL_SIZE // opts.num_replicas))
node_idx = (mesh2d == LOCAL_RANK).nonzero().squeeze().tolist()
tp_ranks = mesh2d[node_idx[0], :].tolist()
tp_group = dist.new_group(backend="nccl", ranks=tp_ranks)
dp_ranks = mesh2d[:, node_idx[1]].tolist()
dp_group = dist.new_group(backend="nccl", ranks=dp_ranks)
ranks_per_replica_tensor = all_ranks.reshape(
(opts.num_replicas, LOCAL_SIZE // opts.num_replicas)
)
tp_group, _ = dist.new_subgroups_by_enumeration(
ranks_per_replica_tensor.tolist(), backend="nccl"
)
dp_group, _ = dist.new_subgroups_by_enumeration(
ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl"
)
else:
dp_group = None
tp_group = nccl_world
tp_rank = dist.get_rank(tp_group)
tp_size = dist.get_world_size(tp_group)
tp_rank = dist.get_rank(tp_group)
tp_size = dist.get_world_size(tp_group)
dist_print(
f"Created tensor-parallel group: {dist.get_process_group_ranks(tp_group)}",
group=tp_group,
)
if dp_group is not None:
dp_rank = dist.get_rank(dp_group)
dist_print(
f"Created data-parallel group: {dist.get_process_group_ranks(dp_group)}",
group=dp_group,
)
else:
dp_rank = 0
# Intialize userbuffers
hidden_size = opts.num_heads * opts.head_dim
......@@ -226,26 +327,12 @@ def _train(opts):
)
# Initialize the fused LayerNorm + Multi-layer Perceptron module
torch.manual_seed(opts.seed + tp_rank)
torch.manual_seed(opts.seed + dp_rank)
torch.cuda.manual_seed(opts.seed + tp_rank)
model = te.LayerNormMLP(
hidden_size,
opts.mlp_expansion_factor * hidden_size,
params_dtype=torch.bfloat16,
device="cuda",
tp_group=tp_group,
tp_size=tp_size,
set_parallel_mode=True,
sequence_parallel=True, # this is required for comm+GEMM overlap
seq_length=opts.seq_length,
ub_overlap_rs=not opts.no_comm_overlap,
ub_overlap_ag=not opts.no_comm_overlap,
ub_overlap_rs_dgrad=not opts.no_comm_overlap,
ub_bulk_dgrad=False,
ub_bulk_wgrad=not opts.no_comm_overlap,
)
layer_args, layer_kwargs, input_size = _get_layer_args(opts, tp_group, tp_size)
model = opts.layer_type(*layer_args, **layer_kwargs)
if dp_group is not None:
model = DistributedDataParallel(model, process_group=dp_group)
model = DistributedDataParallel(model, dim=1, process_group=dp_group)
# Initialize optimizer with model parameters
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
......@@ -255,28 +342,28 @@ def _train(opts):
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
# Start dummy "training" iterations
dist_print("Starting training iterations...", nccl_world)
dist_print("Starting training iterations...")
for i in range(opts.num_iters):
dist_print(f" Iter {i+1}", tp_group, debug=True)
dist_print(" |-- Generate random input batch", tp_group, debug=True)
x = torch.rand(
(opts.seq_length // tp_size, opts.batch_size, hidden_size),
dtype=torch.bfloat16,
device="cuda",
requires_grad=True,
)
dist_print(" |-- Forward pass", tp_group, debug=True)
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world):
y = model(x)
dist_print(" |-- Compute loss", tp_group, debug=True)
loss = y.flatten().sum()
dist_print(" |-- Backward pass", tp_group, debug=True)
dist_print(f" Iter {i+1}", group=tp_group, debug=True)
dist_print(" |-- Generate random input batch", group=tp_group, debug=True)
x = torch.randn(input_size, dtype=torch.float32, device="cuda", requires_grad=True)
dist_print(" |-- Forward pass", group=tp_group, debug=True)
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world):
y = model(x)
if isinstance(y, tuple):
out, *_ = y
else:
out = y
dist_print(" |-- Compute loss", group=tp_group, debug=True)
loss = out.sum()
dist_print(" |-- Backward pass", group=tp_group, debug=True)
loss.backward()
dist_print(" |-- Optimizer step", tp_group, debug=True)
dist_print(" |-- Optimizer step", group=tp_group, debug=True)
optim.step()
torch.cuda.synchronize()
......
......@@ -5,6 +5,7 @@
set -e
: ${TE_PATH:=/opt/transformerengine}
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py
git clone https://github.com/NVIDIA/Megatron-LM.git
cd Megatron-LM
......
#!/usr/bin/python3
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import sys
import socket
import argparse
import warnings
from functools import partial
import torch
import torch.distributed as dist
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
def _te_layer_argtype(name):
te_layers = [
te.Linear,
te.LayerNormLinear,
te.LayerNormMLP,
te.MultiheadAttention,
te.TransformerLayer,
]
layer_map = dict(zip([layer.__name__.lower() for layer in te_layers], te_layers))
if name.lower() not in layer_map.keys():
raise argparse.ArgumentTypeError(
f"Invalid TE layer name! Please choose from: {layer_map.keys()}"
)
return layer_map[name.lower()]
def _get_layer_args(config, tp_group, tp_size, reference=False):
hidden_size = config.num_heads * config.head_dim
input_shape = [config.seq_length, config.batch_size, hidden_size]
args = [hidden_size]
kwargs = {
"params_dtype": torch.float32,
"device": "cuda",
"tp_group": tp_group,
"tp_size": tp_size,
"sequence_parallel": True,
}
kwargs["ub_overlap_ag"] = not reference
if config.layer_type is te.Linear:
input_shape[2] = hidden_size // tp_size
args.append(hidden_size)
kwargs["parallel_mode"] = "row"
kwargs["ub_overlap_rs"] = not reference
kwargs["ub_name"] = "proj"
else:
input_shape[0] = config.seq_length // tp_size
kwargs["ub_bulk_wgrad"] = not reference
kwargs["ub_bulk_dgrad"] = not reference
if config.layer_type is te.LayerNormLinear:
args.append(3 * hidden_size)
kwargs["parallel_mode"] = "column"
kwargs["ub_name"] = "qkv"
else:
kwargs["set_parallel_mode"] = True
kwargs["ub_overlap_rs"] = not reference
if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]:
args.append(4 * hidden_size)
kwargs["seq_length"] = config.seq_length
if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
args.append(config.num_heads)
kwargs["attention_dropout"] = 0.0
kwargs["fuse_qkv_params"] = True
if config.layer_type is te.MultiheadAttention:
kwargs["input_layernorm"] = True
else:
kwargs["ub_tp_comm_overlap"] = not reference
kwargs["hidden_dropout"] = 0.0
return args, kwargs, input_shape
def _parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser(
description="Test a Transformer Engine layer with GEMM+comm overlap via Userbuffers."
)
parser.add_argument("-l", "--layer-type", type=_te_layer_argtype, default=te.LayerNormMLP)
parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.")
parser.add_argument("-s", "--seq-length", type=int, default=2048, help="Input sequence length.")
parser.add_argument(
"-n", "--num-heads", type=int, default=12, help="Number of attention heads."
)
parser.add_argument(
"-d", "--head-dim", type=int, default=64, help="Dimension of each attention head."
)
parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
)
parser.add_argument(
"--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8."
)
parser.add_argument(
"--tcp-init",
action="store_true",
default=False,
help="Initialize torch.distributed with TcpStore.",
)
parser.add_argument(
"--bind-to-device",
action="store_true",
default=False,
help="Initialize torch.distributed with `device_id` to bind each rank to a single device.",
)
parser.add_argument(
"--bootstrap-backend",
type=str.lower,
default="nccl",
choices=["gloo", "mpi", "nccl"],
help="Communications backend for host tensor collectives during Userbuffers bootstrapping.",
)
parser.add_argument(
"--use-cuda-graphs", action="store_true", default=False, help="Use CUDA Graphs."
)
parser.add_argument(
"--debug",
action="store_true",
default=False,
help="Print out additional debug information.",
)
args = parser.parse_args(argv, namespace)
if args.use_cuda_graphs and args.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
warnings.warn(f"{args.layer_type.__name__} does not support CUDA Graphs!")
args.use_cuda_graphs = False
return args
def _compare_tensors(name, test, ref, rtol, atol):
# Make sure tensors aren't zero and we don't pass trivially
if test.count_nonzero() == 0:
if ref.count_nonzero() == 0:
warnings.warn(
f"WARNING: {name} is a zero-tensor for both test and reference models!",
category=RuntimeWarning,
)
else:
numerics_info = (
f"NUMERICAL CHECK FAILED: {name} is a zero-tensor but does not match reference!"
)
return 1, numerics_info
diff = torch.abs(test - ref).flatten()
m = torch.argmax(diff)
abs_err = diff[m].item()
rel_err = abs_err / max(abs(ref.flatten()[m].item()), 1e-5)
numerics_failed = 0
if rel_err > rtol and abs_err > atol:
numerics_failed = 1
numerics_info = (
"NUMERICAL CHECK FAILED: "
+ f"{name} not close enough at index {m.item()} "
+ f"with {test.flatten()[m].item()} vs {ref.flatten()[m].item()} | "
+ f"rel. error = {rel_err} (tol = {rtol}) | "
+ f"abs. error = {abs_err} (tol = {atol})"
)
else:
numerics_info = f"NUMERICAL CHECK PASSED: {name} | "
if rel_err <= rtol:
numerics_info += f"rel. error = {rel_err} (tol = {rtol})" + (
" | " if abs_err <= atol else "."
)
if abs_err <= atol:
numerics_info += f" abs. error = {abs_err} (tol = {atol})"
return numerics_failed, numerics_info
def _train(opts):
if "OMPI_COMM_WORLD_SIZE" in os.environ:
# Execution with `mpirun -np N`
WORLD_RANK = int(os.getenv("OMPI_COMM_WORLD_RANK", "0"))
WORLD_SIZE = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1"))
opts.tcp_init = True
opts.bind_to_device = True
opts.bootstrap_backend = "mpi"
elif "TORCHELASTIC_RUN_ID" in os.environ:
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
else:
raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!")
assert LOCAL_SIZE == WORLD_SIZE
def dist_print(msg, src=None, end="\n", debug=False, error=False):
if debug and not opts.debug:
return
stream = sys.stderr if error else sys.stdout
if WORLD_RANK == (0 if src is None else src):
stream.write(f"[rank{WORLD_RANK}] {msg}{end}\n")
dist.barrier()
# Set device and initialize RNG states
torch.cuda.set_device(WORLD_RANK)
torch.manual_seed(opts.seed)
torch.cuda.manual_seed(opts.seed)
# Initialize torch.distributed global process group and get DP/TP groups
dist_init_kwargs = {
"backend": "nccl",
"rank": WORLD_RANK,
"world_size": WORLD_SIZE,
}
if opts.tcp_init:
MASTER_ADDR = os.getenv("MASTER_ADDR", socket.gethostbyname(socket.gethostname()))
MASTER_PORT = os.getenv("MASTER_PORT", "1234")
dist_init_kwargs["init_method"] = f"tcp://{MASTER_ADDR}:{MASTER_PORT}"
if opts.bind_to_device or opts.bootstrap_backend == "nccl":
dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}")
assert dist.is_nccl_available()
dist.init_process_group(**dist_init_kwargs)
nccl_world = dist.new_group(backend="nccl")
dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs")
# Intialize userbuffers
te.module.base.initialize_ub(
[opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim],
WORLD_SIZE,
use_fp8=opts.fp8,
dtype=torch.bfloat16,
bootstrap_backend=opts.bootstrap_backend,
)
# Initialize the Transformer Engine layer with overlap
args, kwargs, input_shape = _get_layer_args(opts, nccl_world, WORLD_SIZE)
with te.fp8_model_init(enabled=opts.fp8_init):
test_model = opts.layer_type(*args, **kwargs)
dist_print("Initialized test model...", debug=True)
# Initialize the reference model and copy all parameters
ref_args, ref_kwargs, _ = _get_layer_args(opts, nccl_world, WORLD_SIZE, reference=True)
with te.fp8_model_init(enabled=opts.fp8_init):
ref_model = opts.layer_type(*ref_args, **ref_kwargs)
dist_print("Initialized reference model...", debug=True)
for test_param, ref_param in zip(test_model.parameters(), ref_model.parameters()):
with torch.no_grad():
ref_param.copy_(test_param)
torch.testing.assert_close(test_param, ref_param, rtol=0.0, atol=0.0)
dist_print("Copied parameters from test model to reference model...", debug=True)
# Fp8 recipe setup
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
# Prepare random input tensors
test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True)
test_x.retain_grad()
ref_x = torch.empty_like(test_x).requires_grad_(True)
with torch.no_grad():
ref_x.copy_(test_x)
torch.testing.assert_close(test_x, ref_x, rtol=0.0, atol=0.0)
ref_x.retain_grad()
# Execute fwd/bwd and collect tensors to test
def run_fwd_bwd(model, x):
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world):
y = model(x)
if isinstance(y, tuple):
out, *_ = y
else:
out = y
loss = out.sum()
loss.backward()
return out
torch_rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state(torch.device(f"cuda:{WORLD_RANK}"))
if opts.use_cuda_graphs:
test_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(test_graph):
test_out = run_fwd_bwd(test_model, test_x)
test_graph.replay()
del test_graph
else:
test_out = run_fwd_bwd(test_model, test_x)
test_grads = [test_out, test_x.grad]
names = ["output", "input.grad"]
for test_name, test_param in test_model.named_parameters():
if test_param.requires_grad and "layer_norm" not in test_name:
test_grads.append(test_param.grad)
names.append(test_name + ".grad")
torch.set_rng_state(torch_rng_state)
torch.cuda.set_rng_state(cuda_rng_state, torch.device(f"cuda:{WORLD_RANK}"))
if opts.use_cuda_graphs:
ref_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(ref_graph):
ref_out = run_fwd_bwd(ref_model, ref_x)
ref_graph.replay()
del ref_graph
else:
ref_out = run_fwd_bwd(ref_model, ref_x)
ref_grads = [ref_out, ref_x.grad]
for ref_name, ref_param in ref_model.named_parameters():
if ref_param.requires_grad and "layer_norm" not in ref_name:
ref_grads.append(ref_param.grad)
# Make sure we have the same number of gradients
numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda")
if len(test_grads) != len(ref_grads):
numerics_failed[0] = 1
numerics_info = (
"NUMERICAL CHECK FAILED: Incorrect number of gradients, "
+ f"expected {len(ref_grads)} but got {len(test_grads)}."
)
dist_print(numerics_info, src=WORLD_RANK, error=True)
dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world)
# Now validate accuracy
if not bool(numerics_failed.item()):
for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)):
rtol = 0.125 if opts.fp8 else 0.025
atol = 0.0625 if opts.fp8 else 0.00125
grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol)
dist_print(grad_info, src=WORLD_RANK, error=grad_failed)
numerics_failed[0] = int(grad_failed)
dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world)
if bool(numerics_failed.item()):
break
te.module.base.destroy_ub()
dist_print("Destroying Userbuffers objects...", debug=True)
dist_print("Destroying all process groups...", debug=True)
dist.destroy_process_group()
if opts.debug and WORLD_RANK == 0:
print("Exiting...\n", end="", flush=True)
return numerics_failed[0].item()
if __name__ == "__main__":
sys.exit(_train(_parse_args()))
......@@ -7,16 +7,27 @@ from pathlib import Path
import pytest
import torch
import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
if torch.cuda.device_count() < 2:
pytest.skip("Comm+GEMM overlap requires at least 2 GPUs.")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
RNG_SEED: int = 1234
SEQ_LENGTH: int = 2024
SEQ_LENGTH: int = 512
BATCH_SIZE: int = 2
NUM_HEADS: int = 64
HEAD_DIM: int = 128
NUM_HEADS: int = 12
HEAD_DIM: int = 64
TE_LAYERS = [
te.Linear,
te.LayerNormLinear,
te.LayerNormMLP,
te.MultiheadAttention,
te.TransformerLayer,
]
TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(torch.cuda.device_count(), 4)
......@@ -32,66 +43,28 @@ if not tex.device_supports_multicast():
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
@pytest.mark.skipif(NUM_PROCS < 2, reason="Comm+GEMM overlap requires at least 2 GPUs.")
@pytest.mark.parametrize(
"fp8,p2p,comm_type,aggregate,atomic,bulk",
[
# FP8, P2P, Type, Aggregate, Atomic, Bulk
(False, True, "AG", False, False, False),
(False, True, "AG", True, False, False),
(True, True, "AG", False, False, False),
(True, True, "AG", True, False, False),
(False, False, "RS", False, False, False),
(False, True, "RS", False, False, False),
(True, False, "RS", False, False, False),
(True, True, "RS", False, False, False),
(True, False, "RS", False, True, False),
(True, True, "RS", False, True, False),
(False, False, "AG", False, False, True),
(False, False, "RS", False, False, True),
],
ids=[
" AG -> SPLIT GEMM | BF16 | RING-EXCHANGE ",
" AG -> SPLIT GEMM | BF16 | RING-EXCHANGE (2X AGGREGATED) ",
" AG -> SPLIT GEMM | FP8 | RING-EXCHANGE ",
" AG -> SPLIT GEMM | FP8 | RING-EXCHANGE (2X AGGREGATED) ",
" SPLIT GEMM -> RS | BF16 | PIPELINE ",
" SPLIT GEMM -> RS | BF16 | RING-EXCHANGE ",
" SPLIT GEMM -> RS | FP8 | PIPELINE ",
" SPLIT GEMM -> RS | FP8 | RING-EXCHANGE ",
" ATOMIC GEMM -> RS | FP8 | PIPELINE ",
" ATOMIC GEMM -> RS | FP8 | RING-EXCHANGE ",
" BULK AG & GEMM | BF16 | PIPELINE ",
" BULK RS & GEMM | BF16 | PIPELINE ",
],
)
def test_gemm_with_overlap(fp8, p2p, comm_type, aggregate, atomic, bulk):
"""
Test comm+GEMM overlap algorithms with direct calls to
te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm
"""
def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggregate):
test_path = TEST_ROOT / "run_gemm_with_overlap.py"
test_cmd = (
LAUNCH_CMD
+ [str(test_path)]
+ [
"--check-numerics",
f"--seed={RNG_SEED}",
f"--seq-length={SEQ_LENGTH}",
f"--batch-size={BATCH_SIZE}",
f"--num-heads={NUM_HEADS}",
f"--head-dim={HEAD_DIM}",
f"--comm-type={comm_type}",
]
)
test_cmd = LAUNCH_CMD + [
str(test_path),
"--check-numerics",
f"--seed={RNG_SEED}",
f"--seq-length={SEQ_LENGTH}",
f"--batch-size={BATCH_SIZE}",
f"--num-heads={NUM_HEADS}",
f"--head-dim={HEAD_DIM}",
f"--comm-type={comm_type}",
]
if bulk:
test_cmd.append("--bulk-overlap")
else:
if fp8:
if fp8_in:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
test_cmd.append("--fp8")
if fp8_out:
test_cmd.append("--fp8-output")
if p2p:
test_cmd.append("--p2p")
if aggregate:
......@@ -101,5 +74,173 @@ def test_gemm_with_overlap(fp8, p2p, comm_type, aggregate, atomic, bulk):
pytest.skip("Device compute capability 9.0 or higher required for Atomic GEMM.")
test_cmd.append("--atomic")
output = subprocess.run(test_cmd, env=os.environ, text=True, capture_output=True, check=False)
assert "NUMERICAL CHECK PASSED" in str(output)
result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False)
if (
result.returncode != 0
or "NUMERICAL CHECK FAILED" in result.stderr.decode()
or "NUMERICAL CHECK PASSED" not in result.stdout.decode()
):
raise AssertionError(result.stderr.decode())
def _run_layer_with_overlap(layer_type, fp8, fp8_init):
test_path = TEST_ROOT / "run_layer_with_overlap.py"
test_cmd = LAUNCH_CMD + [
str(test_path),
f"--seed={RNG_SEED}",
f"--seq-length={SEQ_LENGTH}",
f"--batch-size={BATCH_SIZE}",
f"--num-heads={NUM_HEADS}",
f"--head-dim={HEAD_DIM}",
f"--layer-type={layer_type}",
]
if fp8:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
test_cmd.append("--fp8")
if fp8_init:
test_cmd.append("--fp8-init")
os.environ["PYTORCH_JIT"] = "0"
os.environ["NVTE_TORCH_COMPILE"] = "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0"
result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False)
os.unsetenv("PYTORCH_JIT")
os.unsetenv("NVTE_TORCH_COMPILE")
os.unsetenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO")
if (
result.returncode != 0
or "NUMERICAL CHECK FAILED" in result.stderr.decode()
or "NUMERICAL CHECK PASSED" not in result.stdout.decode()
):
raise AssertionError(result.stderr.decode())
@pytest.mark.parametrize(
"fp8,aggregate",
[
(False, False),
(False, True),
(True, False),
(True, True),
],
ids=[
" BF16 IN - RING-EXCHANGE ",
" BF16 IN - RING-EXCHANGE - 2x AGGREGATED ",
" FP8 IN - RING-EXCHANGE ",
" FP8 IN - RING-EXCHANGE - 2x AGGREGATED ",
],
)
def test_split_all_gather_overlaps(fp8, aggregate):
"""
Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm.
"""
_run_gemm_with_overlap("AG", False, True, False, fp8, False, aggregate)
@pytest.mark.parametrize(
"fp8_in,fp8_out,p2p",
[
(False, False, False),
(False, False, True),
(True, False, False),
(True, False, True),
(True, True, False),
(True, True, True),
],
ids=[
" BF16 IN - BF16 OUT - PIPELINE ",
" BF16 IN - BF16 OUT - RING-EXCHANGE ",
" FP8 IN - BF16 OUT - PIPELINE ",
" FP8 IN - BF16 OUT - RING-EXCHANGE ",
" FP8 IN - FP8 OUT - PIPELINE ",
" FP8 IN - FP8 OUT - RING-EXCHANGE ",
],
)
def test_split_reduce_scatter_overlaps(fp8_in, fp8_out, p2p):
"""
Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm.
"""
_run_gemm_with_overlap("RS", False, p2p, False, fp8_in, fp8_out, False)
@pytest.mark.parametrize(
"ag_type,rs_type,p2p,fp8_out",
[
(0, 0, False, False),
(0, 1, False, False),
(0, 1, False, True),
(0, 2, False, False),
(0, 2, False, True),
(0, 0, True, False),
(0, 0, True, True),
(1, 0, True, False),
(1, 0, True, True),
],
ids=[
" NON-ATOMIC AG - NON-ATOMIC RS - PIPELINE - BF16 OUT ",
" NON-ATOMIC AG - ATOMIC RS - PIPELINE - BF16 OUT ",
" NON-ATOMIC AG - ATOMIC RS - PIPELINE - FP8 OUT ",
" NON-ATOMIC AG - MULTI-ATOMIC RS - PIPELINE - BF16 OUT ",
" NON-ATOMIC AG - MULTI-ATOMIC RS - PIPELINE - FP8 OUT ",
" NON-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - BF16 OUT ",
" NON-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - FP8 OUT ",
" MULTI-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - BF16 OUT ",
" MULTI-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - FP8 OUT ",
],
)
def test_atomic_gemm_overlaps(ag_type, rs_type, p2p, fp8_out):
"""
Test paired (all-gather -> atomic GEMM) and (atomic GEMM -> reduce-scatter) overlaps with
direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm.
"""
os.environ["NVTE_AG_P2P_MULTI_ATOMIC"] = str(ag_type)
os.environ["NVTE_RS_STRIDED_ATOMIC"] = str(rs_type)
_run_gemm_with_overlap("AG", False, p2p, True, True, fp8_out, False)
@pytest.mark.parametrize(
"comm_type,fp8",
[
("AG", False),
("RS", False),
("RS", True),
],
ids=[" ALL-GATHER - BF16 ", " REDUCE-SCATTER - BF16 ", " REDUCE-SCATTER - FP8 "],
)
def test_bulk_overlaps(comm_type, fp8):
"""
Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm.
"""
_run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False)
@pytest.mark.parametrize(
"layer_type",
[layer.__name__ for layer in TE_LAYERS],
ids=[(" " + layer.__name__ + " ") for layer in TE_LAYERS],
)
@pytest.mark.parametrize(
"fp8,fp8_init",
[
(False, False),
(True, False),
(True, True),
],
ids=[
" BF16 GEMM - BF16 PARAMS ",
" FP8 GEMM - BF16 PARAMS ",
" FP8 GEMM - FP8 PARAMS ",
],
)
def test_layers_with_overlap(layer_type, fp8, fp8_init):
"""
Test Transformer Engine layers with comm+GEMM overlap.
"""
_run_layer_with_overlap(layer_type, fp8, fp8_init)
......@@ -166,7 +166,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
// Initialize userbuf communicator
if (!comm_created) {
if (myrank == 0) {
printf("!!! [UB] Create UbufCommOverlap Communicator\n");
printf("!!! [UB] Create Userbuffers Communicator\n");
}
#ifdef NVTE_UB_WITH_MPI
create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1);
......@@ -184,16 +184,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
// Allocate and register extra userbuffers
int ubuf_bytes = sample.numel() * sample.element_size();
if (transformer_engine::getenv<bool>("UB_SKIPMC")) {
_ubuf = torch::zeros_like(sample);
_ubuf_ptr = _ubuf.data_ptr();
_ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes,
_ub_comm, false);
} else {
_ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes,
_ub_comm, true);
_ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options());
}
_ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes,
_ub_comm, true);
_ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options());
if (_ub_comm->myrank == 0) {
printf("!!! [UB] Register UBuf %d\n", _ub_reg);
......@@ -264,6 +257,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, int comm_type,
at::Tensor rs_output) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
......@@ -319,6 +313,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size;
int output_c_dim1 = _ubuf.size(1);
output_tensor = torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options());
_ub_comm->sms = ori_sms;
return {D, output_tensor};
} // bulk_overlap
......@@ -336,6 +331,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, bool gemm_overlap,
at::Tensor rs_output) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
......@@ -352,7 +348,6 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr());
int *counter_ptr = reinterpret_cast<int *>(counter.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
int ori_sms = _ub_comm->sms;
// Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
......@@ -388,7 +383,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type,
D_type, fp8_type,
reducescatter2_userbuff_strided_atomic_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, m,
_num_splits, &counter_ptr[i], _ub_comm, (cudaStream_t)_stream_comm););
......@@ -402,7 +397,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type,
D_type, fp8_type,
reducescatter2_userbuff_strided_multiatomic_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits,
counter_ptr, _ub_comm, (cudaStream_t)_stream_comm););
......@@ -413,10 +408,8 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
}
break;
} else {
assert(_ubuf.element_size() != 1);
consumer(counter_ptr, i, (cudaStream_t)_stream_comm);
// if (i == _num_splits-1) {
// _ub_comm->sms = UB_MAX_SM;
// }
reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm);
}
......@@ -447,6 +440,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
size_t workspaceSize, bool accumulate, bool use_split_accumulator,
bool gemm_overlap, at::Tensor rs_output) {
// Get GEMM dimensions
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
......@@ -464,7 +458,6 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
int ori_sms = _ub_comm->sms;
// Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
......@@ -517,7 +510,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type,
D_type, fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n,
m, _ub_comm, (cudaStream_t)_stream_comm););
......@@ -541,7 +534,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type,
D_type, fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size,
m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm););
......@@ -577,7 +570,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type,
D_type, fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * output_chunk_size, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm););
......@@ -682,7 +675,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
// Initialize userbuf communicator
if (!comm_created) {
if (myrank == 0) {
printf("!!! [UB] Create UbufP2PCommOverlap Communicator\n");
printf("!!! [UB] Create Userbuffers Communicator\n");
}
#ifdef NVTE_UB_WITH_MPI
create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1);
......@@ -708,19 +701,11 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
ubuf_bytes = static_cast<int>(ubuf_bytes / tp_size * (tp_size * 2 - 1));
num_ubuf_chunks = static_cast<int>(tp_size * 2 - 1);
}
if (transformer_engine::getenv<bool>("UB_SKIPMC")) {
_ubuf = torch::zeros({sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)},
sample.options());
_ubuf_ptr = _ubuf.data_ptr();
_ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes,
_ub_comm, false);
} else {
_ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes,
_ub_comm, true);
_ubuf =
torch::from_blob(_ubuf_ptr, {sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)},
sample.options());
}
_ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes,
_ub_comm, true);
_ubuf = torch::from_blob(
_ubuf_ptr, {sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)}, sample.options());
if (_ub_comm->myrank == 0) {
printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg);
}
......@@ -728,9 +713,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
// Create tensor chunks for easy management
char *ubuf_byte_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
for (int i = 0; i < num_ubuf_chunks; i++) {
torch::Tensor ubuf_chunk = torch::from_blob(
ubuf_byte_ptr, {sample.size(0) / tp_size, sample.size(1)}, sample.options());
_ubufs.push_back(ubuf_chunk);
auto ubuf_chunk = torch::from_blob(ubuf_byte_ptr, {sample.size(0) / tp_size, sample.size(1)},
sample.options());
_ubufs.push_back(std::move(ubuf_chunk));
ubuf_byte_ptr += ubuf_chunk_bytes;
}
......@@ -769,6 +754,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC");
if (_rank == 0 && env_p != nullptr) {
if (env_p[0] == '1') {
_use_ce = 0;
_ub_comm->push = 1;
printf("!!userbuffers_sendrecv_multi_atomic_shuffle\n");
}
}
......@@ -818,6 +805,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias,
transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
......@@ -866,6 +854,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC");
if (env_p != nullptr && env_p[0] == '1') {
if (i == 0) {
_ub_comm->use_ce = 0;
userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes,
_ub_comm, _next_rank, _prev_rank, _tp_size, counter_ptr,
true, (cudaStream_t)_stream_recv);
......@@ -906,6 +895,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
n_chunk * m * D.element_size(), cudaMemcpyDeviceToDevice,
(cudaStream_t)stream_main));
// Return the last N rows of D_buffer
_ub_comm->sms = ori_sms;
torch::Tensor D_return = D_buffer.narrow(0, n_chunk, n);
return D_return;
} // atomic_gemm_overlap_ag
......@@ -926,6 +916,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, at::Tensor B_copy) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
......@@ -1078,6 +1069,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
at::cuda::setCurrentCUDAStream(stream_main);
_ub_comm->sms = ori_sms;
return D;
} // split_overlap_ag
......@@ -1094,6 +1086,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, at::Tensor rs_output) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
......@@ -1149,7 +1142,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type,
D_type, fp8_type,
reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size,
_ubufs[0].numel(), (cudaStream_t)stream_main););
} else {
......@@ -1157,6 +1150,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0);
}
_ub_comm->sms = ori_sms;
}
/*
......@@ -1171,6 +1165,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator,
at::Tensor rs_output) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
......@@ -1245,7 +1240,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type,
D_type, fp8_type,
reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size,
_ubufs[0].numel(), (cudaStream_t)stream_main););
} else {
......@@ -1259,6 +1254,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
_ub_comm->sms = ori_sms;
}
/*
......
......@@ -1861,6 +1861,14 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)
}
template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream);
template <typename fp8type>
void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset,
const int elements, communicator *comm, cudaStream_t stream) {
......
......@@ -107,7 +107,7 @@ def initialize_ub(
world_size = torch.distributed.get_world_size(mpi_group)
local_rank = world_rank % tp_size
local_size = tp_size
node_id = world_rank // tp_size
self_node_idx = world_rank // tp_size
num_nodes = world_size // tp_size
ub_callbacks = tex.UbufBootstrapCallbacks()
else:
......@@ -127,13 +127,6 @@ def initialize_ub(
world_rank = torch.distributed.get_rank(world_group)
world_size = torch.distributed.get_world_size(world_group)
if world_rank == 0:
print(
f'!!! [NVTE] Bootstrapping Userbuffers with backend="{bootstrap_backend}"\n',
end="",
flush=True,
)
# Construct an intra-node communicator based on global ranks that share the same hostname
# NOTE: If the user specified a valid network interface for NCCL or GLOO, use the host
# address on that interface instead of the hostname. This can help avoid issues when
......@@ -157,28 +150,41 @@ def initialize_ub(
hostnames = [None for _ in range(world_size)]
torch.distributed.all_gather_object(hostnames, hostname, world_group)
intra_node_ranks = []
for i, host in enumerate(hostnames):
if host == hostname:
intra_node_ranks.append(i)
if len(intra_node_ranks) == world_size:
unique_hosts = []
for host in hostnames:
if host not in unique_hosts:
unique_hosts.append(host)
num_nodes = len(unique_hosts)
if num_nodes > 1:
ranks_per_node_list = [[] for _ in range(num_nodes)]
self_node_idx = -1
for i, host in enumerate(hostnames):
node_idx = unique_hosts.index(host)
ranks_per_node_list[node_idx].append(i)
if host == hostname:
self_node_idx = node_idx
assert self_node_idx >= 0, "Internal TE error!"
intra_node_group, _ = torch.distributed.new_subgroups_by_enumeration(
ranks_per_node_list, backend=bootstrap_backend
)
local_rank = torch.distributed.get_rank(intra_node_group)
local_size = torch.distributed.get_world_size(intra_node_group)
intra_node_ranks = torch.distributed.get_process_group_ranks(intra_node_group)
else:
self_node_idx = 0
intra_node_group = world_group
local_rank = world_rank
local_size = world_size
intra_node_ranks = list(range(world_size))
else:
intra_node_group = torch.distributed.new_group(
backend=bootstrap_backend, ranks=intra_node_ranks
)
local_rank = torch.distributed.get_rank(intra_node_group)
local_size = torch.distributed.get_world_size(intra_node_group)
node_id = world_rank // local_size
num_nodes = world_size // local_size
if world_rank == 0:
print(f"!!! [UB] Number of physical nodes: {num_nodes}\n", end="", flush=True)
if local_rank == 0:
print(
f"!!! [NVTE] Number of physical nodes: {num_nodes}\n"
+ f"!!! [NVTE] Global ranks on node {node_id}: {intra_node_ranks}\n",
f"!!! [UB] Global ranks on node {self_node_idx}: {intra_node_ranks}\n",
end="",
flush=True,
)
......@@ -293,7 +299,7 @@ def initialize_ub(
world_size, # World size
local_rank, # Rank within the node
local_size, # Number of ranks/GPUs per node
node_id, # Node ID
self_node_idx, # Node ID
num_nodes, # Number of nodes
tp_size, # Tensor-parallel group size (may be different than local_size)
num_sm, # Number of communication SMs
......@@ -313,7 +319,7 @@ def initialize_ub(
world_size, # World size
local_rank, # Rank within the node
local_size, # Number of ranks/GPUs per node
node_id, # Node ID
self_node_idx, # Node ID
num_nodes, # Number of nodes
tp_size, # Tensor-parallel group size (may be different than local_size)
num_sm, # Number of communication SMs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment