Commit f8c2af4c authored by yuguo's avatar yuguo
Browse files

Merge commit '1d903f5e' of...

Merge commit '1d903f5e' of https://github.com/NVIDIA/TransformerEngine
parents e92773a3 1d903f5e
......@@ -11,7 +11,7 @@ from transformer_engine.pytorch.utils import (
get_device_compute_capability,
get_cudnn_version,
)
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils
from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils
from test_fused_attn import ModelConfig
from torch.utils.cpp_extension import IS_HIP_EXTENSION
......
......@@ -11,6 +11,12 @@ import math
import pytest
import torch
from test_fused_attn import (
ModelConfig,
reset_rng_states,
_get_attention_backends,
)
from torch.distributions import Exponential
from transformer_engine.pytorch import make_graphed_callables
from transformer_engine.common import recipe
......@@ -18,20 +24,15 @@ from transformer_engine.pytorch import fp8_autocast, fp8_model_init
from transformer_engine.pytorch.transformer import (
TransformerLayer,
)
from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils
from transformer_engine.pytorch.attention import DotProductAttention, InferenceParams
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
FlashAttentionUtils as fa_utils,
)
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
init_method_normal,
scaled_init_method_normal,
is_bf16_compatible,
)
from test_fused_attn import (
ModelConfig,
reset_rng_states,
_get_attention_backends,
)
# Initialize RNG state
seed = 1234
......
......@@ -392,6 +392,110 @@ class TestFloat8BlockwiseTensor:
with pytest.raises(AssertionError):
torch.testing.assert_close(x_view.dequantize(), -x_hp, **_tols[fp8_dtype])
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2], ids=str)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize(
"dims", [[16, 16, 512], [16, 16, 512, 16], [12, 7, 11], [13, 14, 16], [2, 3, 5]]
)
def test_view_and_reshape_1D(
self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: List[int]
) -> None:
"""Test view operations that preserve tensor shape"""
device = "cuda"
def is_bitwise_equal(a, b):
if a.numel() != b.numel():
return False
a_flat = a.reshape(-1).view(torch.uint8)
b_flat = b.reshape(-1).view(torch.uint8)
return torch.all((a_flat ^ b_flat) == 0)
x_hp = torch.rand(dims, dtype=dtype, device=device)
quantizer = Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=True,
columnwise=True,
block_scaling_dim=1,
)
x_fp8 = quantizer.make_empty(x_hp.shape, dtype=dtype, device=device)
quantizer.update_quantized(x_hp.clone(), x_fp8)
# Test view, high dimension tensor -> 2D tensor
x_hp_view = x_hp.view(-1, dims[-1]).contiguous()
x_fp8_view = x_fp8.view(-1, dims[-1])
# Check the dequantized result
torch.testing.assert_close(
x_fp8_view.dequantize().contiguous(), x_hp_view, **_tols[fp8_dtype]
)
# Check the bitwise equality of the inner data
assert is_bitwise_equal(x_fp8_view._rowwise_data, x_fp8._rowwise_data)
assert is_bitwise_equal(x_fp8_view._rowwise_scale_inv, x_fp8._rowwise_scale_inv)
# Check the data ptr
assert x_fp8_view._rowwise_data.data_ptr() == x_fp8._rowwise_data.data_ptr()
assert x_fp8_view._rowwise_scale_inv.data_ptr() == x_fp8._rowwise_scale_inv.data_ptr()
# Test reshape high dimension tensor -> 2D tensor
x_hp_reshape = x_hp.reshape(-1, dims[-1]).contiguous()
x_fp8_reshape = x_fp8.reshape(-1, dims[-1])
# Check the dequantized result
torch.testing.assert_close(
x_fp8_reshape.dequantize().contiguous(), x_hp_reshape, **_tols[fp8_dtype]
)
# Check the bitwise equality of the inner data
assert is_bitwise_equal(x_fp8_reshape._rowwise_data, x_fp8._rowwise_data)
assert is_bitwise_equal(x_fp8_reshape._rowwise_scale_inv, x_fp8._rowwise_scale_inv)
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2], ids=str)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("dims", [[16, 16, 512, 16], [2, 512, 512, 128], [3, 13, 14, 16]])
def test_view_and_reshape_2D(
self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: List[int]
) -> None:
"""Test view operations that preserve tensor shape"""
device = "cuda"
def is_bitwise_equal(a, b):
if a.numel() != b.numel():
return False
a_flat = a.reshape(-1).view(torch.uint8)
b_flat = b.reshape(-1).view(torch.uint8)
return torch.all((a_flat ^ b_flat) == 0)
x_hp = torch.rand(dims, dtype=dtype, device=device)
quantizer = Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=True,
columnwise=True,
block_scaling_dim=2,
)
x_fp8 = quantizer.make_empty(x_hp.shape, dtype=dtype, device=device)
quantizer.update_quantized(x_hp.clone(), x_fp8)
# Test view, high dimension tensor -> 2D tensor
x_hp_view = x_hp.view(-1, dims[-2], dims[-1]).contiguous()
x_fp8_view = x_fp8.view(-1, dims[-2], dims[-1])
# Check the dequantized result
torch.testing.assert_close(
x_fp8_view.dequantize().contiguous(), x_hp_view, **_tols[fp8_dtype]
)
# Check the bitwise equality of the inner data
assert is_bitwise_equal(x_fp8_view._rowwise_data, x_fp8._rowwise_data)
assert is_bitwise_equal(x_fp8_view._rowwise_scale_inv, x_fp8._rowwise_scale_inv)
# Check the data ptr
assert x_fp8_view._rowwise_data.data_ptr() == x_fp8._rowwise_data.data_ptr()
assert x_fp8_view._rowwise_scale_inv.data_ptr() == x_fp8._rowwise_scale_inv.data_ptr()
# Test reshape high dimension tensor -> 2D tensor
x_hp_reshape = x_hp.reshape(-1, dims[-2], dims[-1]).contiguous()
x_fp8_reshape = x_fp8.reshape(-1, dims[-2], dims[-1])
# Check the dequantized result
torch.testing.assert_close(
x_fp8_reshape.dequantize().contiguous(), x_hp_reshape, **_tols[fp8_dtype]
)
# Check the bitwise equality of the inner data
assert is_bitwise_equal(x_fp8_reshape._rowwise_data, x_fp8._rowwise_data)
assert is_bitwise_equal(x_fp8_reshape._rowwise_scale_inv, x_fp8._rowwise_scale_inv)
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize("dims", [[256, 512], [250, 500]])
......
......@@ -18,7 +18,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.constants import TE_DType, TE_DType_To_Torch
from transformer_engine.pytorch.utils import non_tn_fp8_gemm_supported
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
import transformer_engine_torch as tex
from references.ref_per_tensor_cs import ref_per_tensor_cs_cast
......@@ -400,7 +400,7 @@ class TestCurrentScalingFloat8Tensor:
"""Check numerical error when casting to FP8"""
# Skip invalid configurations
if non_tn_fp8_gemm_supported() and return_transpose:
if is_non_tn_fp8_gemm_supported() and return_transpose:
pytest.skip("FP8 transpose is neither needed nor supported on current system")
# Initialize random high precision data
......
......@@ -12,10 +12,11 @@ from torch import nn
from torch.testing._internal.common_device_type import largeTensorTest
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.pytorch.attention import MultiheadAttention
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
from transformer_engine.pytorch import fp8_model_init
from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import gpu_autocast_ctx
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
......@@ -596,7 +597,7 @@ class AdamTest:
gt_ = gt.clone()
# Reference
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model(x)
loss = ((gt - y) ** 2).mean()
......@@ -605,7 +606,7 @@ class AdamTest:
scaler.update()
# DUT
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()
......@@ -647,7 +648,7 @@ class AdamTest:
gt_ = gt.clone()
# Reference
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model(x)
loss = ((gt - y) ** 2).mean()
......@@ -656,7 +657,7 @@ class AdamTest:
scaler.update()
# DUT
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()
......@@ -705,7 +706,7 @@ class AdamTest:
gt_ = gt.clone()
# Reference
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model(x)
loss = ((gt - y) ** 2).mean()
......@@ -714,7 +715,7 @@ class AdamTest:
scaler.update()
# DUT
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from typing import Callable, Tuple, Union
import math
import pytest
import torch
from typing import Callable, Tuple, Union
from transformer_engine.pytorch.dot_product_attention.rope import (
import pytest
from transformer_engine.pytorch.attention.rope import (
RotaryPositionEmbedding,
apply_rotary_pos_emb,
)
......@@ -22,6 +22,7 @@ def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
return torch.sum(output * t)
@pytest.mark.parametrize("start_positions", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("seq_length", [2048, 4096])
@pytest.mark.parametrize("hidden_size", [128, 256])
......@@ -43,7 +44,17 @@ def test_fused_rope(
loss_func: Callable,
cp_size: int,
interleaved: bool,
start_positions: bool,
) -> None:
if margin == 0 and start_positions == True:
# This makes sure that the `start_positions` offsets being applied
# are with the maximum length of the rope embeddings.
pytest.skip("Skipping test with margin=0 and start_positions=True")
if start_positions == True and cp_size > 1:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest.skip("Skipping test with cp_size>1 and start_positions=True")
device = torch.device("cuda:0")
batch_size, head_num = 2, 64
t = torch.rand(
......@@ -51,6 +62,14 @@ def test_fused_rope(
dtype=dtype,
device=device,
)
# Get arbitrary offsets to be used with RoPE for all the sequences
start_positions = (
torch.randint(0, margin, (batch_size,), dtype=torch.int32, device=device)
if start_positions
else None
)
if tensor_format == "bshd":
t = t.transpose(0, 1).contiguous()
if transpose:
......@@ -69,14 +88,18 @@ def test_fused_rope(
t.float(),
emb,
tensor_format=tensor_format,
start_positions=start_positions,
interleaved=interleaved,
fused=False,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)
loss_unfused = loss_func(output_unfused)
if not isinstance(start_positions, torch.Tensor):
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None
# fused
......@@ -84,21 +107,29 @@ def test_fused_rope(
t,
emb,
tensor_format=tensor_format,
start_positions=start_positions,
interleaved=interleaved,
fused=True,
cp_size=cp_size,
cp_rank=cp_rank,
)
loss_fused = loss_func(output_fused)
if not isinstance(start_positions, torch.Tensor):
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None
torch.testing.assert_close(output_fused, output_unfused)
if not isinstance(start_positions, torch.Tensor):
torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous()
@pytest.mark.parametrize("margin", [10])
@pytest.mark.parametrize("start_positions", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
......@@ -114,10 +145,25 @@ def test_fused_rope_thd(
loss_func: Callable,
cp_size: int,
interleaved: bool,
start_positions: bool,
margin: int,
) -> None:
if start_positions == True and cp_size > 1:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest.skip("Skipping test with cp_size>1 and start_positions=True")
device = torch.device("cuda:0")
batch_size, head_num = 2, 64
cu_seqlens = [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048]
# Get arbitrary offsets to be used with RoPE for all the sequences
start_positions = (
torch.randint(0, margin, (len(cu_seqlens) - 1,), dtype=torch.int32, device=device)
if start_positions
else None
)
if cp_size > 1:
cu_seqlens_padded = [0]
for i in range(1, len(cu_seqlens)):
......@@ -152,6 +198,7 @@ def test_fused_rope_thd(
output_unfused = apply_rotary_pos_emb(
t.float(),
emb,
start_positions=start_positions,
tensor_format="thd",
interleaved=interleaved,
fused=False,
......@@ -160,6 +207,8 @@ def test_fused_rope_thd(
cp_rank=cp_rank,
).to(dtype)
loss_unfused = loss_func(output_unfused)
if not isinstance(start_positions, torch.Tensor):
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None
......@@ -168,6 +217,7 @@ def test_fused_rope_thd(
output_fused = apply_rotary_pos_emb(
t,
emb,
start_positions=start_positions,
interleaved=interleaved,
fused=True,
tensor_format="thd",
......@@ -176,9 +226,15 @@ def test_fused_rope_thd(
cp_rank=cp_rank,
)
loss_fused = loss_func(output_fused)
if not isinstance(start_positions, torch.Tensor):
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None
torch.testing.assert_close(output_fused, output_unfused)
if not isinstance(start_positions, torch.Tensor):
torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous()
......@@ -160,7 +160,7 @@ def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tens
normab = torch.cat((a.norm().view(1), b.norm().view(1)))
norm_per_tensor = norm_per_tensor.view(-1, 2)
else:
norm, _ = applier(tex.multi_tensor_l2norm, overflow_buf, [in_list], True)
norm, _ = applier(tex.multi_tensor_l2norm, overflow_buf, [in_list], False)
reference = torch.full(
[(sizea + sizeb) * repeat], val, dtype=torch.float32, device=device
......
......@@ -7,7 +7,6 @@ import math
import os
from typing import Dict, List, Tuple, Optional
import pytest
import copy
import random
import torch
......@@ -40,12 +39,12 @@ from transformer_engine.pytorch import (
Fp8Unpadding,
)
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.pytorch.utils import get_device_compute_capability, get_cudnn_version
from transformer_engine.common import recipe
import transformer_engine_torch as tex
......@@ -135,18 +134,20 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
def assert_allclose(
l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None
l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float = None, rtol: float = None
) -> bool:
"""Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs."
for i, (t1, t2) in enumerate(zip(l1, l2)):
tols = dict(atol=atol)
tols = dtype_tols(t2.dtype)
if rtol is not None:
tols["rtol"] = rtol
if atol is not None:
tols["atol"] = atol
result = torch.allclose(t1, t2, **tols)
if not result:
diff = torch.abs(t1 - t2)
tol = atol + (rtol * torch.abs(t2))
tol = tols["atol"] + (tols["rtol"] * torch.abs(t2))
exceed_mask = diff > tol
if exceed_mask.any():
indices = torch.nonzero(exceed_mask, as_tuple=True)
......@@ -2304,6 +2305,12 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
pytest.skip("FusedAttention and FlashAttention do not support FP32")
if use_RoPE:
pytest.skip("KV cache does not support starting positions for RoPE")
if (
backend == "FusedAttention"
and get_device_compute_capability() == (8, 9)
and get_cudnn_version() < (9, 11, 0)
):
pytest.skip("Skip KV cache for sm89 and cuDNN < 9.11")
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
......
......@@ -19,11 +19,12 @@ class TestParallelCrossEntropy:
label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none"
)
def generate_input(self, dtype: torch.dtype, swap_dim: bool):
def generate_input(self, dtype: torch.dtype, swap_dim: bool, ignore_idx: bool):
SQ = random.choice([64, 128])
batch = random.choice([1, 2])
vocab = random.choice([64000, 128000])
ignore = random.sample(range(0, SQ - 1), 5)
if swap_dim:
self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype).cuda()
......@@ -32,14 +33,27 @@ class TestParallelCrossEntropy:
self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype).cuda()
self.tar_test = torch.randint(0, vocab, (batch, SQ)).cuda()
if ignore_idx:
for i in ignore:
# Ignore 5 indices
if swap_dim:
self.tar_test[i][0] = -100
else:
self.tar_test[0][i] = -100
self.input_ref = torch.reshape(self.input_test.clone().detach(), (batch * SQ, vocab))
self.tar_ref = torch.reshape(self.tar_test.clone().detach(), (batch * SQ,))
def one_iteration_test(
self, dtype: torch.dtype, swap_dim: bool, label_smoothing: float, reduce_loss: bool
self,
dtype: torch.dtype,
swap_dim: bool,
label_smoothing: float,
reduce_loss: bool,
ignore_idx: bool = False,
):
self.generate_input(dtype, swap_dim)
self.generate_input(dtype, swap_dim, ignore_idx)
self.input_test.requires_grad_(True)
self.input_ref.requires_grad_(True)
......@@ -57,6 +71,8 @@ class TestParallelCrossEntropy:
test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss
torch.testing.assert_close(test_loss, ref_loss, check_dtype=False)
if ignore_idx:
print(test_loss, ref_loss)
if reduce_loss:
torch.testing.assert_close(
torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad
......@@ -106,3 +122,15 @@ class TestParallelCrossEntropy:
self.one_iteration_test(
dtype=torch.float32, swap_dim=False, label_smoothing=0, reduce_loss=False
)
def test_ignore_idx(self):
self.generate_iters(5)
self.generate_infra(False, 0)
for i in range(self.iters):
self.one_iteration_test(
dtype=torch.float32,
swap_dim=random.choice([True, False]),
label_smoothing=0,
reduce_loss=False,
ignore_idx=True,
)
......@@ -373,7 +373,9 @@ def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
torch.cuda.synchronize()
def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad):
def _test_sanity_common(
block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching=True
):
if skip_dgrad and skip_wgrad:
pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")
......@@ -389,7 +391,11 @@ def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad
use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
if not microbatching:
te_out = block(te_inp)
else:
_ = block(te_inp, is_first_microbatch=True)
te_out = block(te_inp, is_first_microbatch=False)
if isinstance(te_out, tuple):
te_out = te_out[0]
loss = te_out.sum()
......@@ -443,8 +449,16 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("microbatching", all_boolean)
def test_sanity_layernorm_linear(
dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, normalization
dtype,
fp8_recipe,
model,
skip_wgrad,
zero_centered_gamma,
skip_dgrad,
normalization,
microbatching,
):
config = model_configs[model]
......@@ -470,7 +484,7 @@ def test_sanity_layernorm_linear(
params_dtype=dtype,
device="cuda",
)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
@pytest.mark.parametrize("dtype", param_types)
......@@ -478,7 +492,8 @@ def test_sanity_layernorm_linear(
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
@pytest.mark.parametrize("microbatching", all_boolean)
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microbatching):
config = model_configs[model]
if fp8_recipe is not None:
......@@ -501,7 +516,7 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
params_dtype=dtype,
device="cuda",
)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
@pytest.mark.parametrize("dtype", param_types)
......@@ -600,8 +615,17 @@ def test_sanity_grouped_linear(
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("microbatching", all_boolean)
def test_sanity_layernorm_mlp(
dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, activation, normalization
dtype,
fp8_recipe,
model,
skip_wgrad,
zero_centered_gamma,
skip_dgrad,
activation,
normalization,
microbatching,
):
config = model_configs[model]
......@@ -630,7 +654,7 @@ def test_sanity_layernorm_mlp(
params_dtype=dtype,
device="cuda",
)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
@pytest.mark.parametrize("dtype", param_types)
......
......@@ -11,12 +11,12 @@ import transformer_engine.common
try:
from . import pytorch
except (ImportError, StopIteration) as e:
except ImportError as e:
pass
try:
from . import jax
except (ImportError, StopIteration) as e:
except ImportError as e:
pass
__version__ = str(metadata.version("transformer_engine"))
......@@ -111,6 +111,11 @@ if(USE_CUDA)
cudnn_utils.cpp
transformer_engine.cpp
common.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
transpose/cast_transpose.cu
transpose/transpose.cu
transpose/cast_transpose_fusion.cu
......@@ -148,6 +153,7 @@ if(USE_CUDA)
fused_rope/fused_rope.cu
recipe/current_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
......@@ -158,6 +164,11 @@ else()
cudnn_utils.cpp
transformer_engine.cpp
common.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
transpose/cast_transpose.cu
transpose/transpose.cu
transpose/cast_transpose_fusion.cu
......@@ -191,6 +202,7 @@ else()
fused_rope/fused_rope.cu
recipe/current_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
......@@ -345,6 +357,14 @@ target_include_directories(transformer_engine PRIVATE
set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
PROPERTIES
COMPILE_OPTIONS "--use_fast_math")
option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF)
......
......@@ -9,28 +9,193 @@ import glob
import sysconfig
import subprocess
import ctypes
import logging
import os
import platform
import importlib
import functools
from pathlib import Path
from importlib.metadata import version, metadata, PackageNotFoundError
import transformer_engine
_logger = logging.getLogger(__name__)
def is_package_installed(package):
"""Checks if a pip package is installed."""
return (
subprocess.run(
[sys.executable, "-m", "pip", "show", package], capture_output=True, check=False
).returncode
== 0
@functools.lru_cache(maxsize=None)
def _is_pip_package_installed(package):
"""Check if the given package is installed via pip."""
# This is needed because we only want to return true
# if the python package is installed via pip, and not
# if it's importable in the current directory due to
# the presence of the shared library module.
try:
metadata(package)
except PackageNotFoundError:
return False
return True
@functools.lru_cache(maxsize=None)
def _find_shared_object_in_te_dir(te_path: Path, prefix: str):
"""
Find a shared object file of given prefix in the top level TE directory.
Only the following locations are searched to avoid stray SOs and build
artifacts:
1. The given top level directory (editable install).
2. `transformer_engine` named directories (source install).
3. `wheel_lib` named directories (PyPI install).
Returns None if no shared object files are found.
Raises an error if multiple shared object files are found.
"""
# Ensure top level dir exists and has the module. before searching.
if not te_path.exists() or not (te_path / "transformer_engine").exists():
return None
files = []
search_paths = (
te_path,
te_path / "transformer_engine",
te_path / "transformer_engine/wheel_lib",
te_path / "wheel_lib",
)
# Search.
for dirname, _, names in os.walk(te_path):
if Path(dirname) in search_paths:
for name in names:
if name.startswith(prefix) and name.endswith(f".{_get_sys_extension()}"):
files.append(Path(dirname, name))
if len(files) == 0:
return None
if len(files) == 1:
return files[0]
raise RuntimeError(f"Multiple files found: {files}")
@functools.lru_cache(maxsize=None)
def _get_shared_object_file(library: str) -> Path:
"""
Return the path of the shared object file for the given TE
library, one of 'core', 'torch', or 'jax'.
Several factors affect finding the correct location of the shared object:
1. System and environment.
2. If the installation is from source or via PyPI.
- Source installed .sos are placed in top level dir
- Wheel/PyPI installed .sos are placed in 'wheel_lib' dir to avoid conflicts.
3. For source installations, is the install editable/inplace?
4. The user directory from where TE is being imported.
"""
# Check provided input and determine the correct prefix for .so.
assert library in ("core", "torch", "jax"), f"Unsupported TE library {library}."
if library == "core":
so_prefix = "libtransformer_engine"
else:
so_prefix = f"transformer_engine_{library}"
# Check TE install location (will be local if TE is available in current dir for import).
te_install_dir = Path(importlib.util.find_spec("transformer_engine").origin).parent.parent
so_path_in_install_dir = _find_shared_object_in_te_dir(te_install_dir, so_prefix)
# Check default python package install location in system.
site_packages_dir = Path(sysconfig.get_paths()["purelib"])
so_path_in_default_dir = _find_shared_object_in_te_dir(site_packages_dir, so_prefix)
# Case 1: Typical user workflow: Both locations are the same, return any result.
if te_install_dir == site_packages_dir:
assert (
so_path_in_install_dir is not None
), f"Could not find shared object file for Transformer Engine {library} lib."
return so_path_in_install_dir
# Case 2: ERR! Both locations are different but returned a valid result.
# NOTE: Unlike for source installations, pip does not wipe out artifacts from
# editable builds. In case developers are executing inside a TE directory via
# an inplace build, and then move to a regular build, the local shared object
# file will be incorrectly picked up without the following logic.
if so_path_in_install_dir is not None and so_path_in_default_dir is not None:
raise RuntimeError(
f"Found multiple shared object files: {so_path_in_install_dir} and"
f" {so_path_in_default_dir}. Remove local shared objects installed"
f" here {so_path_in_install_dir} or change the working directory to"
"execute from outside TE."
)
# Case 3: Typical dev workflow: Editable install
if so_path_in_install_dir is not None:
return so_path_in_install_dir
# Case 4: Executing from inside a TE directory without an inplace build available.
if so_path_in_default_dir is not None:
return so_path_in_default_dir
raise RuntimeError(f"Could not find shared object file for Transformer Engine {library} lib.")
@functools.lru_cache(maxsize=None)
def load_framework_extension(framework: str):
"""
Load shared library with Transformer Engine framework bindings
and check verify correctness if installed via PyPI.
"""
# Supported frameworks.
assert framework in ("jax", "torch"), f"Unsupported framework {framework}"
# Name of the framework extension library.
module_name = f"transformer_engine_{framework}"
# Name of the pip extra dependency for framework extensions from PyPI.
extra_dep_name = module_name
if framework == "torch":
extra_dep_name = "pytorch"
# If the framework extension pip package is installed, it means that TE is installed via
# PyPI. For this case we need to make sure that the metapackage, the core lib, and framework
# extension are all installed via PyPI and have matching version.
if _is_pip_package_installed(module_name):
assert _is_pip_package_installed(
"transformer_engine"
), "Could not find `transformer-engine`."
assert _is_pip_package_installed(
"transformer_engine_cu12"
), "Could not find `transformer-engine-cu12`."
assert (
version(module_name)
== version("transformer-engine")
== version("transformer-engine-cu12")
), (
"TransformerEngine package version mismatch. Found"
f" {module_name} v{version(module_name)}, transformer-engine"
f" v{version('transformer-engine')}, and transformer-engine-cu12"
f" v{version('transformer-engine-cu12')}. Install transformer-engine using "
f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'"
)
# If the core package is installed via PyPI, log if
# the framework extension is not found from PyPI.
# Note: Should we error? This is a rare use case.
if _is_pip_package_installed("transformer-engine-cu12"):
if not _is_pip_package_installed(module_name):
_logger.info(
"Could not find package %s. Install transformer-engine using "
f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'",
module_name,
)
def get_te_path():
"""Find Transformer Engine install path using pip"""
return Path(transformer_engine.__path__[0]).parent
# After all checks are completed, load the shared object file.
spec = importlib.util.spec_from_file_location(module_name, _get_shared_object_file(framework))
solib = importlib.util.module_from_spec(spec)
sys.modules[module_name] = solib
spec.loader.exec_module(solib)
@functools.lru_cache(maxsize=None)
def _get_sys_extension():
system = platform.system()
if system == "Linux":
......@@ -45,20 +210,47 @@ def _get_sys_extension():
return extension
def _load_cudnn():
"""Load CUDNN shared library."""
# Attempt to locate cuDNN in Python dist-packages
lib_path = glob.glob(
@functools.lru_cache(maxsize=None)
def _load_nvidia_cuda_library(lib_name: str):
"""
Attempts to load shared object file installed via pip.
`lib_name`: Name of package as found in the `nvidia` dir in python environment.
"""
so_paths = glob.glob(
os.path.join(
sysconfig.get_path("purelib"),
f"nvidia/cudnn/lib/libcudnn.{_get_sys_extension()}.*[0-9]",
f"nvidia/{lib_name}/lib/lib*.{_get_sys_extension()}.*[0-9]",
)
)
if lib_path:
assert (
len(lib_path) == 1
), f"Found {len(lib_path)} libcudnn.{_get_sys_extension()}.x in nvidia-cudnn-cuXX."
return ctypes.CDLL(lib_path[0], mode=ctypes.RTLD_GLOBAL)
path_found = len(so_paths) > 0
ctypes_handles = []
if path_found:
for so_path in so_paths:
ctypes_handles.append(ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL))
return path_found, ctypes_handles
@functools.lru_cache(maxsize=None)
def _nvidia_cudart_include_dir():
"""Returns the include directory for cuda_runtime.h if exists in python environment."""
try:
import nvidia
except ModuleNotFoundError:
return ""
include_dir = Path(nvidia.__file__).parent / "cuda_runtime"
return str(include_dir) if include_dir.exists() else ""
@functools.lru_cache(maxsize=None)
def _load_cudnn():
"""Load CUDNN shared library."""
# Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set
cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH")
......@@ -75,28 +267,16 @@ def _load_cudnn():
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate cuDNN in Python dist-packages
found, handle = _load_nvidia_cuda_library("cudnn")
if found:
return handle
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libcudnn.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
def _load_library():
"""Load shared library with Transformer Engine C extensions"""
so_path = get_te_path() / "transformer_engine" / f"libtransformer_engine.{_get_sys_extension()}"
if not so_path.exists():
so_path = (
get_te_path()
/ "transformer_engine"
/ "wheel_lib"
/ f"libtransformer_engine.{_get_sys_extension()}"
)
if not so_path.exists():
so_path = get_te_path() / f"libtransformer_engine.{_get_sys_extension()}"
assert so_path.exists(), f"Could not find libtransformer_engine.{_get_sys_extension()}"
return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
@functools.lru_cache(maxsize=None)
def _load_nvrtc():
"""Load NVRTC shared library."""
# Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda
......@@ -107,6 +287,11 @@ def _load_nvrtc():
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate NVRTC in Python dist-packages
found, handle = _load_nvidia_cuda_library("cuda_nvrtc")
if found:
return handle
# Attempt to locate NVRTC via ldconfig
libs = subprocess.check_output("ldconfig -p | grep 'libnvrtc'", shell=True)
libs = libs.decode("utf-8").split("\n")
......@@ -123,10 +308,22 @@ def _load_nvrtc():
return ctypes.CDLL(f"libnvrtc.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
@functools.lru_cache(maxsize=None)
def _load_core_library():
"""Load shared library with Transformer Engine C extensions"""
return ctypes.CDLL(_get_shared_object_file("core"), mode=ctypes.RTLD_GLOBAL)
if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
try:
_CUDNN_LIB_CTYPES = _load_cudnn()
_NVRTC_LIB_CTYPES = _load_nvrtc()
_CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas")
_CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime")
# Needed to find the correct headers for NVRTC kernels.
if not os.getenv("NVTE_CUDA_INCLUDE_DIR") and _nvidia_cudart_include_dir():
os.environ["NVTE_CUDA_INCLUDE_DIR"] = _nvidia_cudart_include_dir()
except OSError:
pass
_TE_LIB_CTYPES = _load_library()
_TE_LIB_CTYPES = _load_core_library()
......@@ -21,12 +21,18 @@
#define HALF_BYTES 2
#define UB_MAX_SM 32
#define AS_VECTOR(shape) std::vector<size_t>(shape.data, shape.data + shape.ndim)
using namespace std::placeholders;
namespace transformer_engine {
namespace {
std::vector<size_t> shape_to_vector(const NVTEShape &shape) {
return std::vector<size_t>(shape.data, shape.data + shape.ndim);
}
} // namespace
/***************************************************************************************************
* Comm+GEMM Overlap Common Core
**************************************************************************************************/
......@@ -147,13 +153,50 @@ CommOverlapCore::~CommOverlapCore() {
TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, size_t chunk_offset,
const std::vector<size_t> &chunk_shape) {
TensorWrapper chunk;
const auto scaling_mode = source.scaling_mode();
// Tensor dimensions
std::vector<size_t> shape = shape_to_vector(source.shape());
auto flatten_shape_to_2d = [](const std::vector<size_t> &shape) -> std::pair<size_t, size_t> {
if (shape.empty()) {
return {1, 1};
}
size_t height = 1;
for (size_t i = 0; i < shape.size() - 1; ++i) {
height *= shape[i];
}
return {height, shape.back()};
};
size_t height, width, chunk_height, chunk_width;
std::tie(height, width) = flatten_shape_to_2d(shape);
std::tie(chunk_height, chunk_width) = flatten_shape_to_2d(chunk_shape);
// Check tensor dimensions
#define NVTE_DIM_CHECK(cond, message) \
NVTE_CHECK(cond, message, " (tensor shape=", shape, ", chunk shape=", chunk_shape, \
", chunk offset=", chunk_offset, ")")
NVTE_DIM_CHECK(height > 0 && width > 0, "Attempted to get chunk from empty tensor");
NVTE_DIM_CHECK(chunk_height > 0 && chunk_width > 0, "Attempted to get empty tensor chunk");
NVTE_DIM_CHECK(chunk_height <= height && chunk_width <= width,
"Attempted to get out-of-bounds tensor chunk");
if (scaling_mode == NVTEScalingMode::NVTE_MXFP8_1D_SCALING) {
// MXFP8 scale-inverses are padded to a 2D matrix with dims that
// are divisible by 128. UB doesn't handle this padding yet.
NVTE_DIM_CHECK(height % 128 == 0 && width % 128 == 0,
"Userbuffers requires MXFP8 tensor dims that are divisible by 128");
NVTE_DIM_CHECK(chunk_height % 128 == 0 && chunk_width % 128 == 0,
"Userbuffers requires MXFP8 tensor chunk dims that are divisible by 128");
}
#undef NVTE_DIM_CHECK
// Construct tensor chunk
TensorWrapper chunk(scaling_mode);
for (int param_id = 0; param_id < NVTETensorParam::kNVTENumTensorParams; param_id++) {
auto param_type = static_cast<NVTETensorParam>(param_id);
auto param = source.get_parameter(param_type);
auto param_dptr = reinterpret_cast<char *>(param.data_ptr);
auto param_dtype = static_cast<DType>(param.dtype);
auto param_shape = AS_VECTOR(param.shape);
auto param_shape = shape_to_vector(param.shape);
if (param_dptr != nullptr) {
if (param_type == NVTETensorParam::kNVTERowwiseData ||
......@@ -163,8 +206,8 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
param_shape = chunk_shape;
if (param_type == NVTETensorParam::kNVTEColumnwiseData &&
source.scaling_mode() != NVTEScalingMode::NVTE_MXFP8_1D_SCALING) {
// Columnwise shape for non-block scaled tensors shifts the last dimension to the front
source.scaling_mode() == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) {
// Columnwise shape for FP8 tensor-scaled tensors shifts the last dimension to the front
auto last_dim = param_shape.back();
param_shape.pop_back();
param_shape.insert(param_shape.begin(), last_dim);
......@@ -172,18 +215,16 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
} else if (source.scaling_mode() == NVTEScalingMode::NVTE_MXFP8_1D_SCALING &&
(param_type == NVTETensorParam::kNVTERowwiseScaleInv ||
param_type == NVTETensorParam::kNVTEColumnwiseScaleInv)) {
// Calculate block scaling offset and size
auto scaled_tensor_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv)
? source.shape().data[0]
: source.columnwise_shape().data[0];
auto scaled_chunk_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv)
? chunk_shape.front()
: chunk_shape.back();
auto chunk_scale_start = chunk_offset / 32;
auto chunk_scale_end = (chunk_offset + scaled_chunk_dim_size) / 32;
auto chunk_scale_size = chunk_scale_end - chunk_scale_start;
param_dptr += chunk_scale_start * typeToSize(param_dtype);
param_shape = std::vector<size_t>{chunk_scale_size};
// Calculate offset and size for MXFP8 scale-invs
size_t chunk_scale_height = chunk_height;
size_t chunk_scale_width = chunk_width;
if (param_type == NVTETensorParam::kNVTERowwiseScaleInv) {
chunk_scale_width /= 32;
} else {
chunk_scale_height /= 32;
}
param_dptr += (chunk_offset / 32) * typeToSize(param_dtype);
param_shape = {chunk_scale_height, chunk_scale_width};
}
// Set chunked source parameters into the chunked tensor output
......@@ -434,10 +475,21 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
size_t k = transa ? A.size(1) : A.size(0);
size_t n = _ubuf.size(0);
size_t m_chunk = m / _num_splits;
const std::vector<size_t> input_a_chunk_shape =
(transa ? std::vector<size_t>{m_chunk, k} : std::vector<size_t>{k, m_chunk});
const std::vector<size_t> output_chunk_shape = {n, m_chunk};
size_t input_a_chunk_size = m_chunk * k;
size_t output_chunk_size = n * m_chunk;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
// Helper function to get bias chunk if needed
auto maybe_get_bias_chunk = [this, &bias, m_chunk](size_t chunk_id) -> TensorWrapper {
if (bias.dptr() == nullptr) {
return TensorWrapper();
}
return get_tensor_chunk(bias, chunk_id * m_chunk, {m_chunk});
};
// Catch up the default torch stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
for (size_t i = 0; i < _stream_compute.size(); i++) {
......@@ -449,12 +501,13 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
if (_rs_overlap_first_gemm) {
auto input_a_chunk = get_tensor_chunk(A, 0, {m_chunk, k});
auto output_chunk = get_buffer_chunk_like(D, 0, {m, m_chunk});
auto input_a_chunk = get_tensor_chunk(A, 0, input_a_chunk_shape);
auto output_chunk = get_buffer_chunk_like(D, 0, output_chunk_shape);
auto bias_chunk = maybe_get_bias_chunk(0);
auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk});
if (_ub_stream_nums == 1) {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[0]);
} else {
......@@ -464,18 +517,19 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
}
for (int i = 1; i < _num_splits; i++) {
input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k});
output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk});
input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, input_a_chunk_shape);
output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, output_chunk_shape);
bias_chunk = maybe_get_bias_chunk(i);
workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
if (_ub_stream_nums == 1) {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
} else {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size());
......@@ -519,13 +573,14 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
}
} else {
for (int i = 0; i < _num_splits; i++) {
auto input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k});
auto output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk});
auto input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, input_a_chunk_shape);
auto output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, output_chunk_shape);
auto bias_chunk = maybe_get_bias_chunk(i);
auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
if (_ub_stream_nums == 1) {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
......@@ -605,14 +660,17 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
void *buffer_ptr;
_ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true);
if (_rank == 0) printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg);
_ubuf = TensorWrapper(buffer_ptr, {buffer_shape[0] / tp_size * _num_ubuf_chunks, buffer_shape[1]},
_ubuf = TensorWrapper(
buffer_ptr,
std::vector<size_t>{buffer_shape[0] / tp_size * _num_ubuf_chunks, buffer_shape[1]},
buffer_dtype);
// Create tensor chunks for easy management
char *ubuf_byte_ptr = reinterpret_cast<char *>(buffer_ptr);
for (int i = 0; i < _num_ubuf_chunks; i++) {
_ubufs.push_back(TensorWrapper(reinterpret_cast<void *>(ubuf_byte_ptr),
{buffer_shape[0] / tp_size, buffer_shape[1]}, buffer_dtype));
std::vector<size_t>{buffer_shape[0] / tp_size, buffer_shape[1]},
buffer_dtype));
ubuf_byte_ptr += buffer_chunk_bytes;
}
......@@ -661,7 +719,7 @@ CommOverlapP2PBase::~CommOverlapP2PBase() {
TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source,
size_t chunk_id) {
// Start with a chunk of the source tensor
auto chunk = get_tensor_chunk(source, 0, AS_VECTOR(_ubufs[chunk_id].shape()));
auto chunk = get_tensor_chunk(source, 0, shape_to_vector(_ubufs[chunk_id].shape()));
// Update chunk with offset data pointers from the communication buffer
if (chunk.dptr() != nullptr) {
......@@ -711,7 +769,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0));
auto input_b = get_buffer_chunk_like(B, 0, AS_VECTOR(B.shape()));
auto input_b = get_buffer_chunk_like(B, 0, shape_to_vector(B.shape()));
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk});
......@@ -798,8 +856,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
// Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const bool do_gelu = pre_gelu_out.numel() > 0;
size_t input_chunk_size = n_chunk * k;
size_t output_chunk_size = n_chunk * m;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
......@@ -810,10 +866,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
}
if (_aggregate) {
const int num_steps = _tp_size / 2;
#ifndef __HIP_PLATFORM_AMD__
input_chunk_size *= 2;
output_chunk_size *= 2;
#endif
// Chunk dims
std::vector<size_t> input_b_chunk_shape =
(transb ? std::vector<size_t>{k, 2 * n_chunk} : std::vector<size_t>{2 * n_chunk, k});
std::vector<size_t> output_chunk_shape = {2 * n_chunk, k};
size_t input_b_chunk_size = 2 * n_chunk * k;
size_t output_chunk_size = 2 * n_chunk * m;
// Initial 1X input chunk exchange between neighboring peers
int send_chunk_id = _tp_id;
......@@ -842,8 +901,9 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
// GEMM
auto input_b_chunk =
get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk * 2, k});
auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk * 2, m});
get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape);
auto output_chunk =
get_tensor_chunk(D, output_chunk_size * send_chunk_id, output_chunk_shape);
auto aux_chunk =
(do_gelu)
? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk * 2, k})
......@@ -882,6 +942,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
}
}
} else {
// Chunk dims
std::vector<size_t> input_b_chunk_shape =
(transb ? std::vector<size_t>{k, n_chunk} : std::vector<size_t>{n_chunk, k});
std::vector<size_t> output_chunk_shape = {n_chunk, m};
size_t input_b_chunk_size = n_chunk * k;
size_t output_chunk_size = n_chunk * m;
for (int i = 0; i < _tp_size; i++) {
// Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
......@@ -893,8 +960,10 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
int recv_offset = comm_bytes * recv_chunk_id;
// GEMM
auto input_b_chunk = get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk, k});
auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk, m});
auto input_b_chunk =
get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape);
auto output_chunk =
get_tensor_chunk(D, output_chunk_size * send_chunk_id, output_chunk_shape);
auto aux_chunk =
(do_gelu)
? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk, k})
......@@ -972,7 +1041,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(
// Atomic GEMM
// Process GEMM chunks in the order that AG+GEMM places the output chunks.
auto output_d = get_buffer_chunk_like(D, 0, AS_VECTOR(D.shape()));
auto output_d = get_buffer_chunk_like(D, 0, shape_to_vector(D.shape()));
nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(),
transa, transb, grad, workspace.data(), accumulate, use_split_accumulator,
_math_sms, 0, _tp_size, true, _counter.data(), stream_main);
......@@ -1053,6 +1122,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
auto input_b_chunk = get_tensor_chunk(B, input_b_chunk_id * input_chunk_size, {n_chunk, k});
auto output_chunk = get_buffer_chunk_by_id(D, i);
auto workspace_chunk =
get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk});
......
......@@ -35,6 +35,65 @@ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) {
}
}
namespace {
constexpr size_t kThreadsPerBlock = 256;
template <typename TVectorized>
__global__ void __launch_bounds__(kThreadsPerBlock)
memset_kernel(void *__restrict__ ptr, int value, size_t size_in_bytes) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx * sizeof(TVectorized) >= size_in_bytes) {
return; // Out of bounds
}
if ((idx + 1) * sizeof(TVectorized) > size_in_bytes) {
// If the buffer size is not an even multiple of the vectorization, manually set the remaining bytes unvectorized.
size_t remaining_bytes = size_in_bytes - idx * sizeof(TVectorized);
memset(reinterpret_cast<uint8_t *>(ptr) + idx * sizeof(TVectorized), value, remaining_bytes);
return;
}
union {
TVectorized value;
uint8_t data[sizeof(TVectorized)];
} data;
for (size_t i = 0; i < sizeof(TVectorized); ++i) {
data.data[i] = static_cast<uint8_t>(value);
}
reinterpret_cast<TVectorized *>(ptr)[idx] = data.value;
}
} // namespace
#define MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, vectorizedType, stream) \
if (size_in_bytes >= sizeof(vectorizedType) && \
reinterpret_cast<size_t>(ptr) % sizeof(vectorizedType) == 0) { \
size_t numBlocks = DIVUP(size_in_bytes, kThreadsPerBlock * sizeof(vectorizedType)); \
dim3 grid(numBlocks, 1, 1); \
memset_kernel<vectorizedType> \
<<<grid, kThreadsPerBlock, 0, stream>>>(ptr, value, size_in_bytes); \
return; \
}
extern "C" {
void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream) {
NVTE_API_CALL(nvte_memset);
NVTE_CHECK(ptr != nullptr, "Pointer for memset must be allocated.");
if (size_in_bytes > 4096) {
// Use cudaMemsetAsync for larger sizes.
cudaMemsetAsync(ptr, value, size_in_bytes, stream);
return;
}
MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, float4, stream);
MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, float2, stream);
MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, float, stream);
MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, uint8_t, stream);
}
} // extern "C"
void checkCuDriverContext(CUstream stream) {
#ifdef __HIP_PLATFORM_AMD__
return;
......@@ -144,4 +203,16 @@ bool is_supported_by_CC_100() {
#endif
}
std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensors,
size_t outer_size, size_t inner_size) {
std::vector<std::vector<Tensor *>> ret;
for (size_t i = 0; i < outer_size; ++i) {
ret.emplace_back();
for (size_t j = 0; j < inner_size; ++j) {
ret.back().push_back(reinterpret_cast<Tensor *>(nvte_tensors[i][j]));
}
}
return ret;
}
} // namespace transformer_engine
......@@ -116,7 +116,7 @@ struct Tensor {
columnwise_scale_inv(nullptr, {1}, DType::kFloat32),
scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {}
int numel() const {
size_t numel() const {
size_t acc = 1;
for (const auto dim : shape()) {
acc *= dim;
......@@ -138,6 +138,14 @@ struct Tensor {
return data.dtype;
}
size_t dim() const {
if (!has_data() && has_columnwise_data()) {
return columnwise_data.shape.size();
} else {
return data.shape.size();
}
}
std::vector<size_t> shape() const {
/* Note: We sometimes experience spurious compiler errors
* (-Wstringop-overflow) from this function. It appears that GCC
......@@ -243,6 +251,7 @@ constexpr T DIVUP(const T &x, const T &y) {
}
using byte = uint8_t;
using int16 = int16_t;
using int32 = int32_t;
using int64 = int64_t;
using fp32 = float;
......@@ -271,6 +280,7 @@ constexpr inline const char *type_name() noexcept;
return #T; \
}
TRANSFORMER_ENGINE_TYPE_NAME(uint8_t)
TRANSFORMER_ENGINE_TYPE_NAME(int16_t)
TRANSFORMER_ENGINE_TYPE_NAME(int32_t)
TRANSFORMER_ENGINE_TYPE_NAME(int64_t)
TRANSFORMER_ENGINE_TYPE_NAME(float)
......@@ -327,7 +337,7 @@ struct TypeExtrema {
template <typename T>
struct TypeInfo {
using types = std::tuple<byte, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2>;
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2>;
template <typename U, DType current>
struct Helper {
......@@ -364,6 +374,10 @@ struct TypeInfo {
using type = unsigned char; \
{ __VA_ARGS__ } \
} break; \
case DType::kInt16: { \
using type = int16_t; \
{ __VA_ARGS__ } \
} break; \
case DType::kInt32: { \
using type = int32_t; \
{ __VA_ARGS__ } \
......@@ -400,6 +414,33 @@ struct TypeInfo {
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat16: { \
using type = fp16; \
{ __VA_ARGS__ } \
} break; \
case DType::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E4M3: { \
using type = fp8e4m3; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E5M2: { \
using type = fp8e5m2; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
......@@ -599,6 +640,9 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
bool is_supported_by_CC_100();
std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensors,
size_t outer_size, size_t inner_size);
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_
This diff is collapsed.
......@@ -1006,3 +1006,18 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
}
}
uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len,
cudaStream_t stream) {
NVTE_API_CALL(nvte_get_runtime_num_segments);
using namespace transformer_engine::fused_attn;
return GetRuntimeNumSegments(cu_seqlen, workspace, len, stream);
}
void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor seed,
size_t q_max_seqlen, size_t kv_max_seqlen,
NVTE_Fused_Attn_Backend backend, cudaStream_t stream) {
NVTE_API_CALL(nvte_populate_rng_state_async);
using namespace transformer_engine::fused_attn;
PopulateRngStateAsync(rng_state_dst, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment