".github/vscode:/vscode.git/clone" did not exist on "1d075c0682a3b10b2ee7a381adbd8f0ab50f7664"
Unverified Commit fa4b866d authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[C/PyTorch] Fixed incorrect use of `torch.distributed.new_group()` when...


[C/PyTorch] Fixed incorrect use of `torch.distributed.new_group()` when creating intra-node group in `initialize_ub()` (#1087)

* updated initialize_ub() to use new_subgroups_by_enumeration() to generate intra-node groups, added new unit tests for TE layers with comm overlap
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 86f27e12
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
import os import os
import sys import sys
import socket import socket
import fcntl
import struct
import argparse import argparse
import warnings import warnings
...@@ -15,15 +17,37 @@ import torch.distributed as dist ...@@ -15,15 +17,37 @@ import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.common.recipe import Format, DelayedScaling from transformer_engine.common.recipe import Format, DelayedScaling
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=UserWarning)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
if not tex.device_supports_multicast():
os.environ["UB_SKIPMC"] = "1"
def _te_layer_argtype(name):
te_layers = [
te.Linear,
te.LayerNormLinear,
te.LayerNormMLP,
te.MultiheadAttention,
te.TransformerLayer,
]
layer_map = dict(zip([layer.__name__.lower() for layer in te_layers], te_layers))
if name.lower() not in layer_map.keys():
raise argparse.ArgumentTypeError(
f"Invalid TE layer name! Please choose from: {layer_map.keys()}"
)
return layer_map[name.lower()]
def _parse_args(argv=None, namespace=None): def _parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Test a te.LayerNormMLP module with GEMM+comm overlap via Userbuffers." description="Train a Transformer Engine module with GEMM+comm overlap via Userbuffers."
) )
parser.add_argument( parser.add_argument(
"-i", "--num-iters", type=int, default=5, help="Number of dummy 'training' iterations." "-i", "--num-iters", type=int, default=5, help="Number of dummy 'training' iterations."
...@@ -37,10 +61,10 @@ def _parse_args(argv=None, namespace=None): ...@@ -37,10 +61,10 @@ def _parse_args(argv=None, namespace=None):
"-d", "--head-dim", type=int, default=128, help="Dimension of each attention head." "-d", "--head-dim", type=int, default=128, help="Dimension of each attention head."
) )
parser.add_argument( parser.add_argument(
"--mlp-expansion-factor", "--layer-type",
type=int, type=_te_layer_argtype,
default=4, default=te.TransformerLayer,
help="MLP block intermediate size as a factor of hidden dimension.", help="Transformer Engine layer to train with comm+GEMM overlap.",
) )
parser.add_argument("--seed", type=int, default=1234, help="RNG seed.") parser.add_argument("--seed", type=int, default=1234, help="RNG seed.")
parser.add_argument( parser.add_argument(
...@@ -88,9 +112,57 @@ def _parse_args(argv=None, namespace=None): ...@@ -88,9 +112,57 @@ def _parse_args(argv=None, namespace=None):
help="Print out additional debug information.", help="Print out additional debug information.",
) )
args = parser.parse_args(argv, namespace) args = parser.parse_args(argv, namespace)
if args.bootstrap_backend == "nccl":
args.bind_to_device = True
return args return args
def _get_layer_args(config, tp_group, tp_size, reference=False):
hidden_size = config.num_heads * config.head_dim
input_shape = [config.seq_length, config.batch_size, hidden_size]
args = [hidden_size]
kwargs = {
"params_dtype": torch.float32,
"device": "cuda",
"tp_group": tp_group,
"tp_size": tp_size,
"sequence_parallel": True,
}
kwargs["ub_overlap_ag"] = not config.no_comm_overlap
if config.layer_type is te.Linear:
input_shape[2] = hidden_size // tp_size
args.append(hidden_size)
kwargs["parallel_mode"] = "row"
kwargs["ub_overlap_rs"] = not config.no_comm_overlap
kwargs["ub_name"] = "proj"
else:
input_shape[0] = config.seq_length // tp_size
kwargs["ub_bulk_wgrad"] = not config.no_comm_overlap
kwargs["ub_bulk_dgrad"] = not config.no_comm_overlap
if config.layer_type is te.LayerNormLinear:
args.append(3 * hidden_size)
kwargs["parallel_mode"] = "column"
kwargs["ub_name"] = "qkv"
else:
kwargs["set_parallel_mode"] = True
kwargs["ub_overlap_rs"] = not config.no_comm_overlap
if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]:
args.append(4 * hidden_size)
kwargs["seq_length"] = config.seq_length
if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
args.append(config.num_heads)
kwargs["attention_dropout"] = 0.0
kwargs["fuse_qkv_params"] = True
if config.layer_type is te.MultiheadAttention:
kwargs["input_layernorm"] = True
else:
kwargs["ub_tp_comm_overlap"] = not config.no_comm_overlap
kwargs["hidden_dropout"] = 0.0
return args, kwargs, input_shape
def _train(opts): def _train(opts):
if "OMPI_COMM_WORLD_SIZE" in os.environ: if "OMPI_COMM_WORLD_SIZE" in os.environ:
# Execution with `mpirun -np N` # Execution with `mpirun -np N`
...@@ -110,19 +182,6 @@ def _train(opts): ...@@ -110,19 +182,6 @@ def _train(opts):
raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!") raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!")
NUM_NODES = WORLD_SIZE // LOCAL_SIZE NUM_NODES = WORLD_SIZE // LOCAL_SIZE
def dist_print(msg, group=None, end="\n", debug=False):
if debug and not opts.debug:
return
group = dist.new_group() if group is None else group
group_rank = dist.get_rank(group)
group_size = dist.get_world_size(group)
all_ranks = dist.get_process_group_ranks(group)
ranks_skip = all_ranks[1] - all_ranks[0] > 1
group_id = WORLD_RANK % group_size if ranks_skip else WORLD_RANK // group_size
if group_rank == 0 or opts.verbose:
print(f"[rank{WORLD_RANK}:node{group_id}] {msg}{end}", end="", flush=True)
dist.barrier(group)
# Initialize torch.distributed global process group and get DP/TP groups # Initialize torch.distributed global process group and get DP/TP groups
torch.cuda.set_device(LOCAL_RANK) torch.cuda.set_device(LOCAL_RANK)
dist_init_kwargs = { dist_init_kwargs = {
...@@ -143,75 +202,117 @@ def _train(opts): ...@@ -143,75 +202,117 @@ def _train(opts):
assert dist.is_nccl_available() assert dist.is_nccl_available()
dist.init_process_group(**dist_init_kwargs) dist.init_process_group(**dist_init_kwargs)
nccl_world = dist.new_group(backend="nccl") nccl_world = dist.new_group(backend="nccl")
dist_print(f"Initialized default NCCL process group with {WORLD_RANK} GPUs", nccl_world)
def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False):
if debug and not opts.debug:
return
group_rank = dist.get_rank(group)
stream = sys.stderr if error else sys.stdout
if group_rank == src:
stream.write(f"[rank{WORLD_RANK}] {msg}{end}")
dist.barrier(group)
dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs")
# Figure out process groups for tensor- and data-parallelism (if any) # Figure out process groups for tensor- and data-parallelism (if any)
if NUM_NODES > 1: if NUM_NODES > 1:
# Create a list of world ranks on this node # Create a list of world ranks on this node
hostnames = [None for _ in range(WORLD_SIZE)]
hostname = socket.gethostname() hostname = socket.gethostname()
ifname = os.getenv(
"NVTE_UB_SOCKET_IFNAME",
os.getenv("NCCL_SOCKET_IFNAME", os.getenv("GLOO_SOCKET_IFNAME")),
)
if ifname is not None:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
hostname = socket.inet_ntoa(
fcntl.ioctl(
s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8"))
)[20:24]
)
except OSError as err:
raise OSError(f"Invalid network interface: {ifname}") from err
hostnames = [None for _ in range(WORLD_SIZE)]
dist.all_gather_object(hostnames, hostname) dist.all_gather_object(hostnames, hostname)
node_ranks = [] unique_hosts = []
for host in hostnames:
if host not in unique_hosts:
unique_hosts.append(host)
assert len(unique_hosts) == NUM_NODES
ranks_per_node_list = [[] for _ in range(NUM_NODES)]
self_node_idx = -1
for i, host in enumerate(hostnames): for i, host in enumerate(hostnames):
node_idx = unique_hosts.index(host)
ranks_per_node_list[node_idx].append(i)
if host == hostname: if host == hostname:
node_ranks.append(i) self_node_idx = node_idx
assert self_node_idx >= 0
self_node_ranks = ranks_per_node_list[self_node_idx]
if opts.num_replicas > 1: if opts.num_replicas > 1:
# Split node ranks into multiple replicas # Split node ranks into multiple replicas
assert len(node_ranks) % opts.num_replicas == 0 assert len(self_node_ranks) % opts.num_replicas == 0
tp_size = len(node_ranks) // opts.num_replicas tp_size = len(self_node_ranks) // opts.num_replicas
found_replica = False ranks_per_replica_list = []
for replica in range(opts.num_replicas): for node_ranks in ranks_per_node_list:
start = replica * tp_size for i in range(opts.num_replicas):
end = start + tp_size start = i * tp_size
tp_ranks = node_ranks[start:end] end = start + tp_size
if WORLD_RANK in tp_ranks: ranks_per_replica_list.append(node_ranks[start:end])
found_replica = True
self_replica_idx = -1
for i, replica_ranks in enumerate(ranks_per_replica_list):
if WORLD_RANK in replica_ranks:
self_replica_idx = i
break break
assert found_replica assert self_replica_idx >= 0
else: else:
# The entire node is the tensor-parallel group # The entire node is the tensor-parallel group
tp_ranks = node_ranks ranks_per_replica_list = ranks_per_node_list
self_replica_idx = self_node_idx
tp_group = dist.new_group(backend="nccl", ranks=tp_ranks)
tp_size = dist.get_world_size(tp_group)
tp_rank = dist.get_rank(tp_group)
# Data-parallelism across TP groups tp_group, _ = dist.new_subgroups_by_enumeration(ranks_per_replica_list, backend="nccl")
dp_start = tp_rank ranks_per_replica_tensor = torch.tensor(ranks_per_replica_list, dtype=torch.int32)
dp_end = dp_start + WORLD_SIZE dp_group, _ = dist.new_subgroups_by_enumeration(
dp_ranks = list(range(dp_start, dp_end, tp_size)) ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl"
dp_group = dist.new_group(backend="nccl", ranks=dp_ranks) )
else: else:
if opts.num_replicas > 1: if opts.num_replicas > 1:
# Mixed data- and tensor-parallelism on a single node # Mixed data- and tensor-parallelism on a single node
# NOTE: Avoid dist.init_device_mesh() to support older PyTorch versions # NOTE: Avoid dist.init_device_mesh() to support older PyTorch versions
all_ranks = torch.tensor(list(range(LOCAL_SIZE)), dtype=torch.uint8, device="cpu") all_ranks = torch.tensor(list(range(LOCAL_SIZE)), dtype=torch.uint8, device="cpu")
mesh2d = all_ranks.reshape((opts.num_replicas, LOCAL_SIZE // opts.num_replicas)) ranks_per_replica_tensor = all_ranks.reshape(
node_idx = (mesh2d == LOCAL_RANK).nonzero().squeeze().tolist() (opts.num_replicas, LOCAL_SIZE // opts.num_replicas)
)
tp_ranks = mesh2d[node_idx[0], :].tolist() tp_group, _ = dist.new_subgroups_by_enumeration(
tp_group = dist.new_group(backend="nccl", ranks=tp_ranks) ranks_per_replica_tensor.tolist(), backend="nccl"
)
dp_ranks = mesh2d[:, node_idx[1]].tolist() dp_group, _ = dist.new_subgroups_by_enumeration(
dp_group = dist.new_group(backend="nccl", ranks=dp_ranks) ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl"
)
else: else:
dp_group = None dp_group = None
tp_group = nccl_world tp_group = nccl_world
tp_rank = dist.get_rank(tp_group) tp_rank = dist.get_rank(tp_group)
tp_size = dist.get_world_size(tp_group) tp_size = dist.get_world_size(tp_group)
dist_print( dist_print(
f"Created tensor-parallel group: {dist.get_process_group_ranks(tp_group)}", f"Created tensor-parallel group: {dist.get_process_group_ranks(tp_group)}",
group=tp_group, group=tp_group,
) )
if dp_group is not None: if dp_group is not None:
dp_rank = dist.get_rank(dp_group)
dist_print( dist_print(
f"Created data-parallel group: {dist.get_process_group_ranks(dp_group)}", f"Created data-parallel group: {dist.get_process_group_ranks(dp_group)}",
group=dp_group, group=dp_group,
) )
else:
dp_rank = 0
# Intialize userbuffers # Intialize userbuffers
hidden_size = opts.num_heads * opts.head_dim hidden_size = opts.num_heads * opts.head_dim
...@@ -226,26 +327,12 @@ def _train(opts): ...@@ -226,26 +327,12 @@ def _train(opts):
) )
# Initialize the fused LayerNorm + Multi-layer Perceptron module # Initialize the fused LayerNorm + Multi-layer Perceptron module
torch.manual_seed(opts.seed + tp_rank) torch.manual_seed(opts.seed + dp_rank)
torch.cuda.manual_seed(opts.seed + tp_rank) torch.cuda.manual_seed(opts.seed + tp_rank)
model = te.LayerNormMLP( layer_args, layer_kwargs, input_size = _get_layer_args(opts, tp_group, tp_size)
hidden_size, model = opts.layer_type(*layer_args, **layer_kwargs)
opts.mlp_expansion_factor * hidden_size,
params_dtype=torch.bfloat16,
device="cuda",
tp_group=tp_group,
tp_size=tp_size,
set_parallel_mode=True,
sequence_parallel=True, # this is required for comm+GEMM overlap
seq_length=opts.seq_length,
ub_overlap_rs=not opts.no_comm_overlap,
ub_overlap_ag=not opts.no_comm_overlap,
ub_overlap_rs_dgrad=not opts.no_comm_overlap,
ub_bulk_dgrad=False,
ub_bulk_wgrad=not opts.no_comm_overlap,
)
if dp_group is not None: if dp_group is not None:
model = DistributedDataParallel(model, process_group=dp_group) model = DistributedDataParallel(model, dim=1, process_group=dp_group)
# Initialize optimizer with model parameters # Initialize optimizer with model parameters
optim = torch.optim.Adam(model.parameters(), lr=0.0001) optim = torch.optim.Adam(model.parameters(), lr=0.0001)
...@@ -255,28 +342,28 @@ def _train(opts): ...@@ -255,28 +342,28 @@ def _train(opts):
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
# Start dummy "training" iterations # Start dummy "training" iterations
dist_print("Starting training iterations...", nccl_world) dist_print("Starting training iterations...")
for i in range(opts.num_iters): for i in range(opts.num_iters):
dist_print(f" Iter {i+1}", tp_group, debug=True) dist_print(f" Iter {i+1}", group=tp_group, debug=True)
dist_print(" |-- Generate random input batch", tp_group, debug=True) dist_print(" |-- Generate random input batch", group=tp_group, debug=True)
x = torch.rand( x = torch.randn(input_size, dtype=torch.float32, device="cuda", requires_grad=True)
(opts.seq_length // tp_size, opts.batch_size, hidden_size),
dtype=torch.bfloat16, dist_print(" |-- Forward pass", group=tp_group, debug=True)
device="cuda", with torch.amp.autocast("cuda", dtype=torch.bfloat16):
requires_grad=True, with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world):
) y = model(x)
if isinstance(y, tuple):
dist_print(" |-- Forward pass", tp_group, debug=True) out, *_ = y
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world): else:
y = model(x) out = y
dist_print(" |-- Compute loss", tp_group, debug=True) dist_print(" |-- Compute loss", group=tp_group, debug=True)
loss = y.flatten().sum() loss = out.sum()
dist_print(" |-- Backward pass", tp_group, debug=True) dist_print(" |-- Backward pass", group=tp_group, debug=True)
loss.backward() loss.backward()
dist_print(" |-- Optimizer step", tp_group, debug=True) dist_print(" |-- Optimizer step", group=tp_group, debug=True)
optim.step() optim.step()
torch.cuda.synchronize() torch.cuda.synchronize()
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
set -e set -e
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py
git clone https://github.com/NVIDIA/Megatron-LM.git git clone https://github.com/NVIDIA/Megatron-LM.git
cd Megatron-LM cd Megatron-LM
......
#!/usr/bin/python3
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import sys
import socket
import argparse
import warnings
from functools import partial
import torch
import torch.distributed as dist
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
def _te_layer_argtype(name):
te_layers = [
te.Linear,
te.LayerNormLinear,
te.LayerNormMLP,
te.MultiheadAttention,
te.TransformerLayer,
]
layer_map = dict(zip([layer.__name__.lower() for layer in te_layers], te_layers))
if name.lower() not in layer_map.keys():
raise argparse.ArgumentTypeError(
f"Invalid TE layer name! Please choose from: {layer_map.keys()}"
)
return layer_map[name.lower()]
def _get_layer_args(config, tp_group, tp_size, reference=False):
hidden_size = config.num_heads * config.head_dim
input_shape = [config.seq_length, config.batch_size, hidden_size]
args = [hidden_size]
kwargs = {
"params_dtype": torch.float32,
"device": "cuda",
"tp_group": tp_group,
"tp_size": tp_size,
"sequence_parallel": True,
}
kwargs["ub_overlap_ag"] = not reference
if config.layer_type is te.Linear:
input_shape[2] = hidden_size // tp_size
args.append(hidden_size)
kwargs["parallel_mode"] = "row"
kwargs["ub_overlap_rs"] = not reference
kwargs["ub_name"] = "proj"
else:
input_shape[0] = config.seq_length // tp_size
kwargs["ub_bulk_wgrad"] = not reference
kwargs["ub_bulk_dgrad"] = not reference
if config.layer_type is te.LayerNormLinear:
args.append(3 * hidden_size)
kwargs["parallel_mode"] = "column"
kwargs["ub_name"] = "qkv"
else:
kwargs["set_parallel_mode"] = True
kwargs["ub_overlap_rs"] = not reference
if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]:
args.append(4 * hidden_size)
kwargs["seq_length"] = config.seq_length
if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
args.append(config.num_heads)
kwargs["attention_dropout"] = 0.0
kwargs["fuse_qkv_params"] = True
if config.layer_type is te.MultiheadAttention:
kwargs["input_layernorm"] = True
else:
kwargs["ub_tp_comm_overlap"] = not reference
kwargs["hidden_dropout"] = 0.0
return args, kwargs, input_shape
def _parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser(
description="Test a Transformer Engine layer with GEMM+comm overlap via Userbuffers."
)
parser.add_argument("-l", "--layer-type", type=_te_layer_argtype, default=te.LayerNormMLP)
parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.")
parser.add_argument("-s", "--seq-length", type=int, default=2048, help="Input sequence length.")
parser.add_argument(
"-n", "--num-heads", type=int, default=12, help="Number of attention heads."
)
parser.add_argument(
"-d", "--head-dim", type=int, default=64, help="Dimension of each attention head."
)
parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
)
parser.add_argument(
"--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8."
)
parser.add_argument(
"--tcp-init",
action="store_true",
default=False,
help="Initialize torch.distributed with TcpStore.",
)
parser.add_argument(
"--bind-to-device",
action="store_true",
default=False,
help="Initialize torch.distributed with `device_id` to bind each rank to a single device.",
)
parser.add_argument(
"--bootstrap-backend",
type=str.lower,
default="nccl",
choices=["gloo", "mpi", "nccl"],
help="Communications backend for host tensor collectives during Userbuffers bootstrapping.",
)
parser.add_argument(
"--use-cuda-graphs", action="store_true", default=False, help="Use CUDA Graphs."
)
parser.add_argument(
"--debug",
action="store_true",
default=False,
help="Print out additional debug information.",
)
args = parser.parse_args(argv, namespace)
if args.use_cuda_graphs and args.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
warnings.warn(f"{args.layer_type.__name__} does not support CUDA Graphs!")
args.use_cuda_graphs = False
return args
def _compare_tensors(name, test, ref, rtol, atol):
# Make sure tensors aren't zero and we don't pass trivially
if test.count_nonzero() == 0:
if ref.count_nonzero() == 0:
warnings.warn(
f"WARNING: {name} is a zero-tensor for both test and reference models!",
category=RuntimeWarning,
)
else:
numerics_info = (
f"NUMERICAL CHECK FAILED: {name} is a zero-tensor but does not match reference!"
)
return 1, numerics_info
diff = torch.abs(test - ref).flatten()
m = torch.argmax(diff)
abs_err = diff[m].item()
rel_err = abs_err / max(abs(ref.flatten()[m].item()), 1e-5)
numerics_failed = 0
if rel_err > rtol and abs_err > atol:
numerics_failed = 1
numerics_info = (
"NUMERICAL CHECK FAILED: "
+ f"{name} not close enough at index {m.item()} "
+ f"with {test.flatten()[m].item()} vs {ref.flatten()[m].item()} | "
+ f"rel. error = {rel_err} (tol = {rtol}) | "
+ f"abs. error = {abs_err} (tol = {atol})"
)
else:
numerics_info = f"NUMERICAL CHECK PASSED: {name} | "
if rel_err <= rtol:
numerics_info += f"rel. error = {rel_err} (tol = {rtol})" + (
" | " if abs_err <= atol else "."
)
if abs_err <= atol:
numerics_info += f" abs. error = {abs_err} (tol = {atol})"
return numerics_failed, numerics_info
def _train(opts):
if "OMPI_COMM_WORLD_SIZE" in os.environ:
# Execution with `mpirun -np N`
WORLD_RANK = int(os.getenv("OMPI_COMM_WORLD_RANK", "0"))
WORLD_SIZE = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1"))
opts.tcp_init = True
opts.bind_to_device = True
opts.bootstrap_backend = "mpi"
elif "TORCHELASTIC_RUN_ID" in os.environ:
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
else:
raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!")
assert LOCAL_SIZE == WORLD_SIZE
def dist_print(msg, src=None, end="\n", debug=False, error=False):
if debug and not opts.debug:
return
stream = sys.stderr if error else sys.stdout
if WORLD_RANK == (0 if src is None else src):
stream.write(f"[rank{WORLD_RANK}] {msg}{end}\n")
dist.barrier()
# Set device and initialize RNG states
torch.cuda.set_device(WORLD_RANK)
torch.manual_seed(opts.seed)
torch.cuda.manual_seed(opts.seed)
# Initialize torch.distributed global process group and get DP/TP groups
dist_init_kwargs = {
"backend": "nccl",
"rank": WORLD_RANK,
"world_size": WORLD_SIZE,
}
if opts.tcp_init:
MASTER_ADDR = os.getenv("MASTER_ADDR", socket.gethostbyname(socket.gethostname()))
MASTER_PORT = os.getenv("MASTER_PORT", "1234")
dist_init_kwargs["init_method"] = f"tcp://{MASTER_ADDR}:{MASTER_PORT}"
if opts.bind_to_device or opts.bootstrap_backend == "nccl":
dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}")
assert dist.is_nccl_available()
dist.init_process_group(**dist_init_kwargs)
nccl_world = dist.new_group(backend="nccl")
dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs")
# Intialize userbuffers
te.module.base.initialize_ub(
[opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim],
WORLD_SIZE,
use_fp8=opts.fp8,
dtype=torch.bfloat16,
bootstrap_backend=opts.bootstrap_backend,
)
# Initialize the Transformer Engine layer with overlap
args, kwargs, input_shape = _get_layer_args(opts, nccl_world, WORLD_SIZE)
with te.fp8_model_init(enabled=opts.fp8_init):
test_model = opts.layer_type(*args, **kwargs)
dist_print("Initialized test model...", debug=True)
# Initialize the reference model and copy all parameters
ref_args, ref_kwargs, _ = _get_layer_args(opts, nccl_world, WORLD_SIZE, reference=True)
with te.fp8_model_init(enabled=opts.fp8_init):
ref_model = opts.layer_type(*ref_args, **ref_kwargs)
dist_print("Initialized reference model...", debug=True)
for test_param, ref_param in zip(test_model.parameters(), ref_model.parameters()):
with torch.no_grad():
ref_param.copy_(test_param)
torch.testing.assert_close(test_param, ref_param, rtol=0.0, atol=0.0)
dist_print("Copied parameters from test model to reference model...", debug=True)
# Fp8 recipe setup
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
# Prepare random input tensors
test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True)
test_x.retain_grad()
ref_x = torch.empty_like(test_x).requires_grad_(True)
with torch.no_grad():
ref_x.copy_(test_x)
torch.testing.assert_close(test_x, ref_x, rtol=0.0, atol=0.0)
ref_x.retain_grad()
# Execute fwd/bwd and collect tensors to test
def run_fwd_bwd(model, x):
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world):
y = model(x)
if isinstance(y, tuple):
out, *_ = y
else:
out = y
loss = out.sum()
loss.backward()
return out
torch_rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state(torch.device(f"cuda:{WORLD_RANK}"))
if opts.use_cuda_graphs:
test_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(test_graph):
test_out = run_fwd_bwd(test_model, test_x)
test_graph.replay()
del test_graph
else:
test_out = run_fwd_bwd(test_model, test_x)
test_grads = [test_out, test_x.grad]
names = ["output", "input.grad"]
for test_name, test_param in test_model.named_parameters():
if test_param.requires_grad and "layer_norm" not in test_name:
test_grads.append(test_param.grad)
names.append(test_name + ".grad")
torch.set_rng_state(torch_rng_state)
torch.cuda.set_rng_state(cuda_rng_state, torch.device(f"cuda:{WORLD_RANK}"))
if opts.use_cuda_graphs:
ref_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(ref_graph):
ref_out = run_fwd_bwd(ref_model, ref_x)
ref_graph.replay()
del ref_graph
else:
ref_out = run_fwd_bwd(ref_model, ref_x)
ref_grads = [ref_out, ref_x.grad]
for ref_name, ref_param in ref_model.named_parameters():
if ref_param.requires_grad and "layer_norm" not in ref_name:
ref_grads.append(ref_param.grad)
# Make sure we have the same number of gradients
numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda")
if len(test_grads) != len(ref_grads):
numerics_failed[0] = 1
numerics_info = (
"NUMERICAL CHECK FAILED: Incorrect number of gradients, "
+ f"expected {len(ref_grads)} but got {len(test_grads)}."
)
dist_print(numerics_info, src=WORLD_RANK, error=True)
dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world)
# Now validate accuracy
if not bool(numerics_failed.item()):
for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)):
rtol = 0.125 if opts.fp8 else 0.025
atol = 0.0625 if opts.fp8 else 0.00125
grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol)
dist_print(grad_info, src=WORLD_RANK, error=grad_failed)
numerics_failed[0] = int(grad_failed)
dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world)
if bool(numerics_failed.item()):
break
te.module.base.destroy_ub()
dist_print("Destroying Userbuffers objects...", debug=True)
dist_print("Destroying all process groups...", debug=True)
dist.destroy_process_group()
if opts.debug and WORLD_RANK == 0:
print("Exiting...\n", end="", flush=True)
return numerics_failed[0].item()
if __name__ == "__main__":
sys.exit(_train(_parse_args()))
...@@ -7,16 +7,27 @@ from pathlib import Path ...@@ -7,16 +7,27 @@ from pathlib import Path
import pytest import pytest
import torch import torch
import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
if torch.cuda.device_count() < 2:
pytest.skip("Comm+GEMM overlap requires at least 2 GPUs.")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
RNG_SEED: int = 1234 RNG_SEED: int = 1234
SEQ_LENGTH: int = 2024 SEQ_LENGTH: int = 512
BATCH_SIZE: int = 2 BATCH_SIZE: int = 2
NUM_HEADS: int = 64 NUM_HEADS: int = 12
HEAD_DIM: int = 128 HEAD_DIM: int = 64
TE_LAYERS = [
te.Linear,
te.LayerNormLinear,
te.LayerNormMLP,
te.MultiheadAttention,
te.TransformerLayer,
]
TEST_ROOT = Path(__file__).parent.resolve() TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(torch.cuda.device_count(), 4) NUM_PROCS: int = min(torch.cuda.device_count(), 4)
...@@ -32,66 +43,28 @@ if not tex.device_supports_multicast(): ...@@ -32,66 +43,28 @@ if not tex.device_supports_multicast():
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
@pytest.mark.skipif(NUM_PROCS < 2, reason="Comm+GEMM overlap requires at least 2 GPUs.") def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggregate):
@pytest.mark.parametrize(
"fp8,p2p,comm_type,aggregate,atomic,bulk",
[
# FP8, P2P, Type, Aggregate, Atomic, Bulk
(False, True, "AG", False, False, False),
(False, True, "AG", True, False, False),
(True, True, "AG", False, False, False),
(True, True, "AG", True, False, False),
(False, False, "RS", False, False, False),
(False, True, "RS", False, False, False),
(True, False, "RS", False, False, False),
(True, True, "RS", False, False, False),
(True, False, "RS", False, True, False),
(True, True, "RS", False, True, False),
(False, False, "AG", False, False, True),
(False, False, "RS", False, False, True),
],
ids=[
" AG -> SPLIT GEMM | BF16 | RING-EXCHANGE ",
" AG -> SPLIT GEMM | BF16 | RING-EXCHANGE (2X AGGREGATED) ",
" AG -> SPLIT GEMM | FP8 | RING-EXCHANGE ",
" AG -> SPLIT GEMM | FP8 | RING-EXCHANGE (2X AGGREGATED) ",
" SPLIT GEMM -> RS | BF16 | PIPELINE ",
" SPLIT GEMM -> RS | BF16 | RING-EXCHANGE ",
" SPLIT GEMM -> RS | FP8 | PIPELINE ",
" SPLIT GEMM -> RS | FP8 | RING-EXCHANGE ",
" ATOMIC GEMM -> RS | FP8 | PIPELINE ",
" ATOMIC GEMM -> RS | FP8 | RING-EXCHANGE ",
" BULK AG & GEMM | BF16 | PIPELINE ",
" BULK RS & GEMM | BF16 | PIPELINE ",
],
)
def test_gemm_with_overlap(fp8, p2p, comm_type, aggregate, atomic, bulk):
"""
Test comm+GEMM overlap algorithms with direct calls to
te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm
"""
test_path = TEST_ROOT / "run_gemm_with_overlap.py" test_path = TEST_ROOT / "run_gemm_with_overlap.py"
test_cmd = ( test_cmd = LAUNCH_CMD + [
LAUNCH_CMD str(test_path),
+ [str(test_path)] "--check-numerics",
+ [ f"--seed={RNG_SEED}",
"--check-numerics", f"--seq-length={SEQ_LENGTH}",
f"--seed={RNG_SEED}", f"--batch-size={BATCH_SIZE}",
f"--seq-length={SEQ_LENGTH}", f"--num-heads={NUM_HEADS}",
f"--batch-size={BATCH_SIZE}", f"--head-dim={HEAD_DIM}",
f"--num-heads={NUM_HEADS}", f"--comm-type={comm_type}",
f"--head-dim={HEAD_DIM}", ]
f"--comm-type={comm_type}",
]
)
if bulk: if bulk:
test_cmd.append("--bulk-overlap") test_cmd.append("--bulk-overlap")
else: else:
if fp8: if fp8_in:
if not fp8_available: if not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
test_cmd.append("--fp8") test_cmd.append("--fp8")
if fp8_out:
test_cmd.append("--fp8-output")
if p2p: if p2p:
test_cmd.append("--p2p") test_cmd.append("--p2p")
if aggregate: if aggregate:
...@@ -101,5 +74,173 @@ def test_gemm_with_overlap(fp8, p2p, comm_type, aggregate, atomic, bulk): ...@@ -101,5 +74,173 @@ def test_gemm_with_overlap(fp8, p2p, comm_type, aggregate, atomic, bulk):
pytest.skip("Device compute capability 9.0 or higher required for Atomic GEMM.") pytest.skip("Device compute capability 9.0 or higher required for Atomic GEMM.")
test_cmd.append("--atomic") test_cmd.append("--atomic")
output = subprocess.run(test_cmd, env=os.environ, text=True, capture_output=True, check=False) result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False)
assert "NUMERICAL CHECK PASSED" in str(output) if (
result.returncode != 0
or "NUMERICAL CHECK FAILED" in result.stderr.decode()
or "NUMERICAL CHECK PASSED" not in result.stdout.decode()
):
raise AssertionError(result.stderr.decode())
def _run_layer_with_overlap(layer_type, fp8, fp8_init):
test_path = TEST_ROOT / "run_layer_with_overlap.py"
test_cmd = LAUNCH_CMD + [
str(test_path),
f"--seed={RNG_SEED}",
f"--seq-length={SEQ_LENGTH}",
f"--batch-size={BATCH_SIZE}",
f"--num-heads={NUM_HEADS}",
f"--head-dim={HEAD_DIM}",
f"--layer-type={layer_type}",
]
if fp8:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
test_cmd.append("--fp8")
if fp8_init:
test_cmd.append("--fp8-init")
os.environ["PYTORCH_JIT"] = "0"
os.environ["NVTE_TORCH_COMPILE"] = "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0"
result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False)
os.unsetenv("PYTORCH_JIT")
os.unsetenv("NVTE_TORCH_COMPILE")
os.unsetenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO")
if (
result.returncode != 0
or "NUMERICAL CHECK FAILED" in result.stderr.decode()
or "NUMERICAL CHECK PASSED" not in result.stdout.decode()
):
raise AssertionError(result.stderr.decode())
@pytest.mark.parametrize(
"fp8,aggregate",
[
(False, False),
(False, True),
(True, False),
(True, True),
],
ids=[
" BF16 IN - RING-EXCHANGE ",
" BF16 IN - RING-EXCHANGE - 2x AGGREGATED ",
" FP8 IN - RING-EXCHANGE ",
" FP8 IN - RING-EXCHANGE - 2x AGGREGATED ",
],
)
def test_split_all_gather_overlaps(fp8, aggregate):
"""
Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm.
"""
_run_gemm_with_overlap("AG", False, True, False, fp8, False, aggregate)
@pytest.mark.parametrize(
"fp8_in,fp8_out,p2p",
[
(False, False, False),
(False, False, True),
(True, False, False),
(True, False, True),
(True, True, False),
(True, True, True),
],
ids=[
" BF16 IN - BF16 OUT - PIPELINE ",
" BF16 IN - BF16 OUT - RING-EXCHANGE ",
" FP8 IN - BF16 OUT - PIPELINE ",
" FP8 IN - BF16 OUT - RING-EXCHANGE ",
" FP8 IN - FP8 OUT - PIPELINE ",
" FP8 IN - FP8 OUT - RING-EXCHANGE ",
],
)
def test_split_reduce_scatter_overlaps(fp8_in, fp8_out, p2p):
"""
Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm.
"""
_run_gemm_with_overlap("RS", False, p2p, False, fp8_in, fp8_out, False)
@pytest.mark.parametrize(
"ag_type,rs_type,p2p,fp8_out",
[
(0, 0, False, False),
(0, 1, False, False),
(0, 1, False, True),
(0, 2, False, False),
(0, 2, False, True),
(0, 0, True, False),
(0, 0, True, True),
(1, 0, True, False),
(1, 0, True, True),
],
ids=[
" NON-ATOMIC AG - NON-ATOMIC RS - PIPELINE - BF16 OUT ",
" NON-ATOMIC AG - ATOMIC RS - PIPELINE - BF16 OUT ",
" NON-ATOMIC AG - ATOMIC RS - PIPELINE - FP8 OUT ",
" NON-ATOMIC AG - MULTI-ATOMIC RS - PIPELINE - BF16 OUT ",
" NON-ATOMIC AG - MULTI-ATOMIC RS - PIPELINE - FP8 OUT ",
" NON-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - BF16 OUT ",
" NON-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - FP8 OUT ",
" MULTI-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - BF16 OUT ",
" MULTI-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - FP8 OUT ",
],
)
def test_atomic_gemm_overlaps(ag_type, rs_type, p2p, fp8_out):
"""
Test paired (all-gather -> atomic GEMM) and (atomic GEMM -> reduce-scatter) overlaps with
direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm.
"""
os.environ["NVTE_AG_P2P_MULTI_ATOMIC"] = str(ag_type)
os.environ["NVTE_RS_STRIDED_ATOMIC"] = str(rs_type)
_run_gemm_with_overlap("AG", False, p2p, True, True, fp8_out, False)
@pytest.mark.parametrize(
"comm_type,fp8",
[
("AG", False),
("RS", False),
("RS", True),
],
ids=[" ALL-GATHER - BF16 ", " REDUCE-SCATTER - BF16 ", " REDUCE-SCATTER - FP8 "],
)
def test_bulk_overlaps(comm_type, fp8):
"""
Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm.
"""
_run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False)
@pytest.mark.parametrize(
"layer_type",
[layer.__name__ for layer in TE_LAYERS],
ids=[(" " + layer.__name__ + " ") for layer in TE_LAYERS],
)
@pytest.mark.parametrize(
"fp8,fp8_init",
[
(False, False),
(True, False),
(True, True),
],
ids=[
" BF16 GEMM - BF16 PARAMS ",
" FP8 GEMM - BF16 PARAMS ",
" FP8 GEMM - FP8 PARAMS ",
],
)
def test_layers_with_overlap(layer_type, fp8, fp8_init):
"""
Test Transformer Engine layers with comm+GEMM overlap.
"""
_run_layer_with_overlap(layer_type, fp8, fp8_init)
...@@ -166,7 +166,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -166,7 +166,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
// Initialize userbuf communicator // Initialize userbuf communicator
if (!comm_created) { if (!comm_created) {
if (myrank == 0) { if (myrank == 0) {
printf("!!! [UB] Create UbufCommOverlap Communicator\n"); printf("!!! [UB] Create Userbuffers Communicator\n");
} }
#ifdef NVTE_UB_WITH_MPI #ifdef NVTE_UB_WITH_MPI
create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1);
...@@ -184,16 +184,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -184,16 +184,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
// Allocate and register extra userbuffers // Allocate and register extra userbuffers
int ubuf_bytes = sample.numel() * sample.element_size(); int ubuf_bytes = sample.numel() * sample.element_size();
if (transformer_engine::getenv<bool>("UB_SKIPMC")) { _ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes,
_ubuf = torch::zeros_like(sample); _ub_comm, true);
_ubuf_ptr = _ubuf.data_ptr(); _ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options());
_ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes,
_ub_comm, false);
} else {
_ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes,
_ub_comm, true);
_ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options());
}
if (_ub_comm->myrank == 0) { if (_ub_comm->myrank == 0) {
printf("!!! [UB] Register UBuf %d\n", _ub_reg); printf("!!! [UB] Register UBuf %d\n", _ub_reg);
...@@ -264,6 +257,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -264,6 +257,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, int comm_type, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int comm_type,
at::Tensor rs_output) { at::Tensor rs_output) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm; _ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size; _ub_comm->cga_size = _cga_size;
...@@ -319,6 +313,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -319,6 +313,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size;
int output_c_dim1 = _ubuf.size(1); int output_c_dim1 = _ubuf.size(1);
output_tensor = torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options()); output_tensor = torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options());
_ub_comm->sms = ori_sms;
return {D, output_tensor}; return {D, output_tensor};
} // bulk_overlap } // bulk_overlap
...@@ -336,6 +331,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -336,6 +331,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
bool grad, at::Tensor workspace, size_t workspaceSize, bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, bool gemm_overlap, bool accumulate, bool use_split_accumulator, bool gemm_overlap,
at::Tensor rs_output) { at::Tensor rs_output) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm; _ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size; _ub_comm->cga_size = _cga_size;
...@@ -352,7 +348,6 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -352,7 +348,6 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr()); char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr());
int *counter_ptr = reinterpret_cast<int *>(counter.data_ptr()); int *counter_ptr = reinterpret_cast<int *>(counter.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr()); char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
int ori_sms = _ub_comm->sms;
// Catch up the default torch stream // Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
...@@ -388,7 +383,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -388,7 +383,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
assert(_ubuf_scale_inv_initialized); assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr()); float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type, D_type, fp8_type,
reducescatter2_userbuff_strided_atomic_fp8<fp8_type>( reducescatter2_userbuff_strided_atomic_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, m, rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, m,
_num_splits, &counter_ptr[i], _ub_comm, (cudaStream_t)_stream_comm);); _num_splits, &counter_ptr[i], _ub_comm, (cudaStream_t)_stream_comm););
...@@ -402,7 +397,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -402,7 +397,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
assert(_ubuf_scale_inv_initialized); assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr()); float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type, D_type, fp8_type,
reducescatter2_userbuff_strided_multiatomic_fp8<fp8_type>( reducescatter2_userbuff_strided_multiatomic_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, rs_output_ptr, d_scale_inv_ptr, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits,
counter_ptr, _ub_comm, (cudaStream_t)_stream_comm);); counter_ptr, _ub_comm, (cudaStream_t)_stream_comm););
...@@ -413,10 +408,8 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -413,10 +408,8 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
} }
break; break;
} else { } else {
assert(_ubuf.element_size() != 1);
consumer(counter_ptr, i, (cudaStream_t)_stream_comm); consumer(counter_ptr, i, (cudaStream_t)_stream_comm);
// if (i == _num_splits-1) {
// _ub_comm->sms = UB_MAX_SM;
// }
reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm); _ub_comm, (cudaStream_t)_stream_comm);
} }
...@@ -447,6 +440,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -447,6 +440,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
size_t workspaceSize, bool accumulate, bool use_split_accumulator, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
bool gemm_overlap, at::Tensor rs_output) { bool gemm_overlap, at::Tensor rs_output) {
// Get GEMM dimensions // Get GEMM dimensions
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm; _ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size; _ub_comm->cga_size = _cga_size;
...@@ -464,7 +458,6 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -464,7 +458,6 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr()); char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr()); char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
int ori_sms = _ub_comm->sms;
// Catch up the default torch stream // Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
...@@ -517,7 +510,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -517,7 +510,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
assert(_ubuf_scale_inv_initialized); assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr()); float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type, D_type, fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>( reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, rs_output_ptr, d_scale_inv_ptr, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n,
m, _ub_comm, (cudaStream_t)_stream_comm);); m, _ub_comm, (cudaStream_t)_stream_comm););
...@@ -541,7 +534,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -541,7 +534,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
assert(_ubuf_scale_inv_initialized); assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr()); float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type, D_type, fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>( reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size, rs_output_ptr, d_scale_inv_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size,
m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm);); m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm););
...@@ -577,7 +570,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -577,7 +570,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
assert(_ubuf_scale_inv_initialized); assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr()); float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type, D_type, fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>( reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * output_chunk_size, m_chunk, n, m, rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * output_chunk_size, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm);); _ub_comm, (cudaStream_t)_stream_comm););
...@@ -682,7 +675,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -682,7 +675,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
// Initialize userbuf communicator // Initialize userbuf communicator
if (!comm_created) { if (!comm_created) {
if (myrank == 0) { if (myrank == 0) {
printf("!!! [UB] Create UbufP2PCommOverlap Communicator\n"); printf("!!! [UB] Create Userbuffers Communicator\n");
} }
#ifdef NVTE_UB_WITH_MPI #ifdef NVTE_UB_WITH_MPI
create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1);
...@@ -708,19 +701,11 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -708,19 +701,11 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
ubuf_bytes = static_cast<int>(ubuf_bytes / tp_size * (tp_size * 2 - 1)); ubuf_bytes = static_cast<int>(ubuf_bytes / tp_size * (tp_size * 2 - 1));
num_ubuf_chunks = static_cast<int>(tp_size * 2 - 1); num_ubuf_chunks = static_cast<int>(tp_size * 2 - 1);
} }
if (transformer_engine::getenv<bool>("UB_SKIPMC")) {
_ubuf = torch::zeros({sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)}, _ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes,
sample.options()); _ub_comm, true);
_ubuf_ptr = _ubuf.data_ptr(); _ubuf = torch::from_blob(
_ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes, _ubuf_ptr, {sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)}, sample.options());
_ub_comm, false);
} else {
_ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes,
_ub_comm, true);
_ubuf =
torch::from_blob(_ubuf_ptr, {sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)},
sample.options());
}
if (_ub_comm->myrank == 0) { if (_ub_comm->myrank == 0) {
printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg);
} }
...@@ -728,9 +713,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -728,9 +713,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
// Create tensor chunks for easy management // Create tensor chunks for easy management
char *ubuf_byte_ptr = reinterpret_cast<char *>(_ubuf.data_ptr()); char *ubuf_byte_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
for (int i = 0; i < num_ubuf_chunks; i++) { for (int i = 0; i < num_ubuf_chunks; i++) {
torch::Tensor ubuf_chunk = torch::from_blob( auto ubuf_chunk = torch::from_blob(ubuf_byte_ptr, {sample.size(0) / tp_size, sample.size(1)},
ubuf_byte_ptr, {sample.size(0) / tp_size, sample.size(1)}, sample.options()); sample.options());
_ubufs.push_back(ubuf_chunk); _ubufs.push_back(std::move(ubuf_chunk));
ubuf_byte_ptr += ubuf_chunk_bytes; ubuf_byte_ptr += ubuf_chunk_bytes;
} }
...@@ -769,6 +754,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -769,6 +754,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC"); const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC");
if (_rank == 0 && env_p != nullptr) { if (_rank == 0 && env_p != nullptr) {
if (env_p[0] == '1') { if (env_p[0] == '1') {
_use_ce = 0;
_ub_comm->push = 1;
printf("!!userbuffers_sendrecv_multi_atomic_shuffle\n"); printf("!!userbuffers_sendrecv_multi_atomic_shuffle\n");
} }
} }
...@@ -818,6 +805,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -818,6 +805,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias,
transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm; _ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size; _ub_comm->cga_size = _cga_size;
...@@ -866,6 +854,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -866,6 +854,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC"); const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC");
if (env_p != nullptr && env_p[0] == '1') { if (env_p != nullptr && env_p[0] == '1') {
if (i == 0) { if (i == 0) {
_ub_comm->use_ce = 0;
userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes, userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes,
_ub_comm, _next_rank, _prev_rank, _tp_size, counter_ptr, _ub_comm, _next_rank, _prev_rank, _tp_size, counter_ptr,
true, (cudaStream_t)_stream_recv); true, (cudaStream_t)_stream_recv);
...@@ -906,6 +895,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -906,6 +895,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
n_chunk * m * D.element_size(), cudaMemcpyDeviceToDevice, n_chunk * m * D.element_size(), cudaMemcpyDeviceToDevice,
(cudaStream_t)stream_main)); (cudaStream_t)stream_main));
// Return the last N rows of D_buffer // Return the last N rows of D_buffer
_ub_comm->sms = ori_sms;
torch::Tensor D_return = D_buffer.narrow(0, n_chunk, n); torch::Tensor D_return = D_buffer.narrow(0, n_chunk, n);
return D_return; return D_return;
} // atomic_gemm_overlap_ag } // atomic_gemm_overlap_ag
...@@ -926,6 +916,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -926,6 +916,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
transformer_engine::DType bias_type, at::Tensor pre_gelu_out, transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
bool grad, at::Tensor workspace, size_t workspaceSize, bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { bool accumulate, bool use_split_accumulator, at::Tensor B_copy) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm; _ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size; _ub_comm->cga_size = _cga_size;
...@@ -1078,6 +1069,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1078,6 +1069,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
at::cuda::setCurrentCUDAStream(stream_main); at::cuda::setCurrentCUDAStream(stream_main);
_ub_comm->sms = ori_sms;
return D; return D;
} // split_overlap_ag } // split_overlap_ag
...@@ -1094,6 +1086,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1094,6 +1086,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
transformer_engine::DType bias_type, at::Tensor pre_gelu_out, transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
bool grad, at::Tensor workspace, size_t workspaceSize, bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, at::Tensor rs_output) { bool accumulate, bool use_split_accumulator, at::Tensor rs_output) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm; _ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size; _ub_comm->cga_size = _cga_size;
...@@ -1149,7 +1142,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1149,7 +1142,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr()); float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr()); char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type, D_type, fp8_type,
reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size, reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size,
_ubufs[0].numel(), (cudaStream_t)stream_main);); _ubufs[0].numel(), (cudaStream_t)stream_main););
} else { } else {
...@@ -1157,6 +1150,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1157,6 +1150,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0); torch::sum_out(rs_output, reduce_buf, 0);
} }
_ub_comm->sms = ori_sms;
} }
/* /*
...@@ -1171,6 +1165,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1171,6 +1165,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
at::Tensor rs_output) { at::Tensor rs_output) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm; _ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size; _ub_comm->cga_size = _cga_size;
...@@ -1245,7 +1240,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1245,7 +1240,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr()); float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr()); char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type, D_type, fp8_type,
reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size, reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size,
_ubufs[0].numel(), (cudaStream_t)stream_main);); _ubufs[0].numel(), (cudaStream_t)stream_main););
} else { } else {
...@@ -1259,6 +1254,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1259,6 +1254,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
} }
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
_ub_comm->sms = ori_sms;
} }
/* /*
......
...@@ -1861,6 +1861,14 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const ...@@ -1861,6 +1861,14 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)
} }
template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream);
template <typename fp8type> template <typename fp8type>
void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset, void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset,
const int elements, communicator *comm, cudaStream_t stream) { const int elements, communicator *comm, cudaStream_t stream) {
......
...@@ -107,7 +107,7 @@ def initialize_ub( ...@@ -107,7 +107,7 @@ def initialize_ub(
world_size = torch.distributed.get_world_size(mpi_group) world_size = torch.distributed.get_world_size(mpi_group)
local_rank = world_rank % tp_size local_rank = world_rank % tp_size
local_size = tp_size local_size = tp_size
node_id = world_rank // tp_size self_node_idx = world_rank // tp_size
num_nodes = world_size // tp_size num_nodes = world_size // tp_size
ub_callbacks = tex.UbufBootstrapCallbacks() ub_callbacks = tex.UbufBootstrapCallbacks()
else: else:
...@@ -127,13 +127,6 @@ def initialize_ub( ...@@ -127,13 +127,6 @@ def initialize_ub(
world_rank = torch.distributed.get_rank(world_group) world_rank = torch.distributed.get_rank(world_group)
world_size = torch.distributed.get_world_size(world_group) world_size = torch.distributed.get_world_size(world_group)
if world_rank == 0:
print(
f'!!! [NVTE] Bootstrapping Userbuffers with backend="{bootstrap_backend}"\n',
end="",
flush=True,
)
# Construct an intra-node communicator based on global ranks that share the same hostname # Construct an intra-node communicator based on global ranks that share the same hostname
# NOTE: If the user specified a valid network interface for NCCL or GLOO, use the host # NOTE: If the user specified a valid network interface for NCCL or GLOO, use the host
# address on that interface instead of the hostname. This can help avoid issues when # address on that interface instead of the hostname. This can help avoid issues when
...@@ -157,28 +150,41 @@ def initialize_ub( ...@@ -157,28 +150,41 @@ def initialize_ub(
hostnames = [None for _ in range(world_size)] hostnames = [None for _ in range(world_size)]
torch.distributed.all_gather_object(hostnames, hostname, world_group) torch.distributed.all_gather_object(hostnames, hostname, world_group)
intra_node_ranks = [] unique_hosts = []
for i, host in enumerate(hostnames): for host in hostnames:
if host == hostname: if host not in unique_hosts:
intra_node_ranks.append(i) unique_hosts.append(host)
if len(intra_node_ranks) == world_size: num_nodes = len(unique_hosts)
if num_nodes > 1:
ranks_per_node_list = [[] for _ in range(num_nodes)]
self_node_idx = -1
for i, host in enumerate(hostnames):
node_idx = unique_hosts.index(host)
ranks_per_node_list[node_idx].append(i)
if host == hostname:
self_node_idx = node_idx
assert self_node_idx >= 0, "Internal TE error!"
intra_node_group, _ = torch.distributed.new_subgroups_by_enumeration(
ranks_per_node_list, backend=bootstrap_backend
)
local_rank = torch.distributed.get_rank(intra_node_group)
local_size = torch.distributed.get_world_size(intra_node_group)
intra_node_ranks = torch.distributed.get_process_group_ranks(intra_node_group)
else:
self_node_idx = 0
intra_node_group = world_group intra_node_group = world_group
local_rank = world_rank local_rank = world_rank
local_size = world_size local_size = world_size
intra_node_ranks = list(range(world_size)) intra_node_ranks = list(range(world_size))
else:
intra_node_group = torch.distributed.new_group(
backend=bootstrap_backend, ranks=intra_node_ranks
)
local_rank = torch.distributed.get_rank(intra_node_group)
local_size = torch.distributed.get_world_size(intra_node_group)
node_id = world_rank // local_size if world_rank == 0:
num_nodes = world_size // local_size print(f"!!! [UB] Number of physical nodes: {num_nodes}\n", end="", flush=True)
if local_rank == 0: if local_rank == 0:
print( print(
f"!!! [NVTE] Number of physical nodes: {num_nodes}\n" f"!!! [UB] Global ranks on node {self_node_idx}: {intra_node_ranks}\n",
+ f"!!! [NVTE] Global ranks on node {node_id}: {intra_node_ranks}\n",
end="", end="",
flush=True, flush=True,
) )
...@@ -293,7 +299,7 @@ def initialize_ub( ...@@ -293,7 +299,7 @@ def initialize_ub(
world_size, # World size world_size, # World size
local_rank, # Rank within the node local_rank, # Rank within the node
local_size, # Number of ranks/GPUs per node local_size, # Number of ranks/GPUs per node
node_id, # Node ID self_node_idx, # Node ID
num_nodes, # Number of nodes num_nodes, # Number of nodes
tp_size, # Tensor-parallel group size (may be different than local_size) tp_size, # Tensor-parallel group size (may be different than local_size)
num_sm, # Number of communication SMs num_sm, # Number of communication SMs
...@@ -313,7 +319,7 @@ def initialize_ub( ...@@ -313,7 +319,7 @@ def initialize_ub(
world_size, # World size world_size, # World size
local_rank, # Rank within the node local_rank, # Rank within the node
local_size, # Number of ranks/GPUs per node local_size, # Number of ranks/GPUs per node
node_id, # Node ID self_node_idx, # Node ID
num_nodes, # Number of nodes num_nodes, # Number of nodes
tp_size, # Tensor-parallel group size (may be different than local_size) tp_size, # Tensor-parallel group size (may be different than local_size)
num_sm, # Number of communication SMs num_sm, # Number of communication SMs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment