Commit 27ddce40 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents d262ef4c 5b3092a0
......@@ -12,6 +12,8 @@ import argparse
import warnings
import pprint
import yaml
from contextlib import nullcontext
from functools import partial
import torch
import torch.distributed as dist
......@@ -35,8 +37,9 @@ class multi_module_model(torch.nn.Module):
self.num_layers = num_layers
self.layers = torch.nn.ModuleList([module(*args, **kwargs) for _ in range(num_layers)])
def forward(self, x):
for layer in self.layers:
def forward(self, x, layer_contexts):
for layer, context in zip(self.layers, layer_contexts):
with context():
x = layer(x)
return x
......@@ -237,12 +240,46 @@ def _parse_args(argv=None, namespace=None):
default=False,
help="Print out additional debug information.",
)
parser.add_argument(
"--first-last-layers-bf16",
action="store_true",
default=False,
help="Use bf16 for first and last N layers.",
)
parser.add_argument(
"--num-layers-at-start-in-bf16",
type=int,
default=0,
help="Number of layers at the start to run in bf16.",
)
parser.add_argument(
"--num-layers-at-end-in-bf16",
type=int,
default=0,
help="Number of layers at the end to run in bf16.",
)
args = parser.parse_args(argv, namespace)
if args.use_cuda_graphs and args.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
warnings.warn(f"{args.layer_type.__name__} does not support CUDA Graphs!")
args.use_cuda_graphs = False
if not args.first_last_layers_bf16 and (
args.num_layers_at_start_in_bf16 > 0 or args.num_layers_at_end_in_bf16 > 0
):
warnings.warn(
"num-layers-at-start-in-bf16 and num-layers-at-end-in-bf16 are only supported when"
" first-last-layers-bf16 is enabled!"
)
args.num_layers_at_start_in_bf16 = 0
args.num_layers_at_end_in_bf16 = 0
if args.num_layers_at_start_in_bf16 + args.num_layers_at_end_in_bf16 > args.num_layers:
raise ValueError(
"num-layers-at-start-in-bf16 + num-layers-at-end-in-bf16 must be less than or equal to"
" num-layers!"
)
return args
......@@ -381,10 +418,21 @@ def _train(opts):
"qkv_dgrad": {"method": "ring_exchange"},
"fc1_dgrad": {"method": "ring_exchange"},
}
quantization_modes = [
(
te.module.base.UserBufferQuantizationMode.FP8
if opts.fp8
else te.module.base.UserBufferQuantizationMode.NONE
)
]
if opts.first_last_layers_bf16 and opts.fp8:
quantization_modes.append(te.module.base.UserBufferQuantizationMode.NONE)
te.module.base.initialize_ub(
[opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim],
opts.tp,
use_fp8=opts.fp8,
quantization_modes=quantization_modes,
dtype=torch.bfloat16,
bootstrap_backend=opts.bootstrap_backend,
ub_cfgs=ub_cfgs if opts.ub_cfg is None else opts.ub_cfg,
......@@ -423,6 +471,16 @@ def _train(opts):
elif opts.quantization == "mxfp8":
fp8_recipe = MXFP8BlockScaling()
layer_contexts = [
(
partial(te.fp8_autocast, enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world)
if opts.num_layers_at_start_in_bf16 <= i
and i < (opts.num_layers - opts.num_layers_at_end_in_bf16)
else nullcontext
)
for i in range(opts.num_layers)
]
# Prepare random input tensors
test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True)
test_x.retain_grad()
......@@ -435,8 +493,7 @@ def _train(opts):
# Execute fwd/bwd and collect tensors to test
def run_fwd_bwd(model, x):
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world):
y = model(x)
y = model(x, layer_contexts)
if isinstance(y, tuple):
out, *_ = y
else:
......
......@@ -506,7 +506,13 @@ def main() -> None:
model_config.num_heads * model_config.head_dim,
],
torch.distributed.get_world_size(group),
use_fp8=model_config.quantization is not None,
quantization_modes=[
(
te.module.base.UserBufferQuantizationMode.FP8
if model_config.quantization is not None
else te.module.base.UserBufferQuantizationMode.NONE
)
],
dtype=model_config.dtype,
bootstrap_backend=bootstrap_backend,
ub_cfgs=userbuffer_configs,
......
......@@ -2,8 +2,11 @@
#
# See LICENSE for license information.
import contextlib
import gc
import os
from contextlib import nullcontext
from typing import Iterable, Optional
import pytest
import torch
......@@ -11,15 +14,16 @@ import transformer_engine.pytorch as te
from transformer_engine.common import recipe
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
from utils import ModelConfig, get_available_attention_backends
# Check if FP8 is supported
# Check supported quantization schemes
fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available()
fp8_recipes = [None]
quantization_recipes: Optional[recipe.Recipe] = [None]
if fp8_available:
fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(recipe.DelayedScaling())
quantization_recipes.extend((recipe.Float8CurrentScaling(), recipe.DelayedScaling()))
model_config = {
"small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1),
......@@ -48,85 +52,139 @@ model_types = {
"transformer_layer": lambda: te.TransformerLayer(
SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0
),
"linear_op": lambda: te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
"layernorm_mlp_ops": lambda: te.ops.Sequential(
te.ops.LayerNorm(SIZE, dtype=torch.bfloat16),
te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
te.ops.GELU(),
te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16),
),
}
def _get_input():
return torch.empty((128, SIZE, SIZE), dtype=torch.bfloat16).cuda()
def _make_input() -> torch.Tensor:
"""Generate random input tensor."""
return torch.randn(
(128, SIZE, SIZE),
dtype=torch.bfloat16,
device="cuda",
requires_grad=True,
)
def _get_fp8_weight_cache_size(models, fp8_recipe):
"""
Calculate the total FP8 weight cache size (in MB) for a list of models.
"""
if fp8_recipe is None:
def _warmup_model(
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
) -> None:
"""Perform forward and backward pass"""
tensor = _make_input()
for module in modules:
with te.fp8_autocast(
enabled=quantization_recipe is not None,
fp8_recipe=quantization_recipe,
):
tensor = module(tensor)
tensor.sum().backward()
def _estimate_cached_weight_size(
model_name: str,
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
) -> float:
"""Calculate the memory (in MiB) needed for weight caching."""
# The weight params are cached directly for unquantized compute
if quantization_recipe is None:
return 0
params_bytes = 0
for model in models:
for name, param in model.named_parameters():
if "weight" in name:
params_bytes += param.numel()
# Count number of weight param elements
param_elements = 0
for module in modules:
for param in module.parameters():
if param.dim() == 2:
param_elements += param.numel()
# FP8 tensor-scaling caches one byte per element
if quantization_recipe.delayed() or quantization_recipe.float8_current_scaling():
if not is_non_tn_fp8_gemm_supported() and model_name not in (
"linear_op",
"layernorm_mlp_ops",
):
# Modules do not deallocate FP8 transpose for weights
return 2 * param_elements / 1024**2
return param_elements / 1024**2
# MXFP8 caches one data byte per element and one scale byte per 32
# elements
if quantization_recipe.mxfp8():
if model_name not in ("linear_op", "layernorm_mlp_ops"):
# Modules do not deallocate column-wise MXFP8 data for weights
return 2 * param_elements * (1 + 1 / 32) / 1024**2
return param_elements * (1 + 1 / 32) / 1024**2
raise NotImplementedError(f"Unrecognized recipe ({quantization_recipe})")
def _measure_cached_memory(
modules: Iterable[torch.nn.Module],
quantization_recipe: Optional[recipe.Recipe],
cpu_offload: bool,
) -> float:
"""Measure the growth in allocated GPU memory in MiB after a model forward pass.
Memory measurement excludes the input and output tensors.
# One byte for columnwise and one byte for rowwise,
# hence multiply by 2 and convert to MB
# there is 1 byte of scale per 32 elements in mxFP8
factor_for_scale_inv_tensor = (1 + 1 / 32) if fp8_recipe.mxfp8() else 1
return (2 * params_bytes * factor_for_scale_inv_tensor) / (1024**2)
"""
# Reset memory
gc.collect()
torch.cuda.empty_cache()
def _measure_memory_between_forward_and_backward(models, fp8_recipe, cpu_offload):
tensor = _get_input()
# Context and sync function for CPU offloading
if cpu_offload:
offload_context, sync_function = te.get_cpu_offload_context(
enabled=True,
num_layers=len(models) - 1,
model_layers=len(models),
num_layers=len(modules),
model_layers=len(modules) + 1,
offload_activations=True,
offload_weights=False,
)
else:
offload_context = nullcontext()
offload_context = contextlib.nullcontext()
sync_function = lambda x: x
for model in models:
# Forward pass, with dummy step to trigger offload for last module
inp = _make_input()
tensor = inp
memory_before_forward = torch.cuda.memory_allocated() / (1024**2)
for module in modules:
with te.fp8_autocast(
enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe
enabled=quantization_recipe is not None, fp8_recipe=quantization_recipe
), offload_context:
tensor = model(tensor)
tensor = module(tensor)
tensor = sync_function(tensor)
with offload_context:
tensor = tensor.clone()
tensor = sync_function(tensor)
memory_after_forward = (torch.cuda.memory_allocated() - tensor.nbytes) / (1024**2)
max_mem_used = torch.cuda.memory_allocated() / (1024**2)
torch.cuda.synchronize()
# Backward pass
tensor.sum().backward()
torch.cuda.synchronize()
return max_mem_used
# Memory usage in MiB
return memory_after_forward - memory_before_forward
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model_key", model_types.keys())
def test_cpu_offload(fp8_recipe, model_key) -> None:
"""
We run three configurations:
(1) No offloading: All activations remain on the GPU between forward and backward passes.
(2) No offloading (one layer): Only the first layer's activations remain on the GPU between
forward and backward passes.
(3) With offloading (all layers): Only the last layer's activations remain on the GPU
between forward and backward passes, while all other layers are offloaded to the CPU.
We expect the memory consumption of configurations (2) and (3) to be similar, with
the difference being the size of the FP8 cache that is not offloaded to the CPU.
We also expect this memory consumption to be smaller than in scenario (1).
"""
import gc
@pytest.mark.parametrize("quantization_recipe", quantization_recipes)
@pytest.mark.parametrize("model_name", model_types.keys())
def test_cpu_offload(quantization_recipe: Optional[recipe.Recipe], model_name: str) -> None:
"""Check that CPU offloading runs and has expected memory usage."""
gc.collect()
model_cls = model_types[model_key]
models_list = [model_cls() for _ in range(NUM_LAYERS)]
if model_key in ["multihead_attention", "transformer_layer"]:
# Construct model
modules_list = [model_types[model_name]() for _ in range(NUM_LAYERS)]
if model_name in ["multihead_attention", "transformer_layer"]:
available_backends, *_ = get_available_attention_backends(
model_config["small"],
qkv_dtype=torch.bfloat16,
......@@ -138,20 +196,18 @@ def test_cpu_offload(fp8_recipe, model_key) -> None:
os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
without_offloading = _measure_memory_between_forward_and_backward(
models_list, fp8_recipe, False
)
without_offloading_one_layer = _measure_memory_between_forward_and_backward(
models_list[:1], fp8_recipe, False
)
with_offloading = _measure_memory_between_forward_and_backward(models_list, fp8_recipe, True)
# Warmup
_warmup_model(modules_list, quantization_recipe)
assert with_offloading < without_offloading
# Measure cached memory after forward pass
memory_without_offload = _measure_cached_memory(modules_list, quantization_recipe, False)
memory_with_offload = _measure_cached_memory(modules_list, quantization_recipe, True)
# The only difference between the memory consumption of with_offloading
# and without_offloading_one_layer should be the size of the FP8 weights cache,
# which is not offloaded to the CPU.
memory_consumption_diff = abs(with_offloading - without_offloading_one_layer)
assert (
memory_consumption_diff < _get_fp8_weight_cache_size(models_list[1:], fp8_recipe) + EPSILON
# Check for expected memory usage
assert memory_with_offload < memory_without_offload
memory_from_cached_weights = _estimate_cached_weight_size(
model_name,
modules_list,
quantization_recipe,
)
assert abs(memory_with_offload - memory_from_cached_weights) < EPSILON
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from typing import Callable, Tuple, Union
from typing import Callable, Tuple, Union, List
import math
import torch
import pytest
from transformer_engine.pytorch.attention.rope import (
RotaryPositionEmbedding,
apply_rotary_pos_emb,
apply_fused_qkv_rotary_pos_emb,
)
# Gradient is a broadcasted scalar
def _overlapping_grad(output: torch.Tensor) -> torch.Tensor:
def _overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor:
if isinstance(output, List):
return sum(t.sum() * 2 for t in output)
else:
return output.sum() * 2
# Gradient is a full tensor
def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
def _non_overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor:
if isinstance(output, List):
return sum(torch.sum(t * torch.ones_like(t)) for t in output)
else:
t = torch.ones_like(output)
return torch.sum(output * t)
......@@ -238,3 +245,131 @@ def test_fused_rope_thd(
torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous()
@pytest.mark.parametrize("start_positions", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("seq_length", [2, 8, 2048, 4096])
@pytest.mark.parametrize("hidden_size", [64, 128, 256])
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
@pytest.mark.parametrize("margin", [0, 10])
@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
@pytest.mark.parametrize("cp_size", [1, 2])
@pytest.mark.parametrize("interleaved", [True, False])
def test_fused_qkv_rope(
dtype: torch.dtype,
seq_length: int,
hidden_size: int,
rotary_percent: float,
margin: int,
tensor_format: str,
loss_func: Callable,
cp_size: int,
interleaved: bool,
start_positions: bool,
) -> None:
if margin == 0 and start_positions == True:
# This makes sure that the `start_positions` offsets being applied
# are with the maximum length of the rope embeddings.
pytest.skip("Skipping test with margin=0 and start_positions=True")
if start_positions == True and cp_size > 1:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest.skip("Skipping test with cp_size>1 and start_positions=True")
if seq_length - margin < 0:
pytest.skip("Skipping test with seq_length - margin < 0")
device = torch.device("cuda:0")
batch_size, head_num = 2, 64
t = torch.rand(
(seq_length - margin, batch_size, head_num, hidden_size * 6),
dtype=dtype,
device=device,
)
# Get arbitrary offsets to be used with RoPE for all the sequences
start_positions = (
torch.randint(0, margin, (batch_size,), dtype=torch.int32, device=device)
if start_positions
else None
)
if tensor_format == "bshd":
t = t.transpose(0, 1).contiguous()
t.requires_grad = True
rotary_pos_emb_q = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
emb_q = rotary_pos_emb_q(seq_length * cp_size)
rotary_pos_emb_k = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
emb_k = rotary_pos_emb_k(seq_length * cp_size)
for cp_rank in range(cp_size):
# unfused
# The fused kernel computes in float32 internally, so we force the unfused func to use float32
# for more accurate comparison
t_clone = t.clone()
(query, key, value) = torch.split(
t_clone, [hidden_size * 4, hidden_size, hidden_size], dim=3
)
query = query.reshape(query.shape[0], query.shape[1], head_num * 4, hidden_size)
query_unfused = apply_rotary_pos_emb(
query,
emb_q,
tensor_format=tensor_format,
start_positions=start_positions,
interleaved=interleaved,
fused=True,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)
key_unfused = apply_rotary_pos_emb(
key,
emb_k,
tensor_format=tensor_format,
start_positions=start_positions,
interleaved=interleaved,
fused=True,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)
value_unfused = value
loss_unfused = loss_func([query_unfused, key_unfused, value_unfused])
if not isinstance(start_positions, torch.Tensor):
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None
# fused
query_fused, key_fused, value_fused = apply_fused_qkv_rotary_pos_emb(
t,
emb_q,
emb_k,
tensor_format=tensor_format,
start_positions=start_positions,
interleaved=interleaved,
cp_size=cp_size,
cp_rank=cp_rank,
qkv_split_arg_list=[hidden_size * 4, hidden_size, hidden_size],
)
loss_fused = loss_func([query_fused, key_fused, value_fused])
if not isinstance(start_positions, torch.Tensor):
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None
torch.testing.assert_close(query_fused, query_unfused)
torch.testing.assert_close(key_fused, key_unfused)
torch.testing.assert_close(value_fused, value_unfused)
if not isinstance(start_positions, torch.Tensor):
torch.testing.assert_close(grad_fused, grad_unfused)
......@@ -22,6 +22,7 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops.fused import (
BackwardActivationBias,
BackwardAddRMSNorm,
BackwardLinearAdd,
BackwardLinearScale,
ForwardLinearBiasActivation,
......@@ -1545,7 +1546,10 @@ class TestBasicOps:
torch.testing.assert_close(y2_test, y2_ref, rtol=0, atol=0)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu"))
@pytest.mark.parametrize(
"activation",
("gelu", "geglu", "qgelu", "qgeglu", "relu", "reglu", "srelu", "sreglu", "silu", "swiglu"),
)
@pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
......@@ -1564,7 +1568,7 @@ class TestBasicOps:
# Tensor dimensions
in_shape = list(out_shape)
if activation in ("geglu", "reglu", "swiglu"):
if activation in ("geglu", "qgeglu", "reglu", "sreglu", "swiglu"):
in_shape[-1] *= 2
# Skip invalid configurations
......@@ -1591,14 +1595,26 @@ class TestBasicOps:
y_ref: torch.Tensor
if activation == "gelu":
y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh")
elif activation == "relu":
y_ref = torch.nn.functional.relu(x_ref)
elif activation == "geglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.gelu(x1, approximate="tanh") * x2
elif activation == "qgelu":
y_ref = x_ref * torch.sigmoid(1.702 * x_ref)
elif activation == "qgeglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = x1 * torch.sigmoid(1.702 * x1) * x2
elif activation == "relu":
y_ref = torch.nn.functional.relu(x_ref)
elif activation == "reglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.relu(x1) * x2
elif activation == "srelu":
y_ref = torch.nn.functional.relu(x_ref) ** 2
elif activation == "sreglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.relu(x1) ** 2 * x2
elif activation == "silu":
y_ref = torch.nn.functional.silu(x_ref)
elif activation == "swiglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.silu(x1) * x2
......@@ -1610,9 +1626,14 @@ class TestBasicOps:
recipe = make_recipe(quantization)
make_op = dict(
gelu=te_ops.GELU,
relu=te_ops.ReLU,
geglu=te_ops.GEGLU,
qgelu=te_ops.QGELU,
qgeglu=te_ops.QGEGLU,
relu=te_ops.ReLU,
reglu=te_ops.ReGLU,
srelu=te_ops.SReLU,
sreglu=te_ops.SReGLU,
silu=te_ops.SiLU,
swiglu=te_ops.SwiGLU,
)[activation]
forward = te_ops.Sequential(
......@@ -1742,25 +1763,44 @@ class TestBasicOps:
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("prob", (0.1, 0.5, 0.75))
@pytest.mark.parametrize("prob", (0.0625, 0.5, 0.75))
@pytest.mark.parametrize("is_training", (True, False))
@pytest.mark.parametrize("shape", ((101,), (2, 4, 16)))
@pytest.mark.parametrize("quantization", (None, "fp8_current_scaling"))
@pytest.mark.parametrize("shape", ((101,), (2, 4, 16), (128, 128)))
@pytest.mark.parametrize("dtype", _dtypes)
def test_dropout(
self,
*,
prob: float,
is_training: bool,
quantization: Optional[str],
shape: Iterable[int],
dtype: torch.dtype,
device: torch.device = "cuda",
):
# Skip invalid configurations
quantized_input = quantization is not None
maybe_skip_quantization(quantization, dims=shape, device=device)
# Random data
x_ref = torch.rand(shape, dtype=dtype, device=device) + 0.5
x_test = x_ref.clone().requires_grad_()
dy_ref = torch.rand(shape, dtype=dtype, device=device) + 0.5
dy_test = dy_ref.clone()
# Note: Shift values to make sure inputs are non-zero
x_ref, x_test = make_reference_and_test_tensors(
shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_quantized=quantized_input,
)
with torch.no_grad():
x_test += 1
x_ref.copy_(x_test)
dy_ref, dy_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Apply dropout
op = te_ops.Dropout(prob)
......@@ -1768,17 +1808,20 @@ class TestBasicOps:
op.train()
else:
op.eval()
y = op(x_test)
y.backward(dy_test)
y_test = op(x_test)
y_test.backward(dy_test)
# Check values
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
if is_training:
mask = ((y != 0) / (1 - prob)).to(dtype=dtype)
torch.testing.assert_close(y, x_ref * mask)
torch.testing.assert_close(x_test.grad, dy_ref * mask)
tols = dtype_tols(dtype)
mask = ((y_test != 0) / (1 - prob)).to(dtype=dtype)
torch.testing.assert_close(y_test, x_ref * mask, **tols)
torch.testing.assert_close(dx_test, dy_ref * mask, **tols)
else:
torch.testing.assert_close(y, x_ref, rtol=0, atol=0)
torch.testing.assert_close(x_test.grad, dy_ref, rtol=0, atol=0)
torch.testing.assert_close(y_test, x_ref, rtol=0, atol=0)
torch.testing.assert_close(dx_test, dy_ref, rtol=0, atol=0)
# Hypothesis testing for number of zeros
# Note: A Bernoulli random variable with probability p has
......@@ -1790,9 +1833,11 @@ class TestBasicOps:
# p-value is less than 1% and we assume that the dropout
# distribution is incorrect.
if is_training:
prob_observed = 1 - torch.count_nonzero(y).item() / y.numel()
z_score = (prob_observed - prob) / math.sqrt(prob * (1 - prob) / y.numel())
assert abs(z_score) < 2.5758, "Number of zeros is outside 99% confidence interval"
prob_observed = 1 - torch.count_nonzero(y_test).item() / y_test.numel()
z_score = (prob_observed - prob) / math.sqrt(prob * (1 - prob) / y_test.numel())
assert (
abs(z_score) < 2.5758
), f"Number of zeros is outside 99% confidence interval ({prob=}, {prob_observed=})"
class TestFusedOps:
......@@ -2220,6 +2265,94 @@ class TestFusedOps:
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(db_test, b_ref.grad, **tols)
@pytest.mark.parametrize("weight_shape", ((19,), (64,)))
@pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("zero_centered_gamma", (False, True))
def test_backward_add_rmsnorm(
self,
*,
weight_shape: Iterable[int],
in_shape: Iterable[int],
dtype: torch.dtype,
device: torch.device = "cuda",
eps: float = 0.3,
zero_centered_gamma: bool,
) -> None:
"""Fused backward RMNorm + add"""
# Make input and weight shapes consistent
in_shape = list(in_shape)[:-1] + list(weight_shape)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
weight_shape,
test_dtype=dtype,
test_device=device,
)
dy1_ref, dy1_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
dy2_ref, dy2_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
inner_dims = tuple(range(len(in_shape) - len(weight_shape), len(in_shape)))
var_ref = x_ref.square().sum(dim=inner_dims, keepdim=True) / math.prod(weight_shape)
if zero_centered_gamma:
y1_ref = x_ref / torch.sqrt(eps + var_ref) * (1 + w_ref)
else:
y1_ref = x_ref / torch.sqrt(eps + var_ref) * w_ref
y2_ref = x_ref
(y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward()
# Implementation with fusible operations
model = te_ops.Sequential(
te_ops.MakeExtraOutput(),
te_ops.RMSNorm(
weight_shape,
eps=eps,
device=device,
dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
),
)
with torch.no_grad():
model[1].weight.copy_(w_test)
del w_test
y1_test, y2_test = model(x_test)
(y1_test * dy1_test + y2_test * dy2_test).sum().backward()
# Check that backward operations have been fused
backward_ops = model._module_groups[0]._backward_ops
assert len(backward_ops) == 1
assert isinstance(backward_ops[0][0], BackwardAddRMSNorm)
# Expected numerical error
tols = dtype_tols(dtype)
# Check results
y1_test = y1_test.to(dtype=torch.float64, device="cpu")
y2_test = y2_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y1_test, y1_ref, **tols)
torch.testing.assert_close(y2_test, y2_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
def test_backward_linear_add(
......
......@@ -41,16 +41,21 @@ from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.common import recipe
import transformer_engine_torch as tex
from utils import ModelConfig, reset_rng_states, get_available_attention_backends
# Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
sm_80plus = get_device_compute_capability() >= (8, 0)
......@@ -84,7 +89,18 @@ batch_sizes = [1, 2]
all_boolean = [True, False]
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"]
all_activations = [
"gelu",
"geglu",
"qgelu",
"qgeglu",
"relu",
"reglu",
"srelu",
"sreglu",
"silu",
"swiglu",
]
all_normalizations = ["LayerNorm", "RMSNorm"]
......@@ -114,15 +130,25 @@ if fp8_available:
fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(recipe.DelayedScaling())
use_cutlass_grouped_gemm = [False]
# Only enable cutlass grouped gemm on Hopper
if torch.cuda.get_device_capability() == (9, 0):
use_cutlass_grouped_gemm.append(True)
def is_fused_attn_available(
config: ModelConfig, dtype: torch.dtype, qkv_layout="bshd_bshd_bshd", is_training=True
config: ModelConfig,
dtype: torch.dtype,
qkv_layout="bshd_bshd_bshd",
is_training=True,
deterministic=False,
):
_, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
is_training=is_training,
deterministic=deterministic,
)
return FusedAttnBackend["F16_arbitrary_seqlen"] in fused_attn_backends
......@@ -432,13 +458,16 @@ class TorchGroupedLinearWithPadding(nn.Module):
_supported_act = {
"geglu": nn.GELU(approximate="tanh"),
"gelu": nn.GELU(approximate="tanh"),
"reglu": nn.ReLU(),
"relu": nn.ReLU(),
"swiglu": nn.SiLU(),
"geglu": nn.GELU(approximate="tanh"),
"qgelu": TorchQuickGELU(),
"qgeglu": TorchQuickGELU(),
"relu": nn.ReLU(),
"reglu": nn.ReLU(),
"srelu": TorchSquaredRELU(),
"sreglu": TorchSquaredRELU(),
"silu": nn.SiLU(),
"swiglu": nn.SiLU(),
}
......@@ -830,7 +859,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
@pytest.mark.parametrize("model", ["126m"])
def test_gpt_checkpointing(dtype, bs, model):
config = model_configs[model]
if not is_fused_attn_available(config, dtype):
if not is_fused_attn_available(config, dtype, deterministic=True):
pytest.skip("No attention backend available.")
outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
......@@ -878,7 +907,9 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
config = model_configs[model]
if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False):
if not is_fused_attn_available(
config, dtype, qkv_layout="sb3hd", is_training=True, deterministic=True
):
pytest.skip("No attention backend available.")
te_gpt = TransformerLayer(
......@@ -991,7 +1022,9 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
@pytest.mark.parametrize("mask_type", mask_types)
def test_mha_accuracy(dtype, bs, model, mask_type):
config = model_configs[model]
if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False):
if not is_fused_attn_available(
config, dtype, qkv_layout="sb3hd", is_training=True, deterministic=True
):
pytest.skip("No attention backend available.")
te_mha = MultiheadAttention(
......@@ -1782,6 +1815,7 @@ def test_grouped_linear_accuracy(
bias,
delay_wgrad_compute,
parallel_mode=None,
use_cutlass=False,
):
fp8 = recipe is not None
if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
......@@ -1853,11 +1887,49 @@ def test_grouped_linear_accuracy(
delay_wgrad_compute,
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
for o, o_ref in zip(outputs, outputs_ref):
if use_cutlass:
torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3)
else:
# cuBLAS implementation should be bit-wise match
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.skipif(
torch.cuda.get_device_capability() != (9, 0),
reason="Only enable CUTLASS grouped gemm on Hopper",
)
@pytest.mark.parametrize("dtype", param_types, ids=str)
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
@pytest.mark.parametrize("delay_wgrad_compute", all_boolean)
def test_grouped_linear_accuracy_cutlass(
dtype,
num_gemms,
bs,
model,
fuse_wgrad_accumulation,
delay_wgrad_compute,
):
os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1"
test_grouped_linear_accuracy(
dtype,
num_gemms,
bs,
model,
None,
False,
fuse_wgrad_accumulation,
False,
delay_wgrad_compute,
None,
use_cutlass=True,
)
os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)
@pytest.mark.parametrize("dtype", param_types, ids=str)
@pytest.mark.parametrize("num_gemms", [3])
@pytest.mark.parametrize("bs", [1])
......@@ -2525,10 +2597,11 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
(16, 10027, 128, 512),
],
)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("dtype", param_types, ids=str)
@pytest.mark.parametrize("layout", ["TN", "NN", "NT"])
@pytest.mark.parametrize("accumulate", [False, True])
def test_grouped_gemm(shape, dtype, layout, accumulate):
@pytest.mark.parametrize("use_cutlass", use_cutlass_grouped_gemm)
def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
torch.manual_seed(0)
z, m, k, n = shape
......@@ -2563,6 +2636,9 @@ def test_grouped_gemm(shape, dtype, layout, accumulate):
grad = True
single_output = False
if use_cutlass:
os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1"
# Force the sequential_linear and grouped_linear to use hipblaslt rather than hipblas
if IS_HIP_EXTENSION:
ori_force_rocm_gemm = os.environ.get("NVTE_FORCE_ROCM_GEMM", None)
......@@ -2600,9 +2676,82 @@ def test_grouped_gemm(shape, dtype, layout, accumulate):
else:
del os.environ["NVTE_FORCE_ROCM_GEMM"]
# should be bit-wise match
for o, o_ref in zip(out, out_ref):
if not use_cutlass:
# cublas implementation should be bit-wise match
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
else:
torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2)
if use_cutlass:
os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)
@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
"input_quantizer",
[
Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda"),
MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3),
],
)
@pytest.mark.parametrize(
"out_quantizer",
[
Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda"),
MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3),
Float8Quantizer(
torch.ones(1).cuda().squeeze(), torch.ones(1).cuda().squeeze(), tex.DType.kFloat8E4M3
),
],
)
def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_quantizer):
# For MXFP8 and CurrentScaling, below unfused quantization should happen
# FP8 input --> cublas GEMM --> BF16 output --> Quantize to FP8 --> fp8 Output
# Skip invalid configurations
is_mxfp8_needed = isinstance(input_quantizer, MXFP8Quantizer) or isinstance(
out_quantizer, MXFP8Quantizer
)
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if is_mxfp8_needed and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
inp_fp8 = input_quantizer(torch.randn(N, N, device="cuda", dtype=datatype))
weight_fp8 = input_quantizer(torch.randn(N, N, device="cuda", dtype=datatype))
outp_type = torch.float32
quantized_out, *_ = general_gemm(
weight_fp8,
inp_fp8,
get_workspace(),
outp_type,
quantization_params=out_quantizer,
bias=None,
use_split_accumulator=False,
)
out, *_ = general_gemm(
weight_fp8,
inp_fp8,
get_workspace(),
outp_type,
quantization_params=None,
bias=None,
use_split_accumulator=False,
)
expected_quantized_out = out_quantizer(out)
# Match results again Pytorch GEMM and allow for quantization tolerance
pytorch_out = torch.matmul(
inp_fp8.dequantize().to(torch.float64),
torch.transpose(weight_fp8.dequantize().to(torch.float64), 0, 1),
)
fp8_tols = dict(rtol=0.125, atol=0.0675)
torch.testing.assert_close(
pytorch_out.to(outp_type), expected_quantized_out.dequantize(), **fp8_tols
)
# Match results between quantization happening inside vs outside general_gemm
torch.testing.assert_close(expected_quantized_out.dequantize(), quantized_out.dequantize())
@pytest.mark.parametrize(
......
......@@ -36,6 +36,7 @@ import transformer_engine_torch as tex
from transformer_engine.pytorch.export import is_in_onnx_export_mode, te_translation_table
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import get_default_init_method
import tensorrt as trt
# Global test configuration knobs.
......@@ -64,6 +65,7 @@ if mxfp8_available:
fp8_recipes.append(recipe.MXFP8BlockScaling())
if fp8_available:
fp8_recipes.append(recipe.DelayedScaling())
fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(None)
supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
......@@ -80,11 +82,11 @@ all_normalizations = ["LayerNorm", "RMSNorm"]
],
outputs=[PyCustomOpDef.dt_uint8],
)
def trt_fp8_quantize(t, scale):
def trt_fp8_quantize(t, scale_inv):
"""FP8 quantization extension for ONNX Runtime."""
x = torch.from_numpy(t).cuda()
q = te.tensor.float8_tensor.Float8Quantizer(
scale=1 / torch.from_numpy(scale).cuda(),
scale=1 / torch.from_numpy(scale_inv).cuda(),
amax=torch.zeros([1]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
)
......@@ -100,11 +102,11 @@ def trt_fp8_quantize(t, scale):
],
outputs=[PyCustomOpDef.dt_float],
)
def trt_fp8_dequantize(t, scale):
def trt_fp8_dequantize(t, scale_inv):
"""FP8 dequantization extension for ONNX Runtime."""
x = torch.from_numpy(t).cuda()
q = te.tensor.float8_tensor.Float8Quantizer(
scale=1 / torch.from_numpy(scale).cuda(),
scale=1 / torch.from_numpy(scale_inv).cuda(),
amax=torch.zeros([1]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
)
......@@ -113,7 +115,7 @@ def trt_fp8_dequantize(t, scale):
@onnx_op(
op_type="trt::TRT_MXFP8QuantizeLinear",
op_type="trt::TRT_MXFP8DynamicQuantize",
domain="trt",
inputs=[
PyCustomOpDef.dt_float,
......@@ -592,7 +594,9 @@ def _test_export_layernorm_linear(
fname,
inp,
model,
atol=1e-3,
# For current scaling we use Float8Quantizer in tests + amax computed by hand,
# which has slightly different numerics than Float8CurrentScalingQuantizer.
atol=1e-3 if fp8_recipe.__class__ is not recipe.Float8CurrentScaling else 2e-2,
is_fp8=fp8_recipe is not None,
te_outputs=te_outputs,
)
......@@ -1139,3 +1143,64 @@ def test_export_ctx_manager(enabled):
with te.onnx_export(enabled):
assert is_in_onnx_export_mode() == enabled
assert is_in_onnx_export_mode() == False
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
def test_trt_integration(fp8_recipe: recipe.Recipe):
model = te.TransformerLayer(
hidden_size=128,
ffn_hidden_size=128,
num_attention_heads=4,
).eval()
if type(fp8_recipe) == recipe.Float8CurrentScaling:
# TODO(pgadzinski): Attention does not work with TRT for FP8CurrentScaling
model = te.LayerNormMLP(128, 128)
inps = (torch.randn([16, 16, 128], device="cuda", requires_grad=False),)
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe):
out_ref = model(*inps)
onnx_fd, onnx_path = tempfile.mkstemp(suffix=".onnx")
os.close(onnx_fd)
try:
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe):
with te.onnx_export(enabled=True):
torch.onnx.export(
model,
inps,
onnx_path,
output_names=["output"],
dynamo=True,
custom_translation_table=te_translation_table,
)
os.system(f"trtexec --onnx={onnx_path} --saveEngine={onnx_path}.engine")
# Run TRT engine
logger = trt.Logger(trt.Logger.WARNING)
runtime = trt.Runtime(logger)
with open(onnx_path + ".engine", "rb") as f:
engine_data = f.read()
engine = runtime.deserialize_cuda_engine(engine_data)
context = engine.create_execution_context()
context.set_tensor_address(engine.get_tensor_name(0), inps[0].data_ptr())
stream = torch.cuda.Stream()
out = torch.zeros_like(out_ref)
context.set_tensor_address("output", out.data_ptr())
context.execute_async_v3(stream_handle=stream.cuda_stream)
stream.synchronize()
# Compare TRT and TE outputs
atol = 5e-2 if fp8_recipe is not None else 1e-4
rtol = 5e-2 if fp8_recipe is not None else 1e-4
torch.testing.assert_close(out, out_ref, atol=atol, rtol=rtol)
finally:
try:
os.remove(onnx_path)
except FileNotFoundError:
pass
......@@ -6,6 +6,8 @@ import random
import torch
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy
from utils import dtype_tols
class TestParallelCrossEntropy:
......@@ -18,19 +20,25 @@ class TestParallelCrossEntropy:
label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none"
)
def generate_input(self, dtype: torch.dtype, swap_dim: bool, ignore_idx: bool):
def generate_input(
self,
dtype: torch.dtype,
swap_dim: bool,
ignore_idx: bool,
device: torch.device = "cuda",
):
SQ = random.choice([64, 128])
batch = random.choice([1, 2])
vocab = random.choice([64000, 128000])
ignore = random.sample(range(0, SQ - 1), 5)
# Generate random data
if swap_dim:
self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype).cuda()
self.tar_test = torch.randint(0, vocab, (SQ, batch)).cuda()
self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype, device=device)
self.tar_test = torch.randint(0, vocab, (SQ, batch), device=device)
else:
self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype).cuda()
self.tar_test = torch.randint(0, vocab, (batch, SQ)).cuda()
self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype, device=device)
self.tar_test = torch.randint(0, vocab, (batch, SQ), device=device)
if ignore_idx:
for i in ignore:
......@@ -40,9 +48,14 @@ class TestParallelCrossEntropy:
else:
self.tar_test[0][i] = -100
# Make copy of data for reference implementation
self.input_ref = torch.reshape(self.input_test.clone().detach(), (batch * SQ, vocab))
self.tar_ref = torch.reshape(self.tar_test.clone().detach(), (batch * SQ,))
# Enable autograd
self.input_test.requires_grad_()
self.input_ref.requires_grad_()
def one_iteration_test(
self,
dtype: torch.dtype,
......@@ -52,18 +65,20 @@ class TestParallelCrossEntropy:
ignore_idx: bool = False,
):
# Random data
self.generate_input(dtype, swap_dim, ignore_idx)
self.input_test.requires_grad_(True)
self.input_ref.requires_grad_(True)
# Forward pass
test_loss = self.test_loss_func(
self.input_test, self.tar_test, label_smoothing, reduce_loss, None
)
ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref)
# Handle backward pass based on the test scenario
# Compute square to avoid trivial backward pass
test_loss = torch.square(test_loss)
ref_loss = torch.square(ref_loss)
# Backward pass
if reduce_loss:
test_loss.backward()
ref_loss.backward()
......@@ -71,16 +86,18 @@ class TestParallelCrossEntropy:
test_loss.sum().backward()
ref_loss.sum().backward()
test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss
if ignore_idx:
print(test_loss, ref_loss)
# Compare gradients when backward pass was called
torch.testing.assert_close(
torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad
)
# Check that loss and grad input match
tols = dtype_tols(dtype)
test_loss = test_loss.to(dtype=torch.float64, device="cpu")
ref_loss = test_loss.to(dtype=torch.float64, device="cpu")
ref_loss = ref_loss.reshape(test_loss.size())
test_grad_input = self.input_test.grad.to(dtype=torch.float64, device="cpu")
ref_grad_input = self.input_ref.grad.to(dtype=torch.float64, device="cpu")
ref_grad_input = ref_grad_input.reshape(test_grad_input.size())
torch.testing.assert_close(test_loss, ref_loss, **tols)
torch.testing.assert_close(test_grad_input, ref_grad_input, **tols)
# Reset data
self.input_test = None
self.input_ref = None
self.tar_test = None
......
......@@ -105,7 +105,18 @@ if is_bf16_compatible(): # bf16 requires sm_80 or higher
all_boolean = [True, False]
batch_sizes_with_zero = [0, 1, 2]
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "srelu", "qgelu", "qgeglu"]
all_activations = [
"gelu",
"geglu",
"qgelu",
"qgeglu",
"relu",
"reglu",
"srelu",
"sreglu",
"silu",
"swiglu",
]
all_normalizations = ["LayerNorm", "RMSNorm"]
......
......@@ -266,8 +266,8 @@ def get_available_attention_backends(
)
(
use_flash_attention,
use_fused_attention,
flash_attention_backend,
use_fused_attention,
fused_attention_backend,
use_unfused_attention,
available_backends,
......
......@@ -102,6 +102,11 @@ if(USE_ROCM)
message(STATUS "USE_HIPBLASLT ${USE_HIPBLASLT} USE_ROCBLAS ${USE_ROCBLAS}")
endif()
set(CUTLASS_INCLUDE_DIR
"${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cutlass/include")
set(CUTLASS_TOOLS_INCLUDE_DIR
"${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cutlass/tools/util/include")
# Python
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
......@@ -128,6 +133,7 @@ if(USE_CUDA)
transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
activation/gelu.cu
dropout/dropout.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
......@@ -139,6 +145,7 @@ if(USE_CUDA)
fused_attn/fused_attn.cpp
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
gemm/cutlass_grouped_gemm.cu
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
......@@ -169,6 +176,10 @@ if(USE_CUDA)
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/comm_gemm_overlap.cpp)
if (NVTE_WITH_CUBLASMP)
list(APPEND transformer_engine_SOURCES
comm_gemm/comm_gemm.cpp)
endif()
add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
else()
list(APPEND transformer_engine_SOURCES
......@@ -192,10 +203,12 @@ else()
transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
activation/gelu.cu
dropout/dropout.cu
activation/relu.cu
activation/swiglu.cu
gemm/cublaslt_gemm.cu
gemm/hipblas_gemm.cu
gemm/cutlass_grouped_gemm.cu
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
......@@ -226,6 +239,10 @@ else()
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/comm_gemm_overlap.cpp)
if (NVTE_WITH_CUBLASMP)
list(APPEND transformer_engine_SOURCES
comm_gemm/comm_gemm.cpp)
endif()
# process source code files
message("${message_line}")
message(STATUS "CMAKE_CURRENT_SOURCE_DIR: ${CMAKE_CURRENT_SOURCE_DIR}")
......@@ -272,7 +289,12 @@ if (USE_CUDA)
CUDNN::cudnn_all)
target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine SYSTEM PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl)
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
target_include_directories(transformer_engine PRIVATE
${CUTLASS_INCLUDE_DIR}
${CUTLASS_TOOLS_INCLUDE_DIR})
else()
target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}")
# Aotriton is currently unsupported
......@@ -313,11 +335,23 @@ if (NVTE_ENABLE_NVSHMEM)
target_include_directories(transformer_engine PUBLIC ${NVSHMEMAPI_INCLUDE_DIR})
endif()
option(NVTE_ENABLE_NVSHMEM "Compile with NVSHMEM library" OFF)
if (NVTE_ENABLE_NVSHMEM)
add_subdirectory(nvshmem_api)
target_link_libraries(transformer_engine PUBLIC nvshmemapi)
target_include_directories(transformer_engine PUBLIC ${NVSHMEMAPI_INCLUDE_DIR})
option(NVTE_WITH_CUBLASMP "Use cuBLASMp for tensor parallel GEMMs" OFF)
if (NVTE_WITH_CUBLASMP)
target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUBLASMP)
target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include ${NVSHMEM_DIR}/include)
find_library(CUBLASMP_LIB
NAMES cublasmp libcublasmp
PATHS ${CUBLASMP_DIR}
PATH_SUFFIXES lib
REQUIRED)
find_library(NVSHMEM_HOST_LIB
NAMES nvshmem_host libnvshmem_host.so.3
PATHS ${NVSHMEM_DIR}
PATH_SUFFIXES lib
REQUIRED)
target_link_libraries(transformer_engine PUBLIC ${CUBLASMP_LIB} ${NVSHMEM_HOST_LIB})
message(STATUS "Using cuBLASMp at: ${CUBLASMP_DIR}")
message(STATUS "Using nvshmem at: ${NVSHMEM_DIR}")
endif()
if (USE_CUDA)
......
......@@ -218,6 +218,11 @@ def _nvidia_cudart_include_dir() -> str:
except ModuleNotFoundError:
return ""
# Installing some nvidia-* packages, like nvshmem, create nvidia name, so "import nvidia"
# above doesn't through. However, they don't set "__file__" attribute.
if nvidia.__file__ is None:
return ""
include_dir = Path(nvidia.__file__).parent / "cuda_runtime"
return str(include_dir) if include_dir.exists() else ""
......@@ -295,6 +300,38 @@ def _load_nvrtc():
return ctypes.CDLL(f"libnvrtc{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
@functools.lru_cache(maxsize=None)
def _load_curand():
"""Load cuRAND shared library."""
# Attempt to locate cuRAND in CUDA_HOME, CUDA_PATH or /usr/local/cuda
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
libs = glob.glob(f"{cuda_home}/**/libcurand{_get_sys_extension()}*", recursive=True)
libs = list(filter(lambda x: not ("stub" in x), libs))
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate cuRAND in Python dist-packages
found, handle = _load_nvidia_cuda_library("curand")
if found:
return handle
# Attempt to locate cuRAND via ldconfig
libs = subprocess.check_output(
f"ldconfig -p | grep 'libcurand{_get_sys_extension()}'", shell=True
)
libs = libs.decode("utf-8").split("\n")
sos = []
for lib in libs:
if "libcurand" in lib and "=>" in lib:
sos.append(lib.split(">")[1].strip())
if sos:
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libcurand{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
@functools.lru_cache(maxsize=None)
def _load_core_library():
"""Load shared library with Transformer Engine C extensions"""
......@@ -305,6 +342,7 @@ if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE
try:
_CUDNN_LIB_CTYPES = _load_cudnn()
_NVRTC_LIB_CTYPES = _load_nvrtc()
_CURAND_LIB_CTYPES = _load_curand()
_CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas")
_CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime")
# Needed to find the correct headers for NVRTC kernels.
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "transformer_engine/comm_gemm.h"
#include <cublasmp.h>
#include <cuda_runtime.h>
#include <nvshmem.h>
#include <map>
#include <memory>
#include <string>
#include <tuple>
#include <type_traits>
#include <unordered_map>
#include <utility>
#include <vector>
#include "../common.h"
#include "../util/logging.h"
using namespace transformer_engine;
namespace {
// TODO: log warnings on failures of the *Destroy calls below, once TE has such ability.
// For now, just silently ignoring the errors, since the only diag available in TE is throwing
// exceptions, but these calls will typically be made from destructors, so cannot throw.
template <typename HandlePtr, typename CreateFn, typename DestroyFn, typename... Args>
auto CreateWithCudaCheck(CreateFn create_fn, DestroyFn destroy_fn, Args&&... args) {
using Handle = std::remove_pointer_t<HandlePtr>;
HandlePtr raw{};
NVTE_CHECK_CUDA(create_fn(&raw, std::forward<Args>(args)...));
return std::unique_ptr<Handle, DestroyFn>(raw, destroy_fn);
}
using CudaStream =
std::unique_ptr<std::remove_pointer_t<cudaStream_t>, decltype(&cudaStreamDestroy)>;
CudaStream CudaStreamCreate() {
return CreateWithCudaCheck<cudaStream_t>(cudaStreamCreate, cudaStreamDestroy);
}
using CudaEvent = std::unique_ptr<std::remove_pointer_t<cudaEvent_t>, decltype(&cudaEventDestroy)>;
CudaEvent CudaEventCreate(unsigned flags) {
return CreateWithCudaCheck<cudaEvent_t>(cudaEventCreateWithFlags, cudaEventDestroy, flags);
}
template <bool raw_last, typename HandlePtr, typename CreateFn, typename DestroyFn,
typename... Args>
auto CreateWithCublasMpCheck(CreateFn create_fn, DestroyFn destroy_fn, Args&&... args) {
using Handle = std::remove_pointer_t<HandlePtr>;
HandlePtr raw{};
if constexpr (raw_last) {
NVTE_CHECK_CUBLASMP(create_fn(std::forward<Args>(args)..., &raw));
} else {
NVTE_CHECK_CUBLASMP(create_fn(&raw, std::forward<Args>(args)...));
}
return std::unique_ptr<Handle, DestroyFn>(raw, destroy_fn);
}
using CublasMp =
std::unique_ptr<std::remove_pointer_t<cublasMpHandle_t>, decltype(&cublasMpDestroy)>;
CublasMp CublasMpCreate(cudaStream_t stream) {
return CreateWithCublasMpCheck<false, cublasMpHandle_t>(cublasMpCreate, cublasMpDestroy, stream);
}
using CublasMpGrid =
std::unique_ptr<std::remove_pointer_t<cublasMpGrid_t>, decltype(&cublasMpGridDestroy)>;
CublasMpGrid CublasMpGridCreate(int64_t nprow, int64_t npcol, cublasMpGridLayout_t layout,
ncclComm_t comm) {
return CreateWithCublasMpCheck<true, cublasMpGrid_t>(cublasMpGridCreate, cublasMpGridDestroy,
nprow, npcol, layout, comm);
}
using CublasMpMatrixDesc = std::unique_ptr<std::remove_pointer_t<cublasMpMatrixDescriptor_t>,
decltype(&cublasMpMatrixDescriptorDestroy)>;
CublasMpMatrixDesc CublasMpMatrixDescCreate(int64_t m, int64_t n, int64_t mb, int64_t nb,
int64_t rsrc, int64_t csrc, int64_t lld,
cudaDataType_t type, cublasMpGrid_t grid) {
return CreateWithCublasMpCheck<true, cublasMpMatrixDescriptor_t>(
cublasMpMatrixDescriptorCreate, cublasMpMatrixDescriptorDestroy, m, n, mb, nb, rsrc, csrc,
lld, type, grid);
}
using CublasMpMatmulDesc = std::unique_ptr<std::remove_pointer_t<cublasMpMatmulDescriptor_t>,
decltype(&cublasMpMatmulDescriptorDestroy)>;
CublasMpMatmulDesc CublasMpMatmulDescCreate(cublasComputeType_t compute_type) {
return CreateWithCublasMpCheck<false, cublasMpMatmulDescriptor_t>(
cublasMpMatmulDescriptorCreate, cublasMpMatmulDescriptorDestroy, compute_type);
}
} // namespace
struct NVTECommGemmCtx {
int64_t nranks;
int64_t rank;
ncclComm_t comm;
CudaStream stream;
CudaEvent event;
CublasMp cublas_mp;
CublasMpGrid grid_col_major;
CublasMpGrid grid_row_major;
CublasMpMatrixDesc a_desc;
CublasMpMatrixDesc b_desc;
CublasMpMatrixDesc d_desc;
CublasMpMatmulDesc matmul_desc;
void* workspace;
size_t workspace_size;
};
namespace {
int64_t block_size(NVTECommGemmCtx* ctx, int64_t global_size) {
// Use non-cyclic layout to maximize opportunity for comm overlap.
return (global_size + ctx->nranks - 1) / ctx->nranks;
}
void AgGemmInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n, int64_t k,
const Tensor* a, const Tensor* b, const Tensor* d, bool transa,
bool transb) {
const auto a0 = a->flat_first_dim();
const auto a1 = a->flat_last_dim();
const auto b0 = b->flat_first_dim();
const auto b1 = b->flat_last_dim();
const auto d0 = d->flat_first_dim();
const auto d1 = d->flat_last_dim();
if (transa) {
NVTE_CHECK(a1 == k, "Unsupported tensor dimension in A: expected ", k, ", got ", a1);
NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(k, m, k, block_size(ctx, m), 0, 0, k,
get_cuda_dtype(a->dtype()),
ctx->grid_row_major.get(), ctx->a_desc.get()));
} else {
NVTE_CHECK(a0 == k, "Unsupported tensor dimension in A: expected ", k, ", got ", a0);
NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(m, k, block_size(ctx, m), k, 0, 0,
block_size(ctx, m), get_cuda_dtype(a->dtype()),
ctx->grid_col_major.get(), ctx->a_desc.get()));
}
if (transb) {
NVTE_CHECK(b0 == k, "Unsupported tensor dimensionin B: expected ", k, ", got ", b0);
NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(n, k, block_size(ctx, n), k, 0, 0,
block_size(ctx, n), get_cuda_dtype(b->dtype()),
ctx->grid_col_major.get(), ctx->b_desc.get()));
} else {
NVTE_CHECK(b1 == k, "Unsupported tensor dimension in B: expected ", k, ", got ", b1);
NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(k, n, k, block_size(ctx, n), 0, 0, k,
get_cuda_dtype(b->dtype()),
ctx->grid_row_major.get(), ctx->b_desc.get()));
}
NVTE_CHECK(d0 == n, "Unsupported tensor dimension in D: expected ", n, ", got ", d0);
*ldd = block_size(ctx, m);
NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(m, n, block_size(ctx, m), block_size(ctx, n), 0,
0, *ldd, get_cuda_dtype(d->dtype()),
ctx->grid_col_major.get(), ctx->d_desc.get()));
}
void GemmRsInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n, int64_t k,
const Tensor* a, const Tensor* b, const Tensor* d, bool transa,
bool transb) {
const auto a0 = a->flat_first_dim();
const auto a1 = a->flat_last_dim();
const auto b0 = b->flat_first_dim();
const auto b1 = b->flat_last_dim();
const auto d0 = d->flat_first_dim();
const auto d1 = d->flat_last_dim();
if (transa) {
NVTE_CHECK(a0 == m, "Unsupported tensor dimension in A: expected ", m, ", got ", a0);
NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(k, m, block_size(ctx, k), m, 0, 0,
block_size(ctx, k), get_cuda_dtype(a->dtype()),
ctx->grid_col_major.get(), ctx->a_desc.get()));
} else {
NVTE_CHECK(a1 == m, "Unsupported tensor dimension in A: expected ", m, ", got ", a1);
NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(m, k, m, block_size(ctx, k), 0, 0, m,
get_cuda_dtype(a->dtype()),
ctx->grid_row_major.get(), ctx->a_desc.get()));
}
if (transb) {
NVTE_CHECK(b1 == n, "Unsupported tensor dimension in B: expected ", n, ", got ", b1);
NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(
n, k, block_size(ctx, n), block_size(ctx, k), 0, 0, block_size(ctx, n),
get_cuda_dtype(b->dtype()), ctx->grid_row_major.get(), ctx->b_desc.get()));
} else {
NVTE_CHECK(b0 == n, "Unsupported tensor dimension in B: expected ", n, ", got ", b0);
NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(
k, n, block_size(ctx, k), block_size(ctx, n), 0, 0, block_size(ctx, k),
get_cuda_dtype(b->dtype()), ctx->grid_col_major.get(), ctx->b_desc.get()));
}
NVTE_CHECK(d1 == m, "Unsupported tensor dimension in D: expected ", m, ", got ", d1);
*ldd = m;
NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(m, n, m, block_size(ctx, n), 0, 0, *ldd,
get_cuda_dtype(d->dtype()),
ctx->grid_row_major.get(), ctx->d_desc.get()));
}
void GemmArInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n, int64_t k,
const Tensor* a, const Tensor* b, const Tensor* d, bool transa,
bool transb) {
const auto a0 = a->flat_first_dim();
const auto a1 = a->flat_last_dim();
const auto b0 = b->flat_first_dim();
const auto b1 = b->flat_last_dim();
const auto d0 = d->flat_first_dim();
const auto d1 = d->flat_last_dim();
if (transa) {
NVTE_CHECK(a0 == m, "Unsupported tensor dimension in A: expected ", m, ", got ", a0);
NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(k, m, block_size(ctx, k), m, 0, 0,
block_size(ctx, k), get_cuda_dtype(a->dtype()),
ctx->grid_col_major.get(), ctx->a_desc.get()));
} else {
NVTE_ERROR("N transpose flag is not supported for input A");
}
if (transb) {
NVTE_ERROR("T transpose flag is not supported for input B");
} else {
NVTE_CHECK(b0 == n, "Unsupported tensor dimension in B: expected ", n, ", got ", b0);
NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(k, n, block_size(ctx, k), n, 0, 0,
block_size(ctx, k), get_cuda_dtype(b->dtype()),
ctx->grid_col_major.get(), ctx->b_desc.get()));
}
NVTE_CHECK(d1 == m, "Unsupported tensor dimension in D: expected ", m, ", got ", d1);
*ldd = m;
NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(m, n * ctx->nranks, m, n, 0, 0, *ldd,
get_cuda_dtype(d->dtype()),
ctx->grid_row_major.get(), ctx->d_desc.get()));
const cublasMpMatmulEpilogue_t epilogue = CUBLASMP_MATMUL_EPILOGUE_ALLREDUCE;
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue,
sizeof epilogue));
}
using InitMatricesFn = void (*)(NVTECommGemmCtx*, int64_t*, int64_t, int64_t, int64_t,
const Tensor*, const Tensor*, const Tensor*, bool, bool);
cublasMpMatmulAlgoType_t cublasmp_algo(NVTECommGemmAlgoType algo) {
static const std::unordered_map<NVTECommGemmAlgoType, cublasMpMatmulAlgoType_t> s_map{
{kNVTECommGemmAlgoDefault, CUBLASMP_MATMUL_ALGO_TYPE_DEFAULT},
{kNVTECommGemmAlgoSplitP2P, CUBLASMP_MATMUL_ALGO_TYPE_SPLIT_P2P},
{kNVTECommGemmAlgoSplitMulticast, CUBLASMP_MATMUL_ALGO_TYPE_SPLIT_MULTICAST},
{kNVTECommGemmAlgoAtomicP2P, CUBLASMP_MATMUL_ALGO_TYPE_ATOMIC_P2P},
{kNVTECommGemmAlgoAtomicMulticast, CUBLASMP_MATMUL_ALGO_TYPE_ATOMIC_MULTICAST},
};
auto it = s_map.find(algo);
return it != s_map.end() ? it->second : static_cast<cublasMpMatmulAlgoType_t>(algo);
}
void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECommGemmAlgoType algo,
int64_t m, int64_t n, int64_t k, const Tensor* a, const Tensor* b,
const Tensor* d, const Tensor* bias, const Tensor* pre_act_out, bool transa,
bool transb, bool grad, bool accumulate, int comm_sm_count,
cudaStream_t main_stream) {
for (auto t : {a, b, d}) {
NVTE_CHECK(is_tensor_scaling(t->scaling_mode),
"Unsupported scaling mode: " + std::to_string(t->scaling_mode));
}
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorInit(ctx->matmul_desc.get(), CUBLAS_COMPUTE_32F));
int64_t ldd{};
init_matrices_fn(ctx, &ldd, m, n, k, a, b, d, transa, transb);
const cublasOperation_t trans_a = transa ? CUBLAS_OP_T : CUBLAS_OP_N;
const cublasOperation_t trans_b = transb ? CUBLAS_OP_T : CUBLAS_OP_N;
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSA, &trans_a,
sizeof trans_a));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSB, &trans_b,
sizeof trans_b));
cublasMpMatmulAlgoType_t algo_attr = cublasmp_algo(algo);
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_ALGO_TYPE, &algo_attr,
sizeof algo_attr));
const cublasMpMatmulMatrixScale_t scale_mode = CUBLASMP_MATMUL_MATRIX_SCALE_SCALAR_FP32;
if (is_fp8_dtype(a->dtype())) {
NVTE_CHECK(a->scale_inv.dptr, "Scaling must be set for FP8 dtype");
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_MODE, &scale_mode,
sizeof scale_mode));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_POINTER,
&a->scale_inv.dptr, sizeof(void*)));
}
if (is_fp8_dtype(b->dtype())) {
NVTE_CHECK(b->scale_inv.dptr, "Scaling must be set for FP8 dtype");
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_MODE, &scale_mode,
sizeof scale_mode));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_POINTER,
&b->scale_inv.dptr, sizeof(void*)));
}
if (is_fp8_dtype(d->dtype())) {
NVTE_CHECK(d->scale.dptr, "Scaling must be set for FP8 dtype");
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_MODE, &scale_mode,
sizeof scale_mode));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_POINTER,
&d->scale.dptr, sizeof(void*)));
if (d->amax.dptr) {
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_AMAX_D_POINTER,
&d->amax.dptr, sizeof(void*)));
}
}
// Might be set to ALLREDUCE before, need to OR with the new flags to set.
cublasMpMatmulEpilogue_t epilogue{};
size_t size_read{};
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeGet(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue,
sizeof epilogue, &size_read));
NVTE_CHECK(size_read == sizeof epilogue);
// (bias, gelu, grad) -> epilogue
const std::map<std::tuple<bool, bool, bool>, cublasMpMatmulEpilogue_t> flags_to_epilogue{
{{true, true, false}, CUBLASMP_MATMUL_EPILOGUE_GELU_AUX_BIAS},
{{true, true, true}, CUBLASMP_MATMUL_EPILOGUE_DGELU_BGRAD},
{{true, false, false}, CUBLASMP_MATMUL_EPILOGUE_BIAS},
{{true, false, true}, CUBLASMP_MATMUL_EPILOGUE_BGRADB},
{{false, true, false}, CUBLASMP_MATMUL_EPILOGUE_GELU_AUX},
{{false, true, true}, CUBLASMP_MATMUL_EPILOGUE_DGELU},
};
if (auto it =
flags_to_epilogue.find({bias ? bias->data.dptr != nullptr : false,
pre_act_out ? pre_act_out->data.dptr != nullptr : false, grad});
it != flags_to_epilogue.end()) {
epilogue = static_cast<cublasMpMatmulEpilogue_t>(epilogue | it->second);
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue,
sizeof epilogue));
}
if (bias && bias->data.dptr) {
cudaDataType_t bias_type = get_cuda_dtype(bias->data.dtype);
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_DATA_TYPE, &bias_type,
sizeof bias_type));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_POINTER, &bias->data.dptr,
sizeof bias->data.dptr));
}
if (pre_act_out && pre_act_out->data.dptr) {
cudaDataType_t aux_type = get_cuda_dtype(pre_act_out->data.dtype);
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_DATA_TYPE,
&aux_type, sizeof aux_type));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_POINTER,
&pre_act_out->data.dptr, sizeof pre_act_out->data.dptr));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_LD, &ldd,
sizeof ldd));
if (is_fp8_dtype(pre_act_out->dtype())) {
NVTE_CHECK(pre_act_out->scale.dptr, "Scaling must be set for FP8 dtype");
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_MODE,
&scale_mode, sizeof scale_mode));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_POINTER,
&pre_act_out->scale.dptr, sizeof(void*)));
if (pre_act_out->amax.dptr) {
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_AMAX_POINTER,
&pre_act_out->amax.dptr, sizeof(void*)));
}
}
}
if (comm_sm_count) {
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_COMMUNICATION_SM_COUNT,
&comm_sm_count, sizeof comm_sm_count));
}
NVTE_CHECK_CUBLASMP(cublasMpStreamSet(ctx->cublas_mp.get(), main_stream));
size_t wrksp_size_device{};
size_t wrksp_size_host{};
float alpha = 1.0;
float beta = accumulate ? 1.0 : 0.0;
std::tuple args{ctx->cublas_mp.get(),
ctx->matmul_desc.get(),
m,
n,
k,
&alpha,
a->data.dptr,
1,
1,
ctx->a_desc.get(),
b->data.dptr,
1,
1,
ctx->b_desc.get(),
&beta,
accumulate ? d->data.dptr : nullptr,
1,
1,
accumulate ? ctx->d_desc.get() : nullptr,
d->data.dptr,
1,
1,
ctx->d_desc.get()};
NVTE_CHECK_CUBLASMP(
std::apply(cublasMpMatmul_bufferSize,
std::tuple_cat(args, std::tuple{&wrksp_size_device, &wrksp_size_host})));
std::vector<uint8_t> workspace_host(wrksp_size_host);
if (ctx->workspace_size < wrksp_size_device) {
nvshmem_free(ctx->workspace);
ctx->workspace = nvshmem_malloc(wrksp_size_device);
ctx->workspace_size = wrksp_size_device;
}
NVTE_CHECK_CUBLASMP(
std::apply(cublasMpMatmul,
std::tuple_cat(args, std::tuple{ctx->workspace, ctx->workspace_size,
workspace_host.data(), workspace_host.size()})));
NVTE_CHECK_CUDA(cudaEventRecord(ctx->event.get(), main_stream));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(ctx->stream.get(), ctx->event.get(), 0));
}
} // namespace
NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank) {
NVTE_API_CALL(nvte_comm_gemm_ctx_create);
auto stream = CudaStreamCreate();
auto event = CudaEventCreate(cudaEventDisableTiming);
auto cublas_mp = CublasMpCreate(stream.get());
auto col_major = CublasMpGridCreate(nranks, 1, CUBLASMP_GRID_LAYOUT_COL_MAJOR, comm);
auto row_major = CublasMpGridCreate(1, nranks, CUBLASMP_GRID_LAYOUT_ROW_MAJOR, comm);
// Pre-creating matrix descriptors here, will be initialized with the actual params later.
auto a_desc = CublasMpMatrixDescCreate(1, 1, 1, 1, 0, 0, 1, CUDA_R_16F, row_major.get());
auto b_desc = CublasMpMatrixDescCreate(1, 1, 1, 1, 0, 0, 1, CUDA_R_16F, row_major.get());
auto d_desc = CublasMpMatrixDescCreate(1, 1, 1, 1, 0, 0, 1, CUDA_R_16F, row_major.get());
auto matmul_desc = CublasMpMatmulDescCreate(CUBLAS_COMPUTE_32F);
return new NVTECommGemmCtx{
.nranks = nranks,
.rank = rank,
.comm = comm,
.stream = std::move(stream),
.event = std::move(event),
.cublas_mp = std::move(cublas_mp),
.grid_col_major = std::move(col_major),
.grid_row_major = std::move(row_major),
.a_desc = std::move(a_desc),
.b_desc = std::move(b_desc),
.d_desc = std::move(d_desc),
.matmul_desc = std::move(matmul_desc),
};
}
void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx) {
NVTE_API_CALL(nvte_comm_gemm_ctx_destroy);
nvshmemx_sync_all_on_stream(ctx->stream.get());
delete ctx;
}
void nvte_all_gather_gemm(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, const NVTETensor a,
const NVTETensor b, const NVTETensor d, const NVTETensor bias,
const NVTETensor pre_act_out, bool transa, bool transb, bool grad,
bool accumulate, int comm_sm_count, cudaStream_t main_stream,
NVTECommGemmAlgoType algo) {
NVTE_API_CALL(nvte_all_gather_gemm);
cublasmp_gemm(AgGemmInitMatrices, ctx, algo, m, n, k, convertNVTETensorCheck(a),
convertNVTETensorCheck(b), convertNVTETensorCheck(d), convertNVTETensorCheck(bias),
convertNVTETensorCheck(pre_act_out), transa, transb, grad, accumulate,
comm_sm_count, main_stream);
}
void nvte_gemm_reduce_scatter(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k,
const NVTETensor a, const NVTETensor b, const NVTETensor d,
const NVTETensor bias, const NVTETensor pre_act_out, bool transa,
bool transb, bool grad, bool accumulate, int comm_sm_count,
cudaStream_t main_stream, NVTECommGemmAlgoType algo) {
NVTE_API_CALL(nvte_gemm_reduce_scatter);
cublasmp_gemm(GemmRsInitMatrices, ctx, algo, m, n, k, convertNVTETensorCheck(a),
convertNVTETensorCheck(b), convertNVTETensorCheck(d), convertNVTETensorCheck(bias),
convertNVTETensorCheck(pre_act_out), transa, transb, grad, accumulate,
comm_sm_count, main_stream);
}
void nvte_gemm_all_reduce(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, const NVTETensor a,
const NVTETensor b, const NVTETensor d, const NVTETensor bias,
const NVTETensor pre_act_out, bool transa, bool transb, bool grad,
bool accumulate, int comm_sm_count, cudaStream_t main_stream,
NVTECommGemmAlgoType algo) {
NVTE_API_CALL(nvte_gemm_all_reduce);
cublasmp_gemm(GemmArInitMatrices, ctx, algo, m, n, k, convertNVTETensorCheck(a),
convertNVTETensorCheck(b), convertNVTETensorCheck(d), convertNVTETensorCheck(bias),
convertNVTETensorCheck(pre_act_out), transa, transb, grad, accumulate,
comm_sm_count, main_stream);
}
int64_t nvte_comm_gemm_numroc(NVTECommGemmCtx* ctx, int64_t global_size) {
NVTE_API_CALL(nvte_comm_gemm_numroc);
return cublasMpNumroc(global_size, block_size(ctx, global_size), ctx->rank, 0, ctx->nranks);
}
......@@ -153,10 +153,10 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
DType::kInt32);
}
// CUDA event creation
cudaEventCreateWithFlags(&_start_compute, 0);
cudaEventCreateWithFlags(&_stop_compute, 0);
cudaEventCreateWithFlags(&_start_comm, 0);
cudaEventCreateWithFlags(&_stop_comm, 0);
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_compute, 0));
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_compute, 0));
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_comm, 0));
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_comm, 0));
/*
Defining the launcher order between the communication and GEMM kernels
......@@ -169,12 +169,12 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
int runtime_version = 6;
#else
int runtime_version = 0;
cudaRuntimeGetVersion(&runtime_version);
NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&runtime_version));
#endif
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, 0);
NVTE_CHECK_CUDA(cudaGetDeviceProperties(&deviceProp, 0));
if (runtime_version >= 12030 && deviceProp.major == 9 && max_connection > 1) {
cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming);
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming));
} else {
_comm_launch_event = 0;
}
......@@ -185,9 +185,13 @@ CommOverlapCore::~CommOverlapCore() {
cudaEventDestroy(_start_comm);
cudaEventDestroy(_stop_compute);
cudaEventDestroy(_start_compute);
if (_comm_launch_event) cudaEventDestroy(_comm_launch_event);
if (_comm_launch_event) {
cudaEventDestroy(_comm_launch_event);
}
if (_atomic_gemm) cudaFree(_counter.dptr());
if (_atomic_gemm) {
cudaFree(_counter.dptr());
}
for (size_t i = 0; i < _stream_compute.size(); i++) {
cudaStreamSynchronize(_stream_compute[i]);
......@@ -723,17 +727,21 @@ void CommOverlapBase::bulk_overlap_external_ag(cudaStream_t send_stream, cudaStr
int comm_bytes_per_rank = comm_bytes / _tp_size;
// We use the reference to the overlap_gemm to get the stream to send an receive on to ensure the kernels don't finish until the previous gemm is flush
userbuffers_send_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _ub_comm,
send_stream);
userbuffers_recv_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _ub_comm,
recv_stream);
userbuffers_send_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _rank,
_ub_comm, send_stream);
userbuffers_recv_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _rank,
_ub_comm, recv_stream);
// We sync with the internal comm stream so the destructor can wait for the comm stream to finish before freeing the ubuf
for (auto stream : {send_stream, recv_stream}) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, stream));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
// We sync with the comm stream so the destructor can wait for the comm stream to finish before freeing the ubuf
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _stop_comm, 0));
}
// Next we sync with the main stream
// We have to recapture an event off the comm stream to enable cuda graph capture otherwise the comm stream will be never be joined in the graph
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
}
/***************************************************************************************************
......@@ -829,7 +837,9 @@ CommOverlapP2PBase::~CommOverlapP2PBase() {
cudaEventDestroy(_stop_recv);
cudaEventDestroy(_stop_send);
cudaStreamDestroy(_stream_recv);
for (size_t i = 0; i < _stream_send.size(); i++) cudaStreamDestroy(_stream_send[i]);
for (size_t i = 0; i < _stream_send.size(); i++) {
cudaStreamDestroy(_stream_send[i]);
}
}
TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source,
......
......@@ -515,7 +515,7 @@ void destroy_communicator_mpi(communicator *comm) {
}
int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc) {
if (comm->free_region > NVTE_MAX_REGIONS) return -1;
if (comm->free_region >= NVTE_MAX_REGIONS) return -1;
int hndl = comm->free_region;
comm->peer_ptr[hndl] = reinterpret_cast<void **>(malloc(sizeof(void *) * (comm->nvsize)));
size_t aligned_size = bytes;
......
......@@ -2436,6 +2436,7 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds
if (comm->push == 0) {
kuserbuffers_pullsend<<<1, 1, 0, stream>>>(comm->myrank, peer, &(comm->send_id[peer]),
reinterpret_cast<int *>(flagptr));
NVTE_CHECK_CUDA(cudaGetLastError());
} else {
void *srcptr = reinterpret_cast<char *>(comm->mem_ptr[srchandler]) + srcoffset;
void *dstptr = reinterpret_cast<char *>(comm->peer_ptr[dsthandler][peerlocal]) + dstoffset;
......@@ -2633,8 +2634,11 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds
&(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler]), reinterpret_cast<int *>(flagptr),
reinterpret_cast<int4 *>(srcptr), reinterpret_cast<int4 *>(dstptr),
signalonly ? 0 : bytes / 16, comm->ub_timeout);
if (!signalonly)
NVTE_CHECK_CUDA(cudaGetLastError());
if (!signalonly) {
kuserbuffers_inc<<<1, 1, 0, stream>>>(&(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler]));
NVTE_CHECK_CUDA(cudaGetLastError());
}
if (comm->use_ce) {
NVTE_CHECK_CUDA(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
}
......@@ -2649,30 +2653,33 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds
reinterpret_cast<int *>(0 ? // temporary disable
GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 2)
: nullptr));
NVTE_CHECK_CUDA(cudaGetLastError());
}
}
void userbuffers_send_all(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes_per_slice, int tp_rank,
int tp_size, communicator *comm, cudaStream_t stream) {
int tp_size, int world_rank, communicator *comm, cudaStream_t stream) {
int rank_round_tp = (world_rank / tp_size) * tp_size;
for (int j = 1; j < tp_size; j++) {
int i = (tp_rank + j) % tp_size;
int send_offset = srcoffset + bytes_per_slice * tp_rank;
int recv_offset = dstoffset + bytes_per_slice * tp_rank;
userbuffers_send(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, i,
stream);
userbuffers_send(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm,
rank_round_tp + i, stream);
}
}
void userbuffers_recv_all(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes_per_slice, int tp_rank,
int tp_size, communicator *comm, cudaStream_t stream) {
int tp_size, int world_rank, communicator *comm, cudaStream_t stream) {
int rank_round_tp = (world_rank / tp_size) * tp_size;
for (int j = tp_size - 1; j > 0; j--) {
int i = (tp_rank + j) % tp_size;
int send_offset = srcoffset + bytes_per_slice * i;
int recv_offset = dstoffset + bytes_per_slice * i;
userbuffers_recv(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, i,
stream);
userbuffers_recv(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm,
rank_round_tp + i, stream);
}
}
......@@ -2747,24 +2754,28 @@ void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream) {
dim3 block(1);
dim3 grid(1);
producer_kernel<<<grid, block, 0, stream>>>(atomic_ptr, chunk_i);
NVTE_CHECK_CUDA(cudaGetLastError());
}
void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream) {
dim3 block(1);
dim3 grid(1);
consumer_kernel<<<grid, block, 0, stream>>>(atomic_ptr, chunk_i);
NVTE_CHECK_CUDA(cudaGetLastError());
}
void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStream_t stream) {
dim3 block(1);
dim3 grid(1);
consumer_batch_kernel<<<grid, block, 0, stream>>>(atomic_ptr, first_chunk_i, num_chunks);
NVTE_CHECK_CUDA(cudaGetLastError());
}
void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, cudaStream_t stream) {
dim3 block(1);
dim3 grid(1);
reset_counters_kernel<<<grid, block, 0, stream>>>(atomic_ptr, num_chunks, allgather);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <typename fp8type, int nvec>
......@@ -2818,6 +2829,7 @@ void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_in
reduce_fp8_in_bf16_out_cuda<fp8type, nvec>
<<<grid, block, 0, stream>>>(inputs, output, scale, num_inputs, input_size,
num_aligned_elements_per_input, tot_input_size);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(void *inputs, void *output, float *scale,
......@@ -2877,4 +2889,5 @@ void reduce_bf16(void *inputs, void *output, int num_inputs, int input_size, cud
dim3 grid(num_blocks);
reduce_bf16_cuda<nvec><<<grid, block, 0, stream>>>(
inputs, output, num_inputs, input_size, num_aligned_elements_per_input, tot_input_size);
NVTE_CHECK_CUDA(cudaGetLastError());
}
......@@ -27,7 +27,7 @@
using ExtAllgatherOp = std::function<void(void *, size_t, void *, size_t, ExtComm)>;
using ExtBarrierOp = std::function<void(ExtComm)>;
#define NVTE_MAX_REGIONS 16
#define NVTE_MAX_REGIONS 32
#define NVTE_MAX_SMS 32
#define NVTE_MAX_OPS 32
#define NVTE_MAX_PEERS 8192
......@@ -314,10 +314,10 @@ void reduce_bf16(void *input, void *output, int num_inputs, int input_size, cuda
void userbuffers_send_all(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes_per_slice, int tp_rank,
int tp_size, communicator *comm, cudaStream_t stream);
int tp_size, int world_rank, communicator *comm, cudaStream_t stream);
void userbuffers_recv_all(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes_per_slice, int tp_rank,
int tp_size, communicator *comm, cudaStream_t stream);
int tp_size, int world_rank, communicator *comm, cudaStream_t stream);
#endif // TRANSFORMER_ENGINE_USERBUFFERS_H_
......@@ -26,12 +26,31 @@ __global__ void __launch_bounds__(1)
} // namespace
cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kFloat16:
return CUDA_R_16F;
case DType::kFloat32:
return CUDA_R_32F;
case DType::kBFloat16:
return CUDA_R_16BF;
case DType::kFloat8E4M3:
return CUDA_R_8F_E4M3;
case DType::kFloat8E5M2:
return CUDA_R_8F_E5M2;
default:
NVTE_ERROR("Invalid type");
}
}
void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) {
if (is_fp8_dtype(t->data.dtype) && is_tensor_scaling(t->scaling_mode)) {
NVTE_CHECK(t->scale_inv.dptr != nullptr, "Tensor should have allocated scale_inv.");
update_tensor_scale_inv_kernel<<<1, 1, 0, stream>>>(
reinterpret_cast<const float *>(t->scale.dptr),
reinterpret_cast<float *>(t->scale_inv.dptr));
NVTE_CHECK_CUDA(cudaGetLastError());
}
}
......@@ -73,6 +92,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
dim3 grid(numBlocks, 1, 1); \
memset_kernel<vectorizedType> \
<<<grid, kThreadsPerBlock, 0, stream>>>(ptr, value, size_in_bytes); \
NVTE_CHECK_CUDA(cudaGetLastError()); \
return; \
}
......@@ -83,7 +103,7 @@ void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream
if (size_in_bytes > 4096) {
// Use cudaMemsetAsync for larger sizes.
cudaMemsetAsync(ptr, value, size_in_bytes, stream);
NVTE_CHECK_CUDA(cudaMemsetAsync(ptr, value, size_in_bytes, stream));
return;
}
......
......@@ -276,6 +276,8 @@ struct QuantizationConfig {
};
};
cudaDataType_t get_cuda_dtype(const transformer_engine::DType t);
template <typename T>
constexpr T DIVUP(const T &x, const T &y) {
return (((x) + ((y)-1)) / (y));
......@@ -395,9 +397,19 @@ struct BitsNumber {
template <typename T>
struct TypeInfo {
#if FP4_TYPE_SUPPORTED
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, int8, fp4e2m1>;
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, int8, fp4e2m1
#if CUDA_VERSION >= 12080
,
fp8e8m0
#endif
>;
#else
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, int8>;
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, int8
#if CUDA_VERSION >= 12080
,
fp8e8m0
#endif
>;
#endif
template <typename U, DType current>
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <curand.h>
#include <curand_kernel.h>
#include <curand_philox4x32_x.h>
#include <cmath>
#include "../common.h"
#include "../utils.cuh"
#include "transformer_engine/dropout.h"
namespace transformer_engine {
namespace {
// RNG kernels process chunks of 16 entries
constexpr size_t rng_chunk_size = 16;
// CUDA block size
constexpr size_t block_size = 128;
// Vector class to help with vectorized memory accesses
template <typename T, size_t kSize>
union Vector {
using StorageType = typename BytesToType<sizeof(T) * kSize>::Type;
StorageType storage;
T entries[kSize];
};
/* Byte-wise less-than comparison
*
* Results are stored in each byte's most-significant bit (MSB). All
* other bits are zero.
*/
__device__ __forceinline__ uint32_t bytewise_less_than(uint32_t a, uint32_t b) {
// Compare low bits by masking MSBs and subtracting. The resulting
// MSBs are 0 if the low bits of a are less than the low bits of b.
uint32_t result = (a | 0x80808080) - (b & 0x7F7F7F7F);
// Bitwise logical op to get answer in MSBs
// Equivalent logic: result = (a == b) ? !result : b
asm("lop3.b32 %0, %1, %2, %3, 0x4D;\n\t" : "=r"(result) : "r"(a), "r"(b), "r"(result));
// Mask out everything except MSBs and return
result &= 0x80808080;
return result;
}
/* Generate dropout mask with 16 bits.
*
* 1 corresponds to keep and 0 to drop.
*
* Consumes 4 values from cuRAND Philox generator.
*/
__device__ __forceinline__ uint16_t make_16bit_mask(uint64_t chunk_idx, uint64_t rng_seed,
uint64_t rng_offset,
uint32_t bytewise_drop_prob) {
// Generate random bits
curandStatePhilox4_32_10_t state;
curand_init(rng_seed, chunk_idx, rng_offset, &state);
const uint4 rand_bits = curand4(&state);
// Compute mask
// Note: bytewise_less_than fills MSBs (bits 7, 15, 23, 31). By
// shifting 2 bits after every call, every other bit will be filled.
uint32_t result = bytewise_less_than(rand_bits.x, bytewise_drop_prob);
result = (result >> 2) | bytewise_less_than(rand_bits.y, bytewise_drop_prob);
result = (result >> 2) | bytewise_less_than(rand_bits.z, bytewise_drop_prob);
result = (result >> 2) | bytewise_less_than(rand_bits.w, bytewise_drop_prob);
// Consolidate mask in lowest 16 bits
result |= result >> 17;
// Flip bits so 0 corresponds to drop
result = ~result;
return result;
}
// Dropout forward with FP16/BF16 input and output.
template <typename T>
__global__ void __launch_bounds__(block_size)
dropout_kernel_fwd_f16(const T *__restrict__ input_ptr, T *__restrict__ output_ptr,
uint8_t *__restrict__ mask_ptr,
const uint64_t *__restrict__ rng_state_ptr, size_t num_chunks,
uint32_t bytewise_drop_prob, float scale) {
static_assert(sizeof(T) == 2);
// Each thread processes a chunk of 16 entries
const size_t gid = threadIdx.x + blockIdx.x * block_size;
const size_t nthreads = gridDim.x * block_size;
for (size_t chunk_idx = gid; chunk_idx < num_chunks; chunk_idx += nthreads) {
// Generate dropout mask
auto local_mask =
make_16bit_mask(chunk_idx, rng_state_ptr[0], rng_state_ptr[1], bytewise_drop_prob);
reinterpret_cast<uint16_t *>(mask_ptr)[chunk_idx] = local_mask;
// Read input data
using VectorType = Vector<T, rng_chunk_size>;
VectorType local_data;
local_data = reinterpret_cast<const VectorType *>(input_ptr)[chunk_idx];
// Apply dropout based on mask
#pragma unroll
for (size_t i = 0; i < rng_chunk_size; i++) {
float val = static_cast<float>(local_data.entries[i]);
if ((local_mask & 0x1) == 0) {
val = 0;
}
val *= scale;
local_data.entries[i] = static_cast<T>(val);
local_mask >>= 1;
}
// Write output data
reinterpret_cast<VectorType *>(output_ptr)[chunk_idx] = local_data;
}
}
// Dropout forward with FP8 input and FP16/BF16 output.
template <typename InputType, typename OutputType>
__global__ void __launch_bounds__(block_size)
dropout_kernel_fwd_fp8(const InputType *__restrict__ input_ptr,
const float *__restrict__ input_scale_inv_ptr,
OutputType *__restrict__ output_ptr, uint8_t *__restrict__ mask_ptr,
const uint64_t *__restrict__ rng_state_ptr, size_t num_chunks,
uint32_t bytewise_drop_prob, float scale) {
static_assert(sizeof(InputType) == 1);
static_assert(sizeof(OutputType) == 2);
const float input_scale_inv = *input_scale_inv_ptr;
// Each thread processes a chunk of 16 entries
const size_t gid = threadIdx.x + blockIdx.x * block_size;
const size_t nthreads = gridDim.x * block_size;
for (size_t chunk_idx = gid; chunk_idx < num_chunks; chunk_idx += nthreads) {
// Generate dropout mask
auto local_mask =
make_16bit_mask(chunk_idx, rng_state_ptr[0], rng_state_ptr[1], bytewise_drop_prob);
reinterpret_cast<uint16_t *>(mask_ptr)[chunk_idx] = local_mask;
// Read input data
using InputVectorType = Vector<InputType, rng_chunk_size>;
InputVectorType local_input;
local_input = reinterpret_cast<const InputVectorType *>(input_ptr)[chunk_idx];
// Apply dropout based on mask
using OutputVectorType = Vector<OutputType, rng_chunk_size>;
OutputVectorType local_output;
#pragma unroll
for (size_t i = 0; i < rng_chunk_size; i++) {
float val = static_cast<float>(local_input.entries[i]);
val *= input_scale_inv;
if ((local_mask & 0x1) == 0) {
val = 0;
}
val *= scale;
local_output.entries[i] = static_cast<OutputType>(val);
local_mask >>= 1;
}
// Write output data
reinterpret_cast<OutputVectorType *>(output_ptr)[chunk_idx] = local_output;
}
}
// Apply dropout mask and scale.
template <typename T>
__global__ void __launch_bounds__(block_size)
apply_dropout_mask(const T *__restrict__ input_ptr, const uint8_t *__restrict__ mask_ptr,
T *__restrict__ output_ptr, size_t num_chunks, float scale) {
// Each thread processes a chunk of 8 entries.
const size_t gid = threadIdx.x + blockIdx.x * block_size;
const size_t nthreads = gridDim.x * block_size;
constexpr size_t chunk_size = 8;
for (size_t chunk_idx = gid; chunk_idx < num_chunks; chunk_idx += nthreads) {
// Read dropout mask
uint8_t local_mask = mask_ptr[chunk_idx];
// Read input data
using VectorType = Vector<T, chunk_size>;
VectorType local_data;
local_data = reinterpret_cast<const VectorType *>(input_ptr)[chunk_idx];
// Apply dropout based on mask
#pragma unroll
for (size_t i = 0; i < chunk_size; i++) {
float val = static_cast<float>(local_data.entries[i]);
if ((local_mask & 0x1) == 0) {
val = 0;
}
val *= scale;
local_data.entries[i] = static_cast<T>(val);
local_mask >>= 1;
}
// Write output data
reinterpret_cast<VectorType *>(output_ptr)[chunk_idx] = local_data;
}
}
} // namespace
void dropout_fwd(const Tensor &input, Tensor &output, Tensor &mask, Tensor &rng_state,
float dropout_probability, cudaStream_t stream) {
// Check tensors
const size_t numel = input.numel();
NVTE_CHECK(input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Input tensor must be FP16/BF16 tensor or tensor-scaled FP8 tensor, ",
"but scaling mode is ", to_string(input.scaling_mode), ".");
NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Output tensor must be FP16/BF16 tensor, ", "but scaling mode is ",
to_string(output.scaling_mode), ".");
NVTE_CHECK(mask.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, "Mask tensor must be plain tensor, ",
"but scaling mode is ", to_string(mask.scaling_mode), ".");
NVTE_CHECK(rng_state.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"RNG state tensor must be INT64 tensor with two entries, ", "but scaling mode is ",
to_string(rng_state.scaling_mode), ".");
NVTE_CHECK(output.dtype() == DType::kFloat16 || output.dtype() == DType::kBFloat16,
"Output tensor must be FP16/BF16 tensor, but dtype is ", to_string(output.dtype()),
".");
NVTE_CHECK(rng_state.dtype() == DType::kInt64,
"RNG state tensor must be INT64 tensor with two entries, but dtype is ",
to_string(rng_state.dtype()), ".");
NVTE_CHECK(numel % 16 == 0,
"Input tensor number of elements must be divisible by 16, but shape is ",
input.shape(), ".");
NVTE_CHECK(numel == output.numel(), "Input tensor (shape=", input.shape(),
") and output tensor (shape=", output.shape(), ") do not match.");
NVTE_CHECK(typeToNumBits(mask.dtype()) * mask.numel() == numel, "Mask tensor must have ", numel,
" bits, but found dtype=", to_string(mask.dtype()), " and shape=", mask.shape(), ".");
NVTE_CHECK(rng_state.numel() == 2, "RNG state tensor must be INT64 tensor with two entries, ",
"but shape is ", rng_state.shape(), ".");
NVTE_CHECK(input.data.dptr != nullptr, "Input tensor is missing data.");
NVTE_CHECK(output.data.dptr != nullptr, "Output tensor is missing data.");
NVTE_CHECK(mask.data.dptr != nullptr, "Mask tensor is missing data.");
NVTE_CHECK(rng_state.data.dptr != nullptr, "RNG state tensor is missing data.");
// Convert dropout probablity to scale and 8-bit representation
NVTE_CHECK(dropout_probability >= 0 && dropout_probability < 1, "Invalid dropout probability (",
dropout_probability, ").");
const float scale = 1 / (1 - dropout_probability);
uint32_t bytewise_drop_prob = static_cast<uint32_t>(std::floor(dropout_probability * 256));
bytewise_drop_prob |= bytewise_drop_prob << 8;
bytewise_drop_prob |= bytewise_drop_prob << 16;
// CUDA config
const size_t num_chunks = numel / rng_chunk_size;
const size_t num_blocks = DIVUP(num_chunks, block_size);
// Launch kernel depending on input dtype
if (input.dtype() == DType::kFloat16 || input.dtype() == DType::kBFloat16) {
NVTE_CHECK(input.dtype() == output.dtype(), "Input tensor (dtype=", to_string(input.dtype()),
") and output tensor (dtype=", to_string(output.dtype()), ") do not match.");
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(
input.dtype(), DType,
dropout_kernel_fwd_f16<DType><<<num_blocks, block_size, 0, stream>>>(
reinterpret_cast<const DType *>(input.data.dptr),
reinterpret_cast<DType *>(output.data.dptr),
reinterpret_cast<uint8_t *>(mask.data.dptr),
reinterpret_cast<const uint64_t *>(rng_state.data.dptr), num_chunks, bytewise_drop_prob,
scale););
NVTE_CHECK_CUDA(cudaGetLastError());
} else if (input.dtype() == DType::kFloat8E4M3 || input.dtype() == DType::kFloat8E5M2) {
NVTE_CHECK(input.scale_inv.dptr != nullptr, "Input tensor scale-inverse is not allocated.");
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
input.dtype(), InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(
output.dtype(), OutputType,
dropout_kernel_fwd_fp8<InputType, OutputType><<<num_blocks, block_size, 0, stream>>>(
reinterpret_cast<const InputType *>(input.data.dptr),
reinterpret_cast<const float *>(input.scale_inv.dptr),
reinterpret_cast<OutputType *>(output.data.dptr),
reinterpret_cast<uint8_t *>(mask.data.dptr),
reinterpret_cast<const uint64_t *>(rng_state.data.dptr), num_chunks,
bytewise_drop_prob, scale);
););
NVTE_CHECK_CUDA(cudaGetLastError());
} else {
NVTE_ERROR("Input tensor must be FP16/BF16 tensor or tensor-scaled FP8 tensor, ",
"but dtype is ", to_string(input.dtype()), ".");
}
}
void dropout_bwd(const Tensor &grad_output, const Tensor &mask, Tensor &grad_input,
float dropout_probability, cudaStream_t stream) {
// Check tensors
const size_t numel = grad_output.numel();
NVTE_CHECK(grad_output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Grad output tensor must be FP16/BF16 tensor, ", "but scaling mode is ",
to_string(grad_output.scaling_mode), ".");
NVTE_CHECK(grad_input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Grad input tensor must be FP16/BF16 tensor, ", "but scaling mode is ",
to_string(grad_input.scaling_mode), ".");
NVTE_CHECK(mask.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Mask tensor must be a plain tensor, but scaling mode is ",
to_string(mask.scaling_mode), ".");
NVTE_CHECK(grad_output.dtype() == DType::kFloat16 || grad_output.dtype() == DType::kBFloat16,
"Grad output tensor must be FP16/BF16 tensor, but dtype is ",
to_string(grad_output.dtype()), ".");
NVTE_CHECK(grad_output.dtype() == grad_input.dtype(),
"Grad output tensor (dtype=", to_string(grad_output.dtype()),
") and grad input tensor (dtype=", to_string(grad_input.dtype()), ") do not match.");
NVTE_CHECK(numel % 16 == 0,
"Grad output tensor number of elements must be divisible by 16, but shape is ",
grad_output.shape(), ".");
NVTE_CHECK(numel == grad_input.numel(), "Grad output tensor (shape=", grad_output.shape(),
") and grad input tensor (shape=", grad_input.shape(), ") do not match.");
NVTE_CHECK(typeToNumBits(mask.dtype()) * mask.numel() == numel, "Mask tensor must have ", numel,
" bits, but found dtype=", to_string(mask.dtype()), " and shape=", mask.shape(), ".");
NVTE_CHECK(grad_output.data.dptr != nullptr, "Grad output tensor is missing data.");
NVTE_CHECK(grad_input.data.dptr != nullptr, "Grad input tensor is missing data.");
NVTE_CHECK(mask.data.dptr != nullptr, "Mask tensor is missing data.");
// Convert dropout probablity to scale
NVTE_CHECK(dropout_probability >= 0 && dropout_probability < 1, "Invalid dropout probability (",
dropout_probability, ").");
const float scale = 1 / (1 - dropout_probability);
// CUDA config
const size_t num_chunks = numel / 8;
const size_t num_blocks = DIVUP(num_chunks, block_size);
// Launch kernel
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(
grad_output.dtype(), DType,
apply_dropout_mask<DType><<<num_blocks, block_size, 0, stream>>>(
reinterpret_cast<const DType *>(grad_output.data.dptr),
reinterpret_cast<const uint8_t *>(mask.data.dptr),
reinterpret_cast<DType *>(grad_input.data.dptr), num_chunks, scale););
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace transformer_engine
void nvte_dropout_fwd(const NVTETensor input, NVTETensor output, NVTETensor mask,
NVTETensor rng_state, float dropout_probability, cudaStream_t stream) {
NVTE_API_CALL(nvte_dropout_fwd);
using namespace transformer_engine;
dropout_fwd(*convertNVTETensorCheck(input), *convertNVTETensorCheck(output),
*convertNVTETensorCheck(mask), *convertNVTETensorCheck(rng_state),
dropout_probability, stream);
}
void nvte_dropout_bwd(const NVTETensor grad_output, const NVTETensor mask, NVTETensor grad_input,
float dropout_probability, cudaStream_t stream) {
NVTE_API_CALL(nvte_dropout_bwd);
using namespace transformer_engine;
dropout_bwd(*convertNVTETensorCheck(grad_output), *convertNVTETensorCheck(mask),
*convertNVTETensorCheck(grad_input), dropout_probability, stream);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment