Unverified Commit fa4b866d authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[C/PyTorch] Fixed incorrect use of `torch.distributed.new_group()` when...


[C/PyTorch] Fixed incorrect use of `torch.distributed.new_group()` when creating intra-node group in `initialize_ub()` (#1087)

* updated initialize_ub() to use new_subgroups_by_enumeration() to generate intra-node groups, added new unit tests for TE layers with comm overlap
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 86f27e12
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
import os import os
import sys import sys
import socket import socket
import fcntl
import struct
import argparse import argparse
import warnings import warnings
...@@ -15,15 +17,37 @@ import torch.distributed as dist ...@@ -15,15 +17,37 @@ import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.common.recipe import Format, DelayedScaling from transformer_engine.common.recipe import Format, DelayedScaling
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=UserWarning)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
if not tex.device_supports_multicast():
os.environ["UB_SKIPMC"] = "1"
def _te_layer_argtype(name):
te_layers = [
te.Linear,
te.LayerNormLinear,
te.LayerNormMLP,
te.MultiheadAttention,
te.TransformerLayer,
]
layer_map = dict(zip([layer.__name__.lower() for layer in te_layers], te_layers))
if name.lower() not in layer_map.keys():
raise argparse.ArgumentTypeError(
f"Invalid TE layer name! Please choose from: {layer_map.keys()}"
)
return layer_map[name.lower()]
def _parse_args(argv=None, namespace=None): def _parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Test a te.LayerNormMLP module with GEMM+comm overlap via Userbuffers." description="Train a Transformer Engine module with GEMM+comm overlap via Userbuffers."
) )
parser.add_argument( parser.add_argument(
"-i", "--num-iters", type=int, default=5, help="Number of dummy 'training' iterations." "-i", "--num-iters", type=int, default=5, help="Number of dummy 'training' iterations."
...@@ -37,10 +61,10 @@ def _parse_args(argv=None, namespace=None): ...@@ -37,10 +61,10 @@ def _parse_args(argv=None, namespace=None):
"-d", "--head-dim", type=int, default=128, help="Dimension of each attention head." "-d", "--head-dim", type=int, default=128, help="Dimension of each attention head."
) )
parser.add_argument( parser.add_argument(
"--mlp-expansion-factor", "--layer-type",
type=int, type=_te_layer_argtype,
default=4, default=te.TransformerLayer,
help="MLP block intermediate size as a factor of hidden dimension.", help="Transformer Engine layer to train with comm+GEMM overlap.",
) )
parser.add_argument("--seed", type=int, default=1234, help="RNG seed.") parser.add_argument("--seed", type=int, default=1234, help="RNG seed.")
parser.add_argument( parser.add_argument(
...@@ -88,9 +112,57 @@ def _parse_args(argv=None, namespace=None): ...@@ -88,9 +112,57 @@ def _parse_args(argv=None, namespace=None):
help="Print out additional debug information.", help="Print out additional debug information.",
) )
args = parser.parse_args(argv, namespace) args = parser.parse_args(argv, namespace)
if args.bootstrap_backend == "nccl":
args.bind_to_device = True
return args return args
def _get_layer_args(config, tp_group, tp_size, reference=False):
hidden_size = config.num_heads * config.head_dim
input_shape = [config.seq_length, config.batch_size, hidden_size]
args = [hidden_size]
kwargs = {
"params_dtype": torch.float32,
"device": "cuda",
"tp_group": tp_group,
"tp_size": tp_size,
"sequence_parallel": True,
}
kwargs["ub_overlap_ag"] = not config.no_comm_overlap
if config.layer_type is te.Linear:
input_shape[2] = hidden_size // tp_size
args.append(hidden_size)
kwargs["parallel_mode"] = "row"
kwargs["ub_overlap_rs"] = not config.no_comm_overlap
kwargs["ub_name"] = "proj"
else:
input_shape[0] = config.seq_length // tp_size
kwargs["ub_bulk_wgrad"] = not config.no_comm_overlap
kwargs["ub_bulk_dgrad"] = not config.no_comm_overlap
if config.layer_type is te.LayerNormLinear:
args.append(3 * hidden_size)
kwargs["parallel_mode"] = "column"
kwargs["ub_name"] = "qkv"
else:
kwargs["set_parallel_mode"] = True
kwargs["ub_overlap_rs"] = not config.no_comm_overlap
if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]:
args.append(4 * hidden_size)
kwargs["seq_length"] = config.seq_length
if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
args.append(config.num_heads)
kwargs["attention_dropout"] = 0.0
kwargs["fuse_qkv_params"] = True
if config.layer_type is te.MultiheadAttention:
kwargs["input_layernorm"] = True
else:
kwargs["ub_tp_comm_overlap"] = not config.no_comm_overlap
kwargs["hidden_dropout"] = 0.0
return args, kwargs, input_shape
def _train(opts): def _train(opts):
if "OMPI_COMM_WORLD_SIZE" in os.environ: if "OMPI_COMM_WORLD_SIZE" in os.environ:
# Execution with `mpirun -np N` # Execution with `mpirun -np N`
...@@ -110,19 +182,6 @@ def _train(opts): ...@@ -110,19 +182,6 @@ def _train(opts):
raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!") raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!")
NUM_NODES = WORLD_SIZE // LOCAL_SIZE NUM_NODES = WORLD_SIZE // LOCAL_SIZE
def dist_print(msg, group=None, end="\n", debug=False):
if debug and not opts.debug:
return
group = dist.new_group() if group is None else group
group_rank = dist.get_rank(group)
group_size = dist.get_world_size(group)
all_ranks = dist.get_process_group_ranks(group)
ranks_skip = all_ranks[1] - all_ranks[0] > 1
group_id = WORLD_RANK % group_size if ranks_skip else WORLD_RANK // group_size
if group_rank == 0 or opts.verbose:
print(f"[rank{WORLD_RANK}:node{group_id}] {msg}{end}", end="", flush=True)
dist.barrier(group)
# Initialize torch.distributed global process group and get DP/TP groups # Initialize torch.distributed global process group and get DP/TP groups
torch.cuda.set_device(LOCAL_RANK) torch.cuda.set_device(LOCAL_RANK)
dist_init_kwargs = { dist_init_kwargs = {
...@@ -143,75 +202,117 @@ def _train(opts): ...@@ -143,75 +202,117 @@ def _train(opts):
assert dist.is_nccl_available() assert dist.is_nccl_available()
dist.init_process_group(**dist_init_kwargs) dist.init_process_group(**dist_init_kwargs)
nccl_world = dist.new_group(backend="nccl") nccl_world = dist.new_group(backend="nccl")
dist_print(f"Initialized default NCCL process group with {WORLD_RANK} GPUs", nccl_world)
def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False):
if debug and not opts.debug:
return
group_rank = dist.get_rank(group)
stream = sys.stderr if error else sys.stdout
if group_rank == src:
stream.write(f"[rank{WORLD_RANK}] {msg}{end}")
dist.barrier(group)
dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs")
# Figure out process groups for tensor- and data-parallelism (if any) # Figure out process groups for tensor- and data-parallelism (if any)
if NUM_NODES > 1: if NUM_NODES > 1:
# Create a list of world ranks on this node # Create a list of world ranks on this node
hostnames = [None for _ in range(WORLD_SIZE)]
hostname = socket.gethostname() hostname = socket.gethostname()
ifname = os.getenv(
"NVTE_UB_SOCKET_IFNAME",
os.getenv("NCCL_SOCKET_IFNAME", os.getenv("GLOO_SOCKET_IFNAME")),
)
if ifname is not None:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
hostname = socket.inet_ntoa(
fcntl.ioctl(
s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8"))
)[20:24]
)
except OSError as err:
raise OSError(f"Invalid network interface: {ifname}") from err
hostnames = [None for _ in range(WORLD_SIZE)]
dist.all_gather_object(hostnames, hostname) dist.all_gather_object(hostnames, hostname)
node_ranks = [] unique_hosts = []
for host in hostnames:
if host not in unique_hosts:
unique_hosts.append(host)
assert len(unique_hosts) == NUM_NODES
ranks_per_node_list = [[] for _ in range(NUM_NODES)]
self_node_idx = -1
for i, host in enumerate(hostnames): for i, host in enumerate(hostnames):
node_idx = unique_hosts.index(host)
ranks_per_node_list[node_idx].append(i)
if host == hostname: if host == hostname:
node_ranks.append(i) self_node_idx = node_idx
assert self_node_idx >= 0
self_node_ranks = ranks_per_node_list[self_node_idx]
if opts.num_replicas > 1: if opts.num_replicas > 1:
# Split node ranks into multiple replicas # Split node ranks into multiple replicas
assert len(node_ranks) % opts.num_replicas == 0 assert len(self_node_ranks) % opts.num_replicas == 0
tp_size = len(node_ranks) // opts.num_replicas tp_size = len(self_node_ranks) // opts.num_replicas
found_replica = False ranks_per_replica_list = []
for replica in range(opts.num_replicas): for node_ranks in ranks_per_node_list:
start = replica * tp_size for i in range(opts.num_replicas):
start = i * tp_size
end = start + tp_size end = start + tp_size
tp_ranks = node_ranks[start:end] ranks_per_replica_list.append(node_ranks[start:end])
if WORLD_RANK in tp_ranks:
found_replica = True self_replica_idx = -1
for i, replica_ranks in enumerate(ranks_per_replica_list):
if WORLD_RANK in replica_ranks:
self_replica_idx = i
break break
assert found_replica assert self_replica_idx >= 0
else: else:
# The entire node is the tensor-parallel group # The entire node is the tensor-parallel group
tp_ranks = node_ranks ranks_per_replica_list = ranks_per_node_list
self_replica_idx = self_node_idx
tp_group = dist.new_group(backend="nccl", ranks=tp_ranks) tp_group, _ = dist.new_subgroups_by_enumeration(ranks_per_replica_list, backend="nccl")
tp_size = dist.get_world_size(tp_group) ranks_per_replica_tensor = torch.tensor(ranks_per_replica_list, dtype=torch.int32)
tp_rank = dist.get_rank(tp_group) dp_group, _ = dist.new_subgroups_by_enumeration(
ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl"
# Data-parallelism across TP groups )
dp_start = tp_rank
dp_end = dp_start + WORLD_SIZE
dp_ranks = list(range(dp_start, dp_end, tp_size))
dp_group = dist.new_group(backend="nccl", ranks=dp_ranks)
else: else:
if opts.num_replicas > 1: if opts.num_replicas > 1:
# Mixed data- and tensor-parallelism on a single node # Mixed data- and tensor-parallelism on a single node
# NOTE: Avoid dist.init_device_mesh() to support older PyTorch versions # NOTE: Avoid dist.init_device_mesh() to support older PyTorch versions
all_ranks = torch.tensor(list(range(LOCAL_SIZE)), dtype=torch.uint8, device="cpu") all_ranks = torch.tensor(list(range(LOCAL_SIZE)), dtype=torch.uint8, device="cpu")
mesh2d = all_ranks.reshape((opts.num_replicas, LOCAL_SIZE // opts.num_replicas)) ranks_per_replica_tensor = all_ranks.reshape(
node_idx = (mesh2d == LOCAL_RANK).nonzero().squeeze().tolist() (opts.num_replicas, LOCAL_SIZE // opts.num_replicas)
)
tp_ranks = mesh2d[node_idx[0], :].tolist() tp_group, _ = dist.new_subgroups_by_enumeration(
tp_group = dist.new_group(backend="nccl", ranks=tp_ranks) ranks_per_replica_tensor.tolist(), backend="nccl"
)
dp_ranks = mesh2d[:, node_idx[1]].tolist() dp_group, _ = dist.new_subgroups_by_enumeration(
dp_group = dist.new_group(backend="nccl", ranks=dp_ranks) ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl"
)
else: else:
dp_group = None dp_group = None
tp_group = nccl_world tp_group = nccl_world
tp_rank = dist.get_rank(tp_group) tp_rank = dist.get_rank(tp_group)
tp_size = dist.get_world_size(tp_group) tp_size = dist.get_world_size(tp_group)
dist_print( dist_print(
f"Created tensor-parallel group: {dist.get_process_group_ranks(tp_group)}", f"Created tensor-parallel group: {dist.get_process_group_ranks(tp_group)}",
group=tp_group, group=tp_group,
) )
if dp_group is not None: if dp_group is not None:
dp_rank = dist.get_rank(dp_group)
dist_print( dist_print(
f"Created data-parallel group: {dist.get_process_group_ranks(dp_group)}", f"Created data-parallel group: {dist.get_process_group_ranks(dp_group)}",
group=dp_group, group=dp_group,
) )
else:
dp_rank = 0
# Intialize userbuffers # Intialize userbuffers
hidden_size = opts.num_heads * opts.head_dim hidden_size = opts.num_heads * opts.head_dim
...@@ -226,26 +327,12 @@ def _train(opts): ...@@ -226,26 +327,12 @@ def _train(opts):
) )
# Initialize the fused LayerNorm + Multi-layer Perceptron module # Initialize the fused LayerNorm + Multi-layer Perceptron module
torch.manual_seed(opts.seed + tp_rank) torch.manual_seed(opts.seed + dp_rank)
torch.cuda.manual_seed(opts.seed + tp_rank) torch.cuda.manual_seed(opts.seed + tp_rank)
model = te.LayerNormMLP( layer_args, layer_kwargs, input_size = _get_layer_args(opts, tp_group, tp_size)
hidden_size, model = opts.layer_type(*layer_args, **layer_kwargs)
opts.mlp_expansion_factor * hidden_size,
params_dtype=torch.bfloat16,
device="cuda",
tp_group=tp_group,
tp_size=tp_size,
set_parallel_mode=True,
sequence_parallel=True, # this is required for comm+GEMM overlap
seq_length=opts.seq_length,
ub_overlap_rs=not opts.no_comm_overlap,
ub_overlap_ag=not opts.no_comm_overlap,
ub_overlap_rs_dgrad=not opts.no_comm_overlap,
ub_bulk_dgrad=False,
ub_bulk_wgrad=not opts.no_comm_overlap,
)
if dp_group is not None: if dp_group is not None:
model = DistributedDataParallel(model, process_group=dp_group) model = DistributedDataParallel(model, dim=1, process_group=dp_group)
# Initialize optimizer with model parameters # Initialize optimizer with model parameters
optim = torch.optim.Adam(model.parameters(), lr=0.0001) optim = torch.optim.Adam(model.parameters(), lr=0.0001)
...@@ -255,28 +342,28 @@ def _train(opts): ...@@ -255,28 +342,28 @@ def _train(opts):
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
# Start dummy "training" iterations # Start dummy "training" iterations
dist_print("Starting training iterations...", nccl_world) dist_print("Starting training iterations...")
for i in range(opts.num_iters): for i in range(opts.num_iters):
dist_print(f" Iter {i+1}", tp_group, debug=True) dist_print(f" Iter {i+1}", group=tp_group, debug=True)
dist_print(" |-- Generate random input batch", tp_group, debug=True) dist_print(" |-- Generate random input batch", group=tp_group, debug=True)
x = torch.rand( x = torch.randn(input_size, dtype=torch.float32, device="cuda", requires_grad=True)
(opts.seq_length // tp_size, opts.batch_size, hidden_size),
dtype=torch.bfloat16,
device="cuda",
requires_grad=True,
)
dist_print(" |-- Forward pass", tp_group, debug=True) dist_print(" |-- Forward pass", group=tp_group, debug=True)
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world): with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world):
y = model(x) y = model(x)
dist_print(" |-- Compute loss", tp_group, debug=True) if isinstance(y, tuple):
loss = y.flatten().sum() out, *_ = y
else:
out = y
dist_print(" |-- Compute loss", group=tp_group, debug=True)
loss = out.sum()
dist_print(" |-- Backward pass", tp_group, debug=True) dist_print(" |-- Backward pass", group=tp_group, debug=True)
loss.backward() loss.backward()
dist_print(" |-- Optimizer step", tp_group, debug=True) dist_print(" |-- Optimizer step", group=tp_group, debug=True)
optim.step() optim.step()
torch.cuda.synchronize() torch.cuda.synchronize()
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
set -e set -e
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py
git clone https://github.com/NVIDIA/Megatron-LM.git git clone https://github.com/NVIDIA/Megatron-LM.git
cd Megatron-LM cd Megatron-LM
......
...@@ -46,17 +46,20 @@ def _mapped_argtype(opt, typemap): ...@@ -46,17 +46,20 @@ def _mapped_argtype(opt, typemap):
def _parse_args(argv=None, namespace=None): def _parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser(description="Test comm+GEMM overlap with Userbuffers.") parser = argparse.ArgumentParser(description="Test comm+GEMM overlap with Userbuffers.")
parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.") parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.")
parser.add_argument("-s", "--seq-length", type=int, default=2048, help="Input sequence length.") parser.add_argument("-s", "--seq-length", type=int, default=512, help="Input sequence length.")
parser.add_argument( parser.add_argument(
"-n", "--num-heads", type=int, default=64, help="Number of attention heads." "-n", "--num-heads", type=int, default=12, help="Number of attention heads."
) )
parser.add_argument( parser.add_argument(
"-d", "--head-dim", type=int, default=128, help="Dimension of each attention head." "-d", "--head-dim", type=int, default=64, help="Dimension of each attention head."
) )
parser.add_argument("--seed", type=int, default=1234, help="RNG seed.") parser.add_argument("--seed", type=int, default=1234, help="RNG seed.")
parser.add_argument( parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
) )
parser.add_argument(
"--fp8-output", action="store_true", default=False, help="Get FP8 output from GEMM."
)
parser.add_argument( parser.add_argument(
"--p2p", action="store_true", default=False, help="Test overlap with P2P comms." "--p2p", action="store_true", default=False, help="Test overlap with P2P comms."
) )
...@@ -106,7 +109,7 @@ def _parse_args(argv=None, namespace=None): ...@@ -106,7 +109,7 @@ def _parse_args(argv=None, namespace=None):
help="Set device clock speed to a fixed value via `nvidia-smi`.", help="Set device clock speed to a fixed value via `nvidia-smi`.",
) )
parser.add_argument( parser.add_argument(
"--scale", type=float, default=1e-2, help="Set scaling factor for input and weight tensors." "--std", type=float, default=0.023, help="Standard deviation for input and weight tensors."
) )
parser.add_argument( parser.add_argument(
"--tcp-init", "--tcp-init",
...@@ -135,6 +138,9 @@ def _parse_args(argv=None, namespace=None): ...@@ -135,6 +138,9 @@ def _parse_args(argv=None, namespace=None):
+ "initialization." + "initialization."
), ),
) )
parser.add_argument(
"--use-cuda-graphs", action="store_true", default=False, help="Use CUDA graphs."
)
parser.add_argument( parser.add_argument(
"-v", "--verbose", action="store_true", default=False, help="Verbose info messages." "-v", "--verbose", action="store_true", default=False, help="Verbose info messages."
) )
...@@ -150,7 +156,10 @@ def _parse_args(argv=None, namespace=None): ...@@ -150,7 +156,10 @@ def _parse_args(argv=None, namespace=None):
if opts.fp8: if opts.fp8:
warnings.warn("Bulk overlap is supported in FP8 but only tested in BF16.") warnings.warn("Bulk overlap is supported in FP8 but only tested in BF16.")
opts.fp8 = False opts.fp8 = False
elif opts.comm_type == 1 and not opts.p2p: elif opts.comm_type == 1:
if opts.atomic:
setattr(opts, "atomic_rs_p2p", opts.p2p)
if not opts.p2p:
warnings.warn("All-gather overlap is only supported with point-2-point comms.") warnings.warn("All-gather overlap is only supported with point-2-point comms.")
opts.p2p = True opts.p2p = True
...@@ -203,13 +212,14 @@ def _main(opts): ...@@ -203,13 +212,14 @@ def _main(opts):
print(f"[rank:{LOCAL_RANK}] {msg}\n", end="", flush=True) print(f"[rank:{LOCAL_RANK}] {msg}\n", end="", flush=True)
# Info printout # Info printout
def dist_print(msg, src=None, info=False, section=False, group=None): def dist_print(msg, src=None, info=False, error=False, section=False, group=None):
group = dist.new_group() if group is None else group group = dist.new_group() if group is None else group
rank = dist.get_rank(group) rank = dist.get_rank(group)
stream = sys.stderr if error else sys.stdout
if info or opts.verbose: if info or opts.verbose:
if section: if section:
if rank == (0 if src is None else src): if rank == (0 if src is None else src):
print("\n", end="", flush=True) stream.write("\n")
dist.barrier(group) dist.barrier(group)
if src is None or rank == src: if src is None or rank == src:
prefix = "[GLOBAL] " if src is not None else f"[rank:{rank}] " prefix = "[GLOBAL] " if src is not None else f"[rank:{rank}] "
...@@ -217,7 +227,7 @@ def _main(opts): ...@@ -217,7 +227,7 @@ def _main(opts):
msg = "\n".join( msg = "\n".join(
[prefix + lines[0]] + [(" " * len(prefix)) + line for line in lines[1:]] [prefix + lines[0]] + [(" " * len(prefix)) + line for line in lines[1:]]
) )
print(msg + "\n", end="", flush=True) stream.write(msg + "\n")
dist.barrier(group) dist.barrier(group)
# Initialize torch.distributed global process group and get TP group # Initialize torch.distributed global process group and get TP group
...@@ -312,7 +322,9 @@ def _main(opts): ...@@ -312,7 +322,9 @@ def _main(opts):
hidden_size = opts.num_heads * opts.head_dim hidden_size = opts.num_heads * opts.head_dim
inp_shape = (opts.seq_length, opts.batch_size, hidden_size) inp_shape = (opts.seq_length, opts.batch_size, hidden_size)
outer_size = reduce(operator.mul, inp_shape[:-1], 1) outer_size = reduce(operator.mul, inp_shape[:-1], 1)
ubuf_dtype = torch.uint8 if opts.fp8 and opts.comm_type == 1 else torch.bfloat16 ubuf_dtype = torch.bfloat16
if opts.fp8 and not opts.bulk_overlap and (opts.comm_type == 1 or opts.fp8_output):
ubuf_dtype = torch.uint8
sample_buffer = torch.empty((outer_size, hidden_size), dtype=ubuf_dtype, device="cuda") sample_buffer = torch.empty((outer_size, hidden_size), dtype=ubuf_dtype, device="cuda")
ub_obj = ub_obj = ( ub_obj = ub_obj = (
tex.UbufP2PCommOverlap( tex.UbufP2PCommOverlap(
...@@ -331,7 +343,7 @@ def _main(opts): ...@@ -331,7 +343,7 @@ def _main(opts):
3, # Max concurrent GEMM streams 3, # Max concurrent GEMM streams
opts.comm_type == 0, # overlap with reduce scatter opts.comm_type == 0, # overlap with reduce scatter
opts.atomic, # use a single GEMM with atomic-counters opts.atomic, # use a single GEMM with atomic-counters
True, # Use copy engine for P2P communications not (opts.atomic and bool(int(os.getenv("NVTE_AG_P2P_MULTI_ATOMIC", "0")))),
ub_callbacks, ub_callbacks,
) )
if opts.p2p if opts.p2p
...@@ -349,7 +361,7 @@ def _main(opts): ...@@ -349,7 +361,7 @@ def _main(opts):
4, # Number of communication splits 4, # Number of communication splits
True, # Set SM margin True, # Set SM margin
3, # Max concurrent GEMM streams 3, # Max concurrent GEMM streams
opts.atomic, # uUe a single GEMM with atomic-counters opts.atomic, # Use a single GEMM with atomic-counters
ub_callbacks, ub_callbacks,
) )
) )
...@@ -357,8 +369,13 @@ def _main(opts): ...@@ -357,8 +369,13 @@ def _main(opts):
# Numerical check on AG + atomic GEMM requires testing an AG+RS pair # Numerical check on AG + atomic GEMM requires testing an AG+RS pair
ub_obj2 = None ub_obj2 = None
if opts.atomic and opts.comm_type == 1 and opts.check_numerics: if opts.atomic and opts.comm_type == 1 and opts.check_numerics:
sample_buffer2 = torch.empty((outer_size, hidden_size), dtype=torch.bfloat16, device="cuda") sample_buffer2 = torch.empty(
ub_obj2 = tex.UbufP2PCommOverlap( (outer_size, hidden_size),
dtype=torch.uint8 if opts.fp8_output else torch.bfloat16,
device="cuda",
)
ub_obj2 = (
tex.UbufP2PCommOverlap(
sample_buffer2, # Sample userbuffer sample_buffer2, # Sample userbuffer
WORLD_RANK, # World rank WORLD_RANK, # World rank
WORLD_SIZE, # World size WORLD_SIZE, # World size
...@@ -377,6 +394,25 @@ def _main(opts): ...@@ -377,6 +394,25 @@ def _main(opts):
True, # use copy engine for P2P communications True, # use copy engine for P2P communications
ub_callbacks, ub_callbacks,
) )
if opts.atomic_rs_p2p
else tex.UbufCommOverlap(
sample_buffer2, # Sample userbuffer
WORLD_RANK, # World rank
WORLD_SIZE, # World size
LOCAL_RANK, # Rank within the node
LOCAL_SIZE, # Number of ranks/GPUs per node
0, # Node ID
1, # Number of nodes
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
16, # Number of communication SMs
2, # CGA cluster size
4, # Number of communication splits
True, # Set SM margin
3, # Max concurrent GEMM streams
True, # uUe a single GEMM with atomic-counters
ub_callbacks,
)
)
# Figure out problem sizing: # Figure out problem sizing:
# M = sequence * batch # M = sequence * batch
...@@ -409,43 +445,53 @@ def _main(opts): ...@@ -409,43 +445,53 @@ def _main(opts):
# Initialize distributed input tensor and GEMM kernels # Initialize distributed input tensor and GEMM kernels
torch.manual_seed(opts.seed + tp_rank) torch.manual_seed(opts.seed + tp_rank)
torch.cuda.manual_seed(opts.seed + tp_rank) torch.cuda.manual_seed(opts.seed + tp_rank)
inp = torch.mul(torch.rand(local_inp_shape, dtype=torch.bfloat16, device="cuda"), opts.scale) inp = torch.nn.init.normal_(
kernel_t = torch.mul( torch.empty(local_inp_shape, dtype=torch.bfloat16, device="cuda"),
torch.rand(local_kernel_t_shape, dtype=torch.bfloat16, device="cuda"), opts.scale mean=0.0,
std=opts.std,
)
kernel_t = torch.nn.init.normal_(
torch.empty(local_kernel_t_shape, dtype=torch.bfloat16, device="cuda"),
mean=0.0,
std=opts.std,
) )
if ub_obj2 is not None: if ub_obj2 is not None:
kernel2_t = torch.mul( kernel2_t = torch.nn.init.normal_(
torch.rand(local_kernel2_t_shape, dtype=torch.bfloat16, device="cuda"), opts.scale torch.empty(local_kernel2_t_shape, dtype=torch.bfloat16, device="cuda"),
mean=0.0,
std=opts.std,
) )
# Gather global tensors and calculate reference result (need these first for Fp8 scales) # Gather global tensors and calculate reference result (need these first for Fp8 scales)
if opts.bulk_overlap: if opts.bulk_overlap:
ker_g = torch.transpose(kernel_t, 0, 1) ker_g = torch.transpose(kernel_t, 0, 1)
inp_g = inp inp_g = inp
bulk_inp = torch.mul( bulk_inp = torch.nn.init.normal_(
torch.rand(bulk_inp_shape, dtype=torch.bfloat16, device="cuda"), opts.scale torch.empty(bulk_inp_shape, dtype=torch.bfloat16, device="cuda"),
mean=0.0,
std=opts.std,
) )
else: else:
if opts.comm_type == 1: if opts.comm_type == 1:
# AG Kernel: (K/P, N) -> gather -> (K, N) -> T -> (N, K) # AG Kernel: (K/P, N) -> gather -> (K, N) -> T -> (N, K)
ker_g = torch.transpose( ker_g = torch.transpose(
te.distributed.gather_along_first_dim(kernel_t, tp_group)[0], 0, 1 te.distributed.gather_along_first_dim(kernel_t, tp_group)[0], 0, 1
) ).to(dtype=torch.float32)
# AG Input: (M/P, N) -> gather -> (M, N) # AG Input: (M/P, N) -> gather -> (M, N)
inp_g = te.distributed.gather_along_first_dim(inp, tp_group)[0] inp_g = te.distributed.gather_along_first_dim(inp, tp_group)[0].to(dtype=torch.float32)
if ub_obj2 is not None: if ub_obj2 is not None:
ker2_g = te.distributed.gather_along_first_dim( ker2_g = te.distributed.gather_along_first_dim(
torch.transpose(kernel2_t, 0, 1), tp_group torch.transpose(kernel2_t, 0, 1), tp_group
)[0] )[0].to(dtype=torch.float32)
else: else:
# RS Kernel: (N, K/P) -> T -> (K/P, N) -> gather -> (K, N) # RS Kernel: (N, K/P) -> T -> (K/P, N) -> gather -> (K, N)
ker_g = te.distributed.gather_along_first_dim( ker_g = te.distributed.gather_along_first_dim(
torch.transpose(kernel_t, 0, 1), tp_group torch.transpose(kernel_t, 0, 1), tp_group
)[0] )[0].to(dtype=torch.float32)
# RS Input: (M, K/P) -> T -> (K/P, M) -> gather -> (K, M) -> T -> (M, K) # RS Input: (M, K/P) -> T -> (K/P, M) -> gather -> (K, M) -> T -> (M, K)
inp_g = torch.transpose( inp_g = torch.transpose(
te.distributed.gather_along_first_dim(torch.transpose(inp, 0, 1), tp_group)[0], 0, 1 te.distributed.gather_along_first_dim(torch.transpose(inp, 0, 1), tp_group)[0], 0, 1
) ).to(dtype=torch.float32)
if opts.bulk_overlap: if opts.bulk_overlap:
if opts.comm_type == 1: if opts.comm_type == 1:
...@@ -459,7 +505,7 @@ def _main(opts): ...@@ -459,7 +505,7 @@ def _main(opts):
else: else:
ref_g = torch.matmul(inp_g, ker_g) ref_g = torch.matmul(inp_g, ker_g)
if ub_obj2 is not None: if ub_obj2 is not None:
inp2_g = torch.mul(ref_g, opts.scale) inp2_g = torch.nn.functional.gelu(ref_g)
ref2_g = torch.matmul(inp2_g, ker2_g) ref2_g = torch.matmul(inp2_g, ker2_g)
if opts.fp8: if opts.fp8:
...@@ -483,7 +529,10 @@ def _main(opts): ...@@ -483,7 +529,10 @@ def _main(opts):
fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_WEIGHT].copy_(ker_amax) fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_WEIGHT].copy_(ker_amax)
ref_amax = torch.max(torch.abs(ref_g)) ref_amax = torch.max(torch.abs(ref_g))
fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_OUTPUT].copy_(ref_amax) fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_OUTPUT].copy_(ref_amax)
if ub_obj2 is not None: if opts.bulk_overlap and opts.comm_type == 0:
bulk_amax = torch.max(torch.abs(bulk_inp))
fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_OUTPUT].copy_(bulk_amax)
elif ub_obj2 is not None:
inp2_amax = torch.max(torch.abs(inp2_g)) inp2_amax = torch.max(torch.abs(inp2_g))
fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_INPUT].copy_(inp2_amax) fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_INPUT].copy_(inp2_amax)
ker2_amax = torch.max(torch.abs(ker2_g)) ker2_amax = torch.max(torch.abs(ker2_g))
...@@ -502,7 +551,11 @@ def _main(opts): ...@@ -502,7 +551,11 @@ def _main(opts):
kernel_t_fp8 = tex.cast_to_fp8( kernel_t_fp8 = tex.cast_to_fp8(
kernel_t, fp8_meta, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype kernel_t, fp8_meta, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype
) )
if ub_obj2 is not None: if opts.bulk_overlap and opts.comm_type == 0:
bulk_inp_fp8 = tex.cast_to_fp8(
bulk_inp, fp8_meta, tex.FP8Tensors.GEMM2_OUTPUT, fp8_dtype
)
elif ub_obj2 is not None:
kernel2_t_fp8 = tex.cast_to_fp8( kernel2_t_fp8 = tex.cast_to_fp8(
kernel2_t, fp8_meta, tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype kernel2_t, fp8_meta, tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype
) )
...@@ -521,7 +574,14 @@ def _main(opts): ...@@ -521,7 +574,14 @@ def _main(opts):
rtol=0.125, rtol=0.125,
atol=0.0675, atol=0.0675,
) )
if ub_obj2 is not None: if opts.bulk_overlap and opts.comm_type == 0:
torch.allclose(
bulk_inp.to(dtype=torch.float32),
bulk_inp_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT],
rtol=0.125,
atol=0.0675,
)
elif ub_obj2 is not None:
torch.allclose( torch.allclose(
kernel2_t.to(dtype=torch.float32), kernel2_t.to(dtype=torch.float32),
kernel2_t_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_WEIGHT], kernel2_t_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_WEIGHT],
...@@ -534,6 +594,8 @@ def _main(opts): ...@@ -534,6 +594,8 @@ def _main(opts):
ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_INPUT]) ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_INPUT])
if ub_obj2 is not None: if ub_obj2 is not None:
ub_obj2.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT]) ub_obj2.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT])
elif opts.bulk_overlap:
ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT])
else: else:
ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_OUTPUT]) ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_OUTPUT])
...@@ -556,7 +618,7 @@ def _main(opts): ...@@ -556,7 +618,7 @@ def _main(opts):
) )
else: else:
if opts.bulk_overlap: if opts.bulk_overlap:
ub_obj.copy_input_to_ubuf(bulk_inp, 0) ub_obj.copy_input_to_ubuf(bulk_inp_fp8 if opts.fp8 else bulk_inp, 0)
ubuf_out = None ubuf_out = None
else: else:
ubuf_out = ub_obj.get_ubuf_output(1) ubuf_out = ub_obj.get_ubuf_output(1)
...@@ -565,16 +627,9 @@ def _main(opts): ...@@ -565,16 +627,9 @@ def _main(opts):
(outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda" (outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda"
) )
# Trigger GEMM # Wrap GEMM ops in condensed functions to make CUDA Graphs easier to use
total_iters = opts.warmup_iters + opts.timing_iters def _fp8_gemm():
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(total_iters)] return tex.fp8_gemm(
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(total_iters)]
torch.cuda.synchronize()
if opts.fp8:
for i in range(total_iters):
start_events[i].record()
all_outputs = tex.fp8_gemm(
kernel_t_fp8, kernel_t_fp8,
fp8_meta.scale_inv, fp8_meta.scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
...@@ -583,7 +638,7 @@ def _main(opts): ...@@ -583,7 +638,7 @@ def _main(opts):
fp8_meta.scale_inv, fp8_meta.scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype, fp8_dtype,
torch.bfloat16, torch.uint8 if opts.fp8_output else torch.bfloat16,
te.module.base.get_workspace(), te.module.base.get_workspace(),
bias=None, bias=None,
use_bias=False, use_bias=False,
...@@ -593,16 +648,29 @@ def _main(opts): ...@@ -593,16 +648,29 @@ def _main(opts):
ub=ub_obj, ub=ub_obj,
extra_output_tensor=rs_out, extra_output_tensor=rs_out,
out=ubuf_out, out=ubuf_out,
D_dtype=fp8_dtype if opts.fp8_output else None,
fp8_meta_tensor=fp8_meta if opts.fp8_output else None,
out_index=tex.FP8FwdTensors.GEMM1_OUTPUT if opts.fp8_output else None,
) )
end_events[i].record()
if ub_obj2 is not None: def _fp8_gemm2(gemm1_out):
gemm2_inp = tex.cast_to_fp8( gemm2_inp = tex.gelu(
torch.mul(all_outputs[0], opts.scale), (
tex.cast_from_fp8(
gemm1_out,
fp8_meta,
tex.FP8FwdTensors.GEMM1_OUTPUT,
fp8_dtype,
tex.DType.kFloat32,
)
if opts.fp8_output
else gemm1_out
),
fp8_meta, fp8_meta,
tex.FP8FwdTensors.GEMM2_INPUT, tex.FP8FwdTensors.GEMM2_INPUT,
fp8_dtype, fp8_dtype,
) )
all_outputs = tex.fp8_gemm( return tex.fp8_gemm(
kernel2_t_fp8, kernel2_t_fp8,
fp8_meta.scale_inv, fp8_meta.scale_inv,
tex.FP8FwdTensors.GEMM2_WEIGHT, tex.FP8FwdTensors.GEMM2_WEIGHT,
...@@ -611,21 +679,27 @@ def _main(opts): ...@@ -611,21 +679,27 @@ def _main(opts):
fp8_meta.scale_inv, fp8_meta.scale_inv,
tex.FP8FwdTensors.GEMM2_INPUT, tex.FP8FwdTensors.GEMM2_INPUT,
fp8_dtype, fp8_dtype,
torch.bfloat16, torch.uint8 if opts.fp8_output else torch.bfloat16,
te.module.base.get_workspace(), te.module.base.get_workspace(),
bias=None, bias=None,
use_bias=False, use_bias=False,
gelu=False, gelu=False,
use_split_accumulator=te.module.base._2X_ACC_FPROP, use_split_accumulator=te.module.base._2X_ACC_FPROP,
ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P, ub_algo=(
tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P
if opts.atomic_rs_p2p
else tex.UbufOverlapAlgo.ATOMIC_GEMM_RS
),
ub=ub_obj2, ub=ub_obj2,
extra_output_tensor=rs_out2, extra_output_tensor=rs_out2,
out=ubuf_out2, out=ubuf_out2,
D_dtype=fp8_dtype if opts.fp8_output else None,
fp8_meta_tensor=fp8_meta if opts.fp8_output else None,
out_index=tex.FP8FwdTensors.GEMM2_OUTPUT if opts.fp8_output else None,
) )
else:
for i in range(total_iters): def _gemm():
start_events[i].record() return tex.gemm(
all_outputs = tex.gemm(
kernel_t, kernel_t,
gemm_inp, gemm_inp,
torch.bfloat16, torch.bfloat16,
...@@ -638,6 +712,45 @@ def _main(opts): ...@@ -638,6 +712,45 @@ def _main(opts):
extra_output_tensor=rs_out, extra_output_tensor=rs_out,
out=ubuf_out, out=ubuf_out,
) )
# Trigger GEMM
total_iters = opts.warmup_iters + opts.timing_iters
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(total_iters)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(total_iters)]
torch.cuda.synchronize()
if opts.use_cuda_graphs:
# Trace the CUDA graph first
g = torch.cuda.CUDAGraph()
if opts.fp8:
if ub_obj is None:
with torch.cuda.graph(g):
all_outputs = _fp8_gemm()
else:
with torch.cuda.graph(g):
all_outputs = _fp8_gemm()
_ = _fp8_gemm2(all_outputs[0])
else:
with torch.cuda.graph(g):
all_outputs = _gemm()
# Now replay the CUDA graph in a loop
for i in range(total_iters):
start_events[i].record()
g.replay()
end_events[i].record()
else:
for i in range(total_iters):
if opts.fp8:
start_events[i].record()
all_outputs = _fp8_gemm()
end_events[i].record()
if ub_obj2 is not None:
_fp8_gemm2(all_outputs[0])
else:
start_events[i].record()
all_outputs = _gemm()
end_events[i].record() end_events[i].record()
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -679,7 +792,11 @@ def _main(opts): ...@@ -679,7 +792,11 @@ def _main(opts):
ref_out = ref_g ref_out = ref_g
output_info += f"output: {list(test_out.shape)} | reference: {list(ref_out.shape)}" output_info += f"output: {list(test_out.shape)} | reference: {list(ref_out.shape)}"
dist_print(output_info, src=0 if opts.comm_type == 0 else None, section=True) dist_print(
output_info,
src=0 if opts.comm_type == 0 else None,
section=True,
)
test_nonzeros = torch.count_nonzero(test_out) test_nonzeros = torch.count_nonzero(test_out)
ref_nonzeros = torch.count_nonzero(ref_out) ref_nonzeros = torch.count_nonzero(ref_out)
...@@ -691,11 +808,21 @@ def _main(opts): ...@@ -691,11 +808,21 @@ def _main(opts):
if opts.comm_type == 1: if opts.comm_type == 1:
if ub_obj2 is not None: if ub_obj2 is not None:
# AG+RS Output: (M/P, N) -> gather -> (M, N) # AG+RS Output: (M/P, N) -> gather -> (M, N)
output = rs_out2 output = rs_out2.to(dtype=torch.float32)
test_out = te.distributed.gather_along_first_dim(output, tp_group)[0] test_out = te.distributed.gather_along_first_dim(output, tp_group)[0]
else: else:
# AG Output: (M, K/P) -> T -> (K/P, M) -> gather -> (K, M) -> T -> (M, K) # AG Output: (M, K/P) -> T -> (K/P, M) -> gather -> (K, M) -> T -> (M, K)
output = all_outputs[0] output = (
tex.cast_from_fp8(
all_outputs[0],
fp8_meta,
tex.FP8FwdTensors.GEMM1_OUTPUT,
fp8_dtype,
tex.DType.kFloat32,
)
if opts.fp8_output
else all_outputs[0]
)
test_out = torch.transpose( test_out = torch.transpose(
te.distributed.gather_along_first_dim( te.distributed.gather_along_first_dim(
torch.transpose(output, 0, 1), tp_group torch.transpose(output, 0, 1), tp_group
...@@ -705,7 +832,7 @@ def _main(opts): ...@@ -705,7 +832,7 @@ def _main(opts):
) )
else: else:
# RS Output: (M/P, N) -> gather -> (M, N) # RS Output: (M/P, N) -> gather -> (M, N)
output = rs_out output = rs_out.to(dtype=torch.float32)
test_out = te.distributed.gather_along_first_dim(output, tp_group)[0] test_out = te.distributed.gather_along_first_dim(output, tp_group)[0]
if opts.fp8: if opts.fp8:
...@@ -755,30 +882,33 @@ def _main(opts): ...@@ -755,30 +882,33 @@ def _main(opts):
torch.cuda.synchronize() torch.cuda.synchronize()
dist.barrier(tp_group) dist.barrier(tp_group)
test_out = test_out.to(dtype=torch.float32)
ref_out = ref_out.to(dtype=torch.float32)
error_below_tol = torch.allclose(
test_out,
ref_out,
rtol=0.125 if opts.fp8 else 0.02,
atol=0.0675 if opts.fp8 else 0.001,
)
diff = torch.abs(test_out - ref_out).flatten() diff = torch.abs(test_out - ref_out).flatten()
m = torch.argmax(diff) m = torch.argmax(diff)
abs_err = diff[m].item() abs_err = diff[m].item()
rel_err = abs_err / (ref_out.flatten()[m].item() + 1e-5) rel_err = abs_err / max(abs(ref_out.flatten()[m].item()), 1e-5)
if not error_below_tol: rtol = 0.125 if opts.fp8 else 0.02
atol = 0.0625 if opts.fp8 else 0.001
if rel_err > rtol and abs_err > atol:
numerics_failed = True numerics_failed = True
numerics_info = ( numerics_info = (
"NUMERICAL CHECK FAILED: " "NUMERICAL CHECK FAILED: "
+ f"Outputs not close enough at index {m.item()} " + f"Outputs not close enough at index {m.item()} "
+ f"with {test_out.flatten()[m].item()} vs {ref_out.flatten()[m].item()} " + f"with {test_out.flatten()[m].item()} vs {ref_out.flatten()[m].item()} | "
+ f"(abs error = {abs_err} | rel error = {rel_err})." + f"rel. error = {rel_err} (tol = {rtol}) | "
+ f"abs. error = {abs_err} (tol = {atol})"
) )
else: else:
numerics_info = f"NUMERICAL CHECK PASSED: abs error = {abs_err} | rel error = {rel_err}" numerics_info = "NUMERICAL CHECK PASSED: "
if rel_err <= rtol:
numerics_info += f"rel. error = {rel_err} (tol = {rtol})" + (
" | " if abs_err < atol else ""
)
if abs_err <= atol:
numerics_info += f"abs. error = {abs_err} (tol = {atol})"
dist_print(numerics_info, src=0, section=True, info=True, group=tp_group) dist_print(
numerics_info, src=0, section=True, info=True, error=numerics_failed, group=tp_group
)
dist.barrier(tp_group) dist.barrier(tp_group)
if LOCAL_RANK == 0: if LOCAL_RANK == 0:
......
#!/usr/bin/python3
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import sys
import socket
import argparse
import warnings
from functools import partial
import torch
import torch.distributed as dist
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
def _te_layer_argtype(name):
te_layers = [
te.Linear,
te.LayerNormLinear,
te.LayerNormMLP,
te.MultiheadAttention,
te.TransformerLayer,
]
layer_map = dict(zip([layer.__name__.lower() for layer in te_layers], te_layers))
if name.lower() not in layer_map.keys():
raise argparse.ArgumentTypeError(
f"Invalid TE layer name! Please choose from: {layer_map.keys()}"
)
return layer_map[name.lower()]
def _get_layer_args(config, tp_group, tp_size, reference=False):
hidden_size = config.num_heads * config.head_dim
input_shape = [config.seq_length, config.batch_size, hidden_size]
args = [hidden_size]
kwargs = {
"params_dtype": torch.float32,
"device": "cuda",
"tp_group": tp_group,
"tp_size": tp_size,
"sequence_parallel": True,
}
kwargs["ub_overlap_ag"] = not reference
if config.layer_type is te.Linear:
input_shape[2] = hidden_size // tp_size
args.append(hidden_size)
kwargs["parallel_mode"] = "row"
kwargs["ub_overlap_rs"] = not reference
kwargs["ub_name"] = "proj"
else:
input_shape[0] = config.seq_length // tp_size
kwargs["ub_bulk_wgrad"] = not reference
kwargs["ub_bulk_dgrad"] = not reference
if config.layer_type is te.LayerNormLinear:
args.append(3 * hidden_size)
kwargs["parallel_mode"] = "column"
kwargs["ub_name"] = "qkv"
else:
kwargs["set_parallel_mode"] = True
kwargs["ub_overlap_rs"] = not reference
if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]:
args.append(4 * hidden_size)
kwargs["seq_length"] = config.seq_length
if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
args.append(config.num_heads)
kwargs["attention_dropout"] = 0.0
kwargs["fuse_qkv_params"] = True
if config.layer_type is te.MultiheadAttention:
kwargs["input_layernorm"] = True
else:
kwargs["ub_tp_comm_overlap"] = not reference
kwargs["hidden_dropout"] = 0.0
return args, kwargs, input_shape
def _parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser(
description="Test a Transformer Engine layer with GEMM+comm overlap via Userbuffers."
)
parser.add_argument("-l", "--layer-type", type=_te_layer_argtype, default=te.LayerNormMLP)
parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.")
parser.add_argument("-s", "--seq-length", type=int, default=2048, help="Input sequence length.")
parser.add_argument(
"-n", "--num-heads", type=int, default=12, help="Number of attention heads."
)
parser.add_argument(
"-d", "--head-dim", type=int, default=64, help="Dimension of each attention head."
)
parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
)
parser.add_argument(
"--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8."
)
parser.add_argument(
"--tcp-init",
action="store_true",
default=False,
help="Initialize torch.distributed with TcpStore.",
)
parser.add_argument(
"--bind-to-device",
action="store_true",
default=False,
help="Initialize torch.distributed with `device_id` to bind each rank to a single device.",
)
parser.add_argument(
"--bootstrap-backend",
type=str.lower,
default="nccl",
choices=["gloo", "mpi", "nccl"],
help="Communications backend for host tensor collectives during Userbuffers bootstrapping.",
)
parser.add_argument(
"--use-cuda-graphs", action="store_true", default=False, help="Use CUDA Graphs."
)
parser.add_argument(
"--debug",
action="store_true",
default=False,
help="Print out additional debug information.",
)
args = parser.parse_args(argv, namespace)
if args.use_cuda_graphs and args.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
warnings.warn(f"{args.layer_type.__name__} does not support CUDA Graphs!")
args.use_cuda_graphs = False
return args
def _compare_tensors(name, test, ref, rtol, atol):
# Make sure tensors aren't zero and we don't pass trivially
if test.count_nonzero() == 0:
if ref.count_nonzero() == 0:
warnings.warn(
f"WARNING: {name} is a zero-tensor for both test and reference models!",
category=RuntimeWarning,
)
else:
numerics_info = (
f"NUMERICAL CHECK FAILED: {name} is a zero-tensor but does not match reference!"
)
return 1, numerics_info
diff = torch.abs(test - ref).flatten()
m = torch.argmax(diff)
abs_err = diff[m].item()
rel_err = abs_err / max(abs(ref.flatten()[m].item()), 1e-5)
numerics_failed = 0
if rel_err > rtol and abs_err > atol:
numerics_failed = 1
numerics_info = (
"NUMERICAL CHECK FAILED: "
+ f"{name} not close enough at index {m.item()} "
+ f"with {test.flatten()[m].item()} vs {ref.flatten()[m].item()} | "
+ f"rel. error = {rel_err} (tol = {rtol}) | "
+ f"abs. error = {abs_err} (tol = {atol})"
)
else:
numerics_info = f"NUMERICAL CHECK PASSED: {name} | "
if rel_err <= rtol:
numerics_info += f"rel. error = {rel_err} (tol = {rtol})" + (
" | " if abs_err <= atol else "."
)
if abs_err <= atol:
numerics_info += f" abs. error = {abs_err} (tol = {atol})"
return numerics_failed, numerics_info
def _train(opts):
if "OMPI_COMM_WORLD_SIZE" in os.environ:
# Execution with `mpirun -np N`
WORLD_RANK = int(os.getenv("OMPI_COMM_WORLD_RANK", "0"))
WORLD_SIZE = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1"))
opts.tcp_init = True
opts.bind_to_device = True
opts.bootstrap_backend = "mpi"
elif "TORCHELASTIC_RUN_ID" in os.environ:
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
else:
raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!")
assert LOCAL_SIZE == WORLD_SIZE
def dist_print(msg, src=None, end="\n", debug=False, error=False):
if debug and not opts.debug:
return
stream = sys.stderr if error else sys.stdout
if WORLD_RANK == (0 if src is None else src):
stream.write(f"[rank{WORLD_RANK}] {msg}{end}\n")
dist.barrier()
# Set device and initialize RNG states
torch.cuda.set_device(WORLD_RANK)
torch.manual_seed(opts.seed)
torch.cuda.manual_seed(opts.seed)
# Initialize torch.distributed global process group and get DP/TP groups
dist_init_kwargs = {
"backend": "nccl",
"rank": WORLD_RANK,
"world_size": WORLD_SIZE,
}
if opts.tcp_init:
MASTER_ADDR = os.getenv("MASTER_ADDR", socket.gethostbyname(socket.gethostname()))
MASTER_PORT = os.getenv("MASTER_PORT", "1234")
dist_init_kwargs["init_method"] = f"tcp://{MASTER_ADDR}:{MASTER_PORT}"
if opts.bind_to_device or opts.bootstrap_backend == "nccl":
dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}")
assert dist.is_nccl_available()
dist.init_process_group(**dist_init_kwargs)
nccl_world = dist.new_group(backend="nccl")
dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs")
# Intialize userbuffers
te.module.base.initialize_ub(
[opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim],
WORLD_SIZE,
use_fp8=opts.fp8,
dtype=torch.bfloat16,
bootstrap_backend=opts.bootstrap_backend,
)
# Initialize the Transformer Engine layer with overlap
args, kwargs, input_shape = _get_layer_args(opts, nccl_world, WORLD_SIZE)
with te.fp8_model_init(enabled=opts.fp8_init):
test_model = opts.layer_type(*args, **kwargs)
dist_print("Initialized test model...", debug=True)
# Initialize the reference model and copy all parameters
ref_args, ref_kwargs, _ = _get_layer_args(opts, nccl_world, WORLD_SIZE, reference=True)
with te.fp8_model_init(enabled=opts.fp8_init):
ref_model = opts.layer_type(*ref_args, **ref_kwargs)
dist_print("Initialized reference model...", debug=True)
for test_param, ref_param in zip(test_model.parameters(), ref_model.parameters()):
with torch.no_grad():
ref_param.copy_(test_param)
torch.testing.assert_close(test_param, ref_param, rtol=0.0, atol=0.0)
dist_print("Copied parameters from test model to reference model...", debug=True)
# Fp8 recipe setup
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
# Prepare random input tensors
test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True)
test_x.retain_grad()
ref_x = torch.empty_like(test_x).requires_grad_(True)
with torch.no_grad():
ref_x.copy_(test_x)
torch.testing.assert_close(test_x, ref_x, rtol=0.0, atol=0.0)
ref_x.retain_grad()
# Execute fwd/bwd and collect tensors to test
def run_fwd_bwd(model, x):
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world):
y = model(x)
if isinstance(y, tuple):
out, *_ = y
else:
out = y
loss = out.sum()
loss.backward()
return out
torch_rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state(torch.device(f"cuda:{WORLD_RANK}"))
if opts.use_cuda_graphs:
test_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(test_graph):
test_out = run_fwd_bwd(test_model, test_x)
test_graph.replay()
del test_graph
else:
test_out = run_fwd_bwd(test_model, test_x)
test_grads = [test_out, test_x.grad]
names = ["output", "input.grad"]
for test_name, test_param in test_model.named_parameters():
if test_param.requires_grad and "layer_norm" not in test_name:
test_grads.append(test_param.grad)
names.append(test_name + ".grad")
torch.set_rng_state(torch_rng_state)
torch.cuda.set_rng_state(cuda_rng_state, torch.device(f"cuda:{WORLD_RANK}"))
if opts.use_cuda_graphs:
ref_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(ref_graph):
ref_out = run_fwd_bwd(ref_model, ref_x)
ref_graph.replay()
del ref_graph
else:
ref_out = run_fwd_bwd(ref_model, ref_x)
ref_grads = [ref_out, ref_x.grad]
for ref_name, ref_param in ref_model.named_parameters():
if ref_param.requires_grad and "layer_norm" not in ref_name:
ref_grads.append(ref_param.grad)
# Make sure we have the same number of gradients
numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda")
if len(test_grads) != len(ref_grads):
numerics_failed[0] = 1
numerics_info = (
"NUMERICAL CHECK FAILED: Incorrect number of gradients, "
+ f"expected {len(ref_grads)} but got {len(test_grads)}."
)
dist_print(numerics_info, src=WORLD_RANK, error=True)
dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world)
# Now validate accuracy
if not bool(numerics_failed.item()):
for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)):
rtol = 0.125 if opts.fp8 else 0.025
atol = 0.0625 if opts.fp8 else 0.00125
grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol)
dist_print(grad_info, src=WORLD_RANK, error=grad_failed)
numerics_failed[0] = int(grad_failed)
dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world)
if bool(numerics_failed.item()):
break
te.module.base.destroy_ub()
dist_print("Destroying Userbuffers objects...", debug=True)
dist_print("Destroying all process groups...", debug=True)
dist.destroy_process_group()
if opts.debug and WORLD_RANK == 0:
print("Exiting...\n", end="", flush=True)
return numerics_failed[0].item()
if __name__ == "__main__":
sys.exit(_train(_parse_args()))
...@@ -7,16 +7,27 @@ from pathlib import Path ...@@ -7,16 +7,27 @@ from pathlib import Path
import pytest import pytest
import torch import torch
import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
if torch.cuda.device_count() < 2:
pytest.skip("Comm+GEMM overlap requires at least 2 GPUs.")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
RNG_SEED: int = 1234 RNG_SEED: int = 1234
SEQ_LENGTH: int = 2024 SEQ_LENGTH: int = 512
BATCH_SIZE: int = 2 BATCH_SIZE: int = 2
NUM_HEADS: int = 64 NUM_HEADS: int = 12
HEAD_DIM: int = 128 HEAD_DIM: int = 64
TE_LAYERS = [
te.Linear,
te.LayerNormLinear,
te.LayerNormMLP,
te.MultiheadAttention,
te.TransformerLayer,
]
TEST_ROOT = Path(__file__).parent.resolve() TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(torch.cuda.device_count(), 4) NUM_PROCS: int = min(torch.cuda.device_count(), 4)
...@@ -32,49 +43,10 @@ if not tex.device_supports_multicast(): ...@@ -32,49 +43,10 @@ if not tex.device_supports_multicast():
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
@pytest.mark.skipif(NUM_PROCS < 2, reason="Comm+GEMM overlap requires at least 2 GPUs.") def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggregate):
@pytest.mark.parametrize(
"fp8,p2p,comm_type,aggregate,atomic,bulk",
[
# FP8, P2P, Type, Aggregate, Atomic, Bulk
(False, True, "AG", False, False, False),
(False, True, "AG", True, False, False),
(True, True, "AG", False, False, False),
(True, True, "AG", True, False, False),
(False, False, "RS", False, False, False),
(False, True, "RS", False, False, False),
(True, False, "RS", False, False, False),
(True, True, "RS", False, False, False),
(True, False, "RS", False, True, False),
(True, True, "RS", False, True, False),
(False, False, "AG", False, False, True),
(False, False, "RS", False, False, True),
],
ids=[
" AG -> SPLIT GEMM | BF16 | RING-EXCHANGE ",
" AG -> SPLIT GEMM | BF16 | RING-EXCHANGE (2X AGGREGATED) ",
" AG -> SPLIT GEMM | FP8 | RING-EXCHANGE ",
" AG -> SPLIT GEMM | FP8 | RING-EXCHANGE (2X AGGREGATED) ",
" SPLIT GEMM -> RS | BF16 | PIPELINE ",
" SPLIT GEMM -> RS | BF16 | RING-EXCHANGE ",
" SPLIT GEMM -> RS | FP8 | PIPELINE ",
" SPLIT GEMM -> RS | FP8 | RING-EXCHANGE ",
" ATOMIC GEMM -> RS | FP8 | PIPELINE ",
" ATOMIC GEMM -> RS | FP8 | RING-EXCHANGE ",
" BULK AG & GEMM | BF16 | PIPELINE ",
" BULK RS & GEMM | BF16 | PIPELINE ",
],
)
def test_gemm_with_overlap(fp8, p2p, comm_type, aggregate, atomic, bulk):
"""
Test comm+GEMM overlap algorithms with direct calls to
te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm
"""
test_path = TEST_ROOT / "run_gemm_with_overlap.py" test_path = TEST_ROOT / "run_gemm_with_overlap.py"
test_cmd = ( test_cmd = LAUNCH_CMD + [
LAUNCH_CMD str(test_path),
+ [str(test_path)]
+ [
"--check-numerics", "--check-numerics",
f"--seed={RNG_SEED}", f"--seed={RNG_SEED}",
f"--seq-length={SEQ_LENGTH}", f"--seq-length={SEQ_LENGTH}",
...@@ -83,15 +55,16 @@ def test_gemm_with_overlap(fp8, p2p, comm_type, aggregate, atomic, bulk): ...@@ -83,15 +55,16 @@ def test_gemm_with_overlap(fp8, p2p, comm_type, aggregate, atomic, bulk):
f"--head-dim={HEAD_DIM}", f"--head-dim={HEAD_DIM}",
f"--comm-type={comm_type}", f"--comm-type={comm_type}",
] ]
)
if bulk: if bulk:
test_cmd.append("--bulk-overlap") test_cmd.append("--bulk-overlap")
else: else:
if fp8: if fp8_in:
if not fp8_available: if not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
test_cmd.append("--fp8") test_cmd.append("--fp8")
if fp8_out:
test_cmd.append("--fp8-output")
if p2p: if p2p:
test_cmd.append("--p2p") test_cmd.append("--p2p")
if aggregate: if aggregate:
...@@ -101,5 +74,173 @@ def test_gemm_with_overlap(fp8, p2p, comm_type, aggregate, atomic, bulk): ...@@ -101,5 +74,173 @@ def test_gemm_with_overlap(fp8, p2p, comm_type, aggregate, atomic, bulk):
pytest.skip("Device compute capability 9.0 or higher required for Atomic GEMM.") pytest.skip("Device compute capability 9.0 or higher required for Atomic GEMM.")
test_cmd.append("--atomic") test_cmd.append("--atomic")
output = subprocess.run(test_cmd, env=os.environ, text=True, capture_output=True, check=False) result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False)
assert "NUMERICAL CHECK PASSED" in str(output) if (
result.returncode != 0
or "NUMERICAL CHECK FAILED" in result.stderr.decode()
or "NUMERICAL CHECK PASSED" not in result.stdout.decode()
):
raise AssertionError(result.stderr.decode())
def _run_layer_with_overlap(layer_type, fp8, fp8_init):
test_path = TEST_ROOT / "run_layer_with_overlap.py"
test_cmd = LAUNCH_CMD + [
str(test_path),
f"--seed={RNG_SEED}",
f"--seq-length={SEQ_LENGTH}",
f"--batch-size={BATCH_SIZE}",
f"--num-heads={NUM_HEADS}",
f"--head-dim={HEAD_DIM}",
f"--layer-type={layer_type}",
]
if fp8:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
test_cmd.append("--fp8")
if fp8_init:
test_cmd.append("--fp8-init")
os.environ["PYTORCH_JIT"] = "0"
os.environ["NVTE_TORCH_COMPILE"] = "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0"
result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False)
os.unsetenv("PYTORCH_JIT")
os.unsetenv("NVTE_TORCH_COMPILE")
os.unsetenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO")
if (
result.returncode != 0
or "NUMERICAL CHECK FAILED" in result.stderr.decode()
or "NUMERICAL CHECK PASSED" not in result.stdout.decode()
):
raise AssertionError(result.stderr.decode())
@pytest.mark.parametrize(
"fp8,aggregate",
[
(False, False),
(False, True),
(True, False),
(True, True),
],
ids=[
" BF16 IN - RING-EXCHANGE ",
" BF16 IN - RING-EXCHANGE - 2x AGGREGATED ",
" FP8 IN - RING-EXCHANGE ",
" FP8 IN - RING-EXCHANGE - 2x AGGREGATED ",
],
)
def test_split_all_gather_overlaps(fp8, aggregate):
"""
Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm.
"""
_run_gemm_with_overlap("AG", False, True, False, fp8, False, aggregate)
@pytest.mark.parametrize(
"fp8_in,fp8_out,p2p",
[
(False, False, False),
(False, False, True),
(True, False, False),
(True, False, True),
(True, True, False),
(True, True, True),
],
ids=[
" BF16 IN - BF16 OUT - PIPELINE ",
" BF16 IN - BF16 OUT - RING-EXCHANGE ",
" FP8 IN - BF16 OUT - PIPELINE ",
" FP8 IN - BF16 OUT - RING-EXCHANGE ",
" FP8 IN - FP8 OUT - PIPELINE ",
" FP8 IN - FP8 OUT - RING-EXCHANGE ",
],
)
def test_split_reduce_scatter_overlaps(fp8_in, fp8_out, p2p):
"""
Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm.
"""
_run_gemm_with_overlap("RS", False, p2p, False, fp8_in, fp8_out, False)
@pytest.mark.parametrize(
"ag_type,rs_type,p2p,fp8_out",
[
(0, 0, False, False),
(0, 1, False, False),
(0, 1, False, True),
(0, 2, False, False),
(0, 2, False, True),
(0, 0, True, False),
(0, 0, True, True),
(1, 0, True, False),
(1, 0, True, True),
],
ids=[
" NON-ATOMIC AG - NON-ATOMIC RS - PIPELINE - BF16 OUT ",
" NON-ATOMIC AG - ATOMIC RS - PIPELINE - BF16 OUT ",
" NON-ATOMIC AG - ATOMIC RS - PIPELINE - FP8 OUT ",
" NON-ATOMIC AG - MULTI-ATOMIC RS - PIPELINE - BF16 OUT ",
" NON-ATOMIC AG - MULTI-ATOMIC RS - PIPELINE - FP8 OUT ",
" NON-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - BF16 OUT ",
" NON-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - FP8 OUT ",
" MULTI-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - BF16 OUT ",
" MULTI-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - FP8 OUT ",
],
)
def test_atomic_gemm_overlaps(ag_type, rs_type, p2p, fp8_out):
"""
Test paired (all-gather -> atomic GEMM) and (atomic GEMM -> reduce-scatter) overlaps with
direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm.
"""
os.environ["NVTE_AG_P2P_MULTI_ATOMIC"] = str(ag_type)
os.environ["NVTE_RS_STRIDED_ATOMIC"] = str(rs_type)
_run_gemm_with_overlap("AG", False, p2p, True, True, fp8_out, False)
@pytest.mark.parametrize(
"comm_type,fp8",
[
("AG", False),
("RS", False),
("RS", True),
],
ids=[" ALL-GATHER - BF16 ", " REDUCE-SCATTER - BF16 ", " REDUCE-SCATTER - FP8 "],
)
def test_bulk_overlaps(comm_type, fp8):
"""
Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm.
"""
_run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False)
@pytest.mark.parametrize(
"layer_type",
[layer.__name__ for layer in TE_LAYERS],
ids=[(" " + layer.__name__ + " ") for layer in TE_LAYERS],
)
@pytest.mark.parametrize(
"fp8,fp8_init",
[
(False, False),
(True, False),
(True, True),
],
ids=[
" BF16 GEMM - BF16 PARAMS ",
" FP8 GEMM - BF16 PARAMS ",
" FP8 GEMM - FP8 PARAMS ",
],
)
def test_layers_with_overlap(layer_type, fp8, fp8_init):
"""
Test Transformer Engine layers with comm+GEMM overlap.
"""
_run_layer_with_overlap(layer_type, fp8, fp8_init)
...@@ -166,7 +166,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -166,7 +166,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
// Initialize userbuf communicator // Initialize userbuf communicator
if (!comm_created) { if (!comm_created) {
if (myrank == 0) { if (myrank == 0) {
printf("!!! [UB] Create UbufCommOverlap Communicator\n"); printf("!!! [UB] Create Userbuffers Communicator\n");
} }
#ifdef NVTE_UB_WITH_MPI #ifdef NVTE_UB_WITH_MPI
create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1);
...@@ -184,16 +184,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -184,16 +184,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
// Allocate and register extra userbuffers // Allocate and register extra userbuffers
int ubuf_bytes = sample.numel() * sample.element_size(); int ubuf_bytes = sample.numel() * sample.element_size();
if (transformer_engine::getenv<bool>("UB_SKIPMC")) {
_ubuf = torch::zeros_like(sample);
_ubuf_ptr = _ubuf.data_ptr();
_ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes,
_ub_comm, false);
} else {
_ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes, _ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes,
_ub_comm, true); _ub_comm, true);
_ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options()); _ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options());
}
if (_ub_comm->myrank == 0) { if (_ub_comm->myrank == 0) {
printf("!!! [UB] Register UBuf %d\n", _ub_reg); printf("!!! [UB] Register UBuf %d\n", _ub_reg);
...@@ -264,6 +257,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -264,6 +257,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, int comm_type, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int comm_type,
at::Tensor rs_output) { at::Tensor rs_output) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm; _ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size; _ub_comm->cga_size = _cga_size;
...@@ -319,6 +313,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -319,6 +313,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size;
int output_c_dim1 = _ubuf.size(1); int output_c_dim1 = _ubuf.size(1);
output_tensor = torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options()); output_tensor = torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options());
_ub_comm->sms = ori_sms;
return {D, output_tensor}; return {D, output_tensor};
} // bulk_overlap } // bulk_overlap
...@@ -336,6 +331,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -336,6 +331,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
bool grad, at::Tensor workspace, size_t workspaceSize, bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, bool gemm_overlap, bool accumulate, bool use_split_accumulator, bool gemm_overlap,
at::Tensor rs_output) { at::Tensor rs_output) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm; _ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size; _ub_comm->cga_size = _cga_size;
...@@ -352,7 +348,6 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -352,7 +348,6 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr()); char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr());
int *counter_ptr = reinterpret_cast<int *>(counter.data_ptr()); int *counter_ptr = reinterpret_cast<int *>(counter.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr()); char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
int ori_sms = _ub_comm->sms;
// Catch up the default torch stream // Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
...@@ -388,7 +383,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -388,7 +383,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
assert(_ubuf_scale_inv_initialized); assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr()); float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type, D_type, fp8_type,
reducescatter2_userbuff_strided_atomic_fp8<fp8_type>( reducescatter2_userbuff_strided_atomic_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, m, rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, m,
_num_splits, &counter_ptr[i], _ub_comm, (cudaStream_t)_stream_comm);); _num_splits, &counter_ptr[i], _ub_comm, (cudaStream_t)_stream_comm););
...@@ -402,7 +397,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -402,7 +397,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
assert(_ubuf_scale_inv_initialized); assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr()); float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type, D_type, fp8_type,
reducescatter2_userbuff_strided_multiatomic_fp8<fp8_type>( reducescatter2_userbuff_strided_multiatomic_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, rs_output_ptr, d_scale_inv_ptr, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits,
counter_ptr, _ub_comm, (cudaStream_t)_stream_comm);); counter_ptr, _ub_comm, (cudaStream_t)_stream_comm););
...@@ -413,10 +408,8 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -413,10 +408,8 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
} }
break; break;
} else { } else {
assert(_ubuf.element_size() != 1);
consumer(counter_ptr, i, (cudaStream_t)_stream_comm); consumer(counter_ptr, i, (cudaStream_t)_stream_comm);
// if (i == _num_splits-1) {
// _ub_comm->sms = UB_MAX_SM;
// }
reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm); _ub_comm, (cudaStream_t)_stream_comm);
} }
...@@ -447,6 +440,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -447,6 +440,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
size_t workspaceSize, bool accumulate, bool use_split_accumulator, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
bool gemm_overlap, at::Tensor rs_output) { bool gemm_overlap, at::Tensor rs_output) {
// Get GEMM dimensions // Get GEMM dimensions
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm; _ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size; _ub_comm->cga_size = _cga_size;
...@@ -464,7 +458,6 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -464,7 +458,6 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr()); char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr()); char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
int ori_sms = _ub_comm->sms;
// Catch up the default torch stream // Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
...@@ -517,7 +510,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -517,7 +510,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
assert(_ubuf_scale_inv_initialized); assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr()); float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type, D_type, fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>( reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, rs_output_ptr, d_scale_inv_ptr, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n,
m, _ub_comm, (cudaStream_t)_stream_comm);); m, _ub_comm, (cudaStream_t)_stream_comm););
...@@ -541,7 +534,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -541,7 +534,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
assert(_ubuf_scale_inv_initialized); assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr()); float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type, D_type, fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>( reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size, rs_output_ptr, d_scale_inv_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size,
m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm);); m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm););
...@@ -577,7 +570,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -577,7 +570,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
assert(_ubuf_scale_inv_initialized); assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr()); float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type, D_type, fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>( reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * output_chunk_size, m_chunk, n, m, rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * output_chunk_size, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm);); _ub_comm, (cudaStream_t)_stream_comm););
...@@ -682,7 +675,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -682,7 +675,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
// Initialize userbuf communicator // Initialize userbuf communicator
if (!comm_created) { if (!comm_created) {
if (myrank == 0) { if (myrank == 0) {
printf("!!! [UB] Create UbufP2PCommOverlap Communicator\n"); printf("!!! [UB] Create Userbuffers Communicator\n");
} }
#ifdef NVTE_UB_WITH_MPI #ifdef NVTE_UB_WITH_MPI
create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1);
...@@ -708,19 +701,11 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -708,19 +701,11 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
ubuf_bytes = static_cast<int>(ubuf_bytes / tp_size * (tp_size * 2 - 1)); ubuf_bytes = static_cast<int>(ubuf_bytes / tp_size * (tp_size * 2 - 1));
num_ubuf_chunks = static_cast<int>(tp_size * 2 - 1); num_ubuf_chunks = static_cast<int>(tp_size * 2 - 1);
} }
if (transformer_engine::getenv<bool>("UB_SKIPMC")) {
_ubuf = torch::zeros({sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)},
sample.options());
_ubuf_ptr = _ubuf.data_ptr();
_ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes,
_ub_comm, false);
} else {
_ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes, _ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes,
_ub_comm, true); _ub_comm, true);
_ubuf = _ubuf = torch::from_blob(
torch::from_blob(_ubuf_ptr, {sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)}, _ubuf_ptr, {sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)}, sample.options());
sample.options());
}
if (_ub_comm->myrank == 0) { if (_ub_comm->myrank == 0) {
printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg);
} }
...@@ -728,9 +713,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -728,9 +713,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
// Create tensor chunks for easy management // Create tensor chunks for easy management
char *ubuf_byte_ptr = reinterpret_cast<char *>(_ubuf.data_ptr()); char *ubuf_byte_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
for (int i = 0; i < num_ubuf_chunks; i++) { for (int i = 0; i < num_ubuf_chunks; i++) {
torch::Tensor ubuf_chunk = torch::from_blob( auto ubuf_chunk = torch::from_blob(ubuf_byte_ptr, {sample.size(0) / tp_size, sample.size(1)},
ubuf_byte_ptr, {sample.size(0) / tp_size, sample.size(1)}, sample.options()); sample.options());
_ubufs.push_back(ubuf_chunk); _ubufs.push_back(std::move(ubuf_chunk));
ubuf_byte_ptr += ubuf_chunk_bytes; ubuf_byte_ptr += ubuf_chunk_bytes;
} }
...@@ -769,6 +754,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -769,6 +754,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC"); const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC");
if (_rank == 0 && env_p != nullptr) { if (_rank == 0 && env_p != nullptr) {
if (env_p[0] == '1') { if (env_p[0] == '1') {
_use_ce = 0;
_ub_comm->push = 1;
printf("!!userbuffers_sendrecv_multi_atomic_shuffle\n"); printf("!!userbuffers_sendrecv_multi_atomic_shuffle\n");
} }
} }
...@@ -818,6 +805,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -818,6 +805,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias,
transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm; _ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size; _ub_comm->cga_size = _cga_size;
...@@ -866,6 +854,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -866,6 +854,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC"); const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC");
if (env_p != nullptr && env_p[0] == '1') { if (env_p != nullptr && env_p[0] == '1') {
if (i == 0) { if (i == 0) {
_ub_comm->use_ce = 0;
userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes, userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes,
_ub_comm, _next_rank, _prev_rank, _tp_size, counter_ptr, _ub_comm, _next_rank, _prev_rank, _tp_size, counter_ptr,
true, (cudaStream_t)_stream_recv); true, (cudaStream_t)_stream_recv);
...@@ -906,6 +895,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -906,6 +895,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
n_chunk * m * D.element_size(), cudaMemcpyDeviceToDevice, n_chunk * m * D.element_size(), cudaMemcpyDeviceToDevice,
(cudaStream_t)stream_main)); (cudaStream_t)stream_main));
// Return the last N rows of D_buffer // Return the last N rows of D_buffer
_ub_comm->sms = ori_sms;
torch::Tensor D_return = D_buffer.narrow(0, n_chunk, n); torch::Tensor D_return = D_buffer.narrow(0, n_chunk, n);
return D_return; return D_return;
} // atomic_gemm_overlap_ag } // atomic_gemm_overlap_ag
...@@ -926,6 +916,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -926,6 +916,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
transformer_engine::DType bias_type, at::Tensor pre_gelu_out, transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
bool grad, at::Tensor workspace, size_t workspaceSize, bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { bool accumulate, bool use_split_accumulator, at::Tensor B_copy) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm; _ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size; _ub_comm->cga_size = _cga_size;
...@@ -1078,6 +1069,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1078,6 +1069,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
at::cuda::setCurrentCUDAStream(stream_main); at::cuda::setCurrentCUDAStream(stream_main);
_ub_comm->sms = ori_sms;
return D; return D;
} // split_overlap_ag } // split_overlap_ag
...@@ -1094,6 +1086,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1094,6 +1086,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
transformer_engine::DType bias_type, at::Tensor pre_gelu_out, transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
bool grad, at::Tensor workspace, size_t workspaceSize, bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, at::Tensor rs_output) { bool accumulate, bool use_split_accumulator, at::Tensor rs_output) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm; _ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size; _ub_comm->cga_size = _cga_size;
...@@ -1149,7 +1142,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1149,7 +1142,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr()); float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr()); char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type, D_type, fp8_type,
reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size, reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size,
_ubufs[0].numel(), (cudaStream_t)stream_main);); _ubufs[0].numel(), (cudaStream_t)stream_main););
} else { } else {
...@@ -1157,6 +1150,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1157,6 +1150,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0); torch::sum_out(rs_output, reduce_buf, 0);
} }
_ub_comm->sms = ori_sms;
} }
/* /*
...@@ -1171,6 +1165,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1171,6 +1165,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
at::Tensor rs_output) { at::Tensor rs_output) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm; _ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size; _ub_comm->cga_size = _cga_size;
...@@ -1245,7 +1240,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1245,7 +1240,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr()); float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr()); char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type, D_type, fp8_type,
reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size, reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size,
_ubufs[0].numel(), (cudaStream_t)stream_main);); _ubufs[0].numel(), (cudaStream_t)stream_main););
} else { } else {
...@@ -1259,6 +1254,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1259,6 +1254,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
} }
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
_ub_comm->sms = ori_sms;
} }
/* /*
......
...@@ -1861,6 +1861,14 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const ...@@ -1861,6 +1861,14 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8)
} }
template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream);
template <typename fp8type> template <typename fp8type>
void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset, void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset,
const int elements, communicator *comm, cudaStream_t stream) { const int elements, communicator *comm, cudaStream_t stream) {
......
...@@ -107,7 +107,7 @@ def initialize_ub( ...@@ -107,7 +107,7 @@ def initialize_ub(
world_size = torch.distributed.get_world_size(mpi_group) world_size = torch.distributed.get_world_size(mpi_group)
local_rank = world_rank % tp_size local_rank = world_rank % tp_size
local_size = tp_size local_size = tp_size
node_id = world_rank // tp_size self_node_idx = world_rank // tp_size
num_nodes = world_size // tp_size num_nodes = world_size // tp_size
ub_callbacks = tex.UbufBootstrapCallbacks() ub_callbacks = tex.UbufBootstrapCallbacks()
else: else:
...@@ -127,13 +127,6 @@ def initialize_ub( ...@@ -127,13 +127,6 @@ def initialize_ub(
world_rank = torch.distributed.get_rank(world_group) world_rank = torch.distributed.get_rank(world_group)
world_size = torch.distributed.get_world_size(world_group) world_size = torch.distributed.get_world_size(world_group)
if world_rank == 0:
print(
f'!!! [NVTE] Bootstrapping Userbuffers with backend="{bootstrap_backend}"\n',
end="",
flush=True,
)
# Construct an intra-node communicator based on global ranks that share the same hostname # Construct an intra-node communicator based on global ranks that share the same hostname
# NOTE: If the user specified a valid network interface for NCCL or GLOO, use the host # NOTE: If the user specified a valid network interface for NCCL or GLOO, use the host
# address on that interface instead of the hostname. This can help avoid issues when # address on that interface instead of the hostname. This can help avoid issues when
...@@ -157,28 +150,41 @@ def initialize_ub( ...@@ -157,28 +150,41 @@ def initialize_ub(
hostnames = [None for _ in range(world_size)] hostnames = [None for _ in range(world_size)]
torch.distributed.all_gather_object(hostnames, hostname, world_group) torch.distributed.all_gather_object(hostnames, hostname, world_group)
intra_node_ranks = [] unique_hosts = []
for host in hostnames:
if host not in unique_hosts:
unique_hosts.append(host)
num_nodes = len(unique_hosts)
if num_nodes > 1:
ranks_per_node_list = [[] for _ in range(num_nodes)]
self_node_idx = -1
for i, host in enumerate(hostnames): for i, host in enumerate(hostnames):
node_idx = unique_hosts.index(host)
ranks_per_node_list[node_idx].append(i)
if host == hostname: if host == hostname:
intra_node_ranks.append(i) self_node_idx = node_idx
if len(intra_node_ranks) == world_size: assert self_node_idx >= 0, "Internal TE error!"
intra_node_group, _ = torch.distributed.new_subgroups_by_enumeration(
ranks_per_node_list, backend=bootstrap_backend
)
local_rank = torch.distributed.get_rank(intra_node_group)
local_size = torch.distributed.get_world_size(intra_node_group)
intra_node_ranks = torch.distributed.get_process_group_ranks(intra_node_group)
else:
self_node_idx = 0
intra_node_group = world_group intra_node_group = world_group
local_rank = world_rank local_rank = world_rank
local_size = world_size local_size = world_size
intra_node_ranks = list(range(world_size)) intra_node_ranks = list(range(world_size))
else:
intra_node_group = torch.distributed.new_group(
backend=bootstrap_backend, ranks=intra_node_ranks
)
local_rank = torch.distributed.get_rank(intra_node_group)
local_size = torch.distributed.get_world_size(intra_node_group)
node_id = world_rank // local_size if world_rank == 0:
num_nodes = world_size // local_size print(f"!!! [UB] Number of physical nodes: {num_nodes}\n", end="", flush=True)
if local_rank == 0: if local_rank == 0:
print( print(
f"!!! [NVTE] Number of physical nodes: {num_nodes}\n" f"!!! [UB] Global ranks on node {self_node_idx}: {intra_node_ranks}\n",
+ f"!!! [NVTE] Global ranks on node {node_id}: {intra_node_ranks}\n",
end="", end="",
flush=True, flush=True,
) )
...@@ -293,7 +299,7 @@ def initialize_ub( ...@@ -293,7 +299,7 @@ def initialize_ub(
world_size, # World size world_size, # World size
local_rank, # Rank within the node local_rank, # Rank within the node
local_size, # Number of ranks/GPUs per node local_size, # Number of ranks/GPUs per node
node_id, # Node ID self_node_idx, # Node ID
num_nodes, # Number of nodes num_nodes, # Number of nodes
tp_size, # Tensor-parallel group size (may be different than local_size) tp_size, # Tensor-parallel group size (may be different than local_size)
num_sm, # Number of communication SMs num_sm, # Number of communication SMs
...@@ -313,7 +319,7 @@ def initialize_ub( ...@@ -313,7 +319,7 @@ def initialize_ub(
world_size, # World size world_size, # World size
local_rank, # Rank within the node local_rank, # Rank within the node
local_size, # Number of ranks/GPUs per node local_size, # Number of ranks/GPUs per node
node_id, # Node ID self_node_idx, # Node ID
num_nodes, # Number of nodes num_nodes, # Number of nodes
tp_size, # Tensor-parallel group size (may be different than local_size) tp_size, # Tensor-parallel group size (may be different than local_size)
num_sm, # Number of communication SMs num_sm, # Number of communication SMs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment