Commit f8c2af4c authored by yuguo's avatar yuguo
Browse files

Merge commit '1d903f5e' of...

Merge commit '1d903f5e' of https://github.com/NVIDIA/TransformerEngine
parents e92773a3 1d903f5e
......@@ -11,7 +11,7 @@ from transformer_engine.pytorch.utils import (
get_device_compute_capability,
get_cudnn_version,
)
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils
from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils
from test_fused_attn import ModelConfig
from torch.utils.cpp_extension import IS_HIP_EXTENSION
......
......@@ -11,6 +11,12 @@ import math
import pytest
import torch
from test_fused_attn import (
ModelConfig,
reset_rng_states,
_get_attention_backends,
)
from torch.distributions import Exponential
from transformer_engine.pytorch import make_graphed_callables
from transformer_engine.common import recipe
......@@ -18,20 +24,15 @@ from transformer_engine.pytorch import fp8_autocast, fp8_model_init
from transformer_engine.pytorch.transformer import (
TransformerLayer,
)
from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils
from transformer_engine.pytorch.attention import DotProductAttention, InferenceParams
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
FlashAttentionUtils as fa_utils,
)
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
init_method_normal,
scaled_init_method_normal,
is_bf16_compatible,
)
from test_fused_attn import (
ModelConfig,
reset_rng_states,
_get_attention_backends,
)
# Initialize RNG state
seed = 1234
......
......@@ -392,6 +392,110 @@ class TestFloat8BlockwiseTensor:
with pytest.raises(AssertionError):
torch.testing.assert_close(x_view.dequantize(), -x_hp, **_tols[fp8_dtype])
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2], ids=str)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize(
"dims", [[16, 16, 512], [16, 16, 512, 16], [12, 7, 11], [13, 14, 16], [2, 3, 5]]
)
def test_view_and_reshape_1D(
self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: List[int]
) -> None:
"""Test view operations that preserve tensor shape"""
device = "cuda"
def is_bitwise_equal(a, b):
if a.numel() != b.numel():
return False
a_flat = a.reshape(-1).view(torch.uint8)
b_flat = b.reshape(-1).view(torch.uint8)
return torch.all((a_flat ^ b_flat) == 0)
x_hp = torch.rand(dims, dtype=dtype, device=device)
quantizer = Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=True,
columnwise=True,
block_scaling_dim=1,
)
x_fp8 = quantizer.make_empty(x_hp.shape, dtype=dtype, device=device)
quantizer.update_quantized(x_hp.clone(), x_fp8)
# Test view, high dimension tensor -> 2D tensor
x_hp_view = x_hp.view(-1, dims[-1]).contiguous()
x_fp8_view = x_fp8.view(-1, dims[-1])
# Check the dequantized result
torch.testing.assert_close(
x_fp8_view.dequantize().contiguous(), x_hp_view, **_tols[fp8_dtype]
)
# Check the bitwise equality of the inner data
assert is_bitwise_equal(x_fp8_view._rowwise_data, x_fp8._rowwise_data)
assert is_bitwise_equal(x_fp8_view._rowwise_scale_inv, x_fp8._rowwise_scale_inv)
# Check the data ptr
assert x_fp8_view._rowwise_data.data_ptr() == x_fp8._rowwise_data.data_ptr()
assert x_fp8_view._rowwise_scale_inv.data_ptr() == x_fp8._rowwise_scale_inv.data_ptr()
# Test reshape high dimension tensor -> 2D tensor
x_hp_reshape = x_hp.reshape(-1, dims[-1]).contiguous()
x_fp8_reshape = x_fp8.reshape(-1, dims[-1])
# Check the dequantized result
torch.testing.assert_close(
x_fp8_reshape.dequantize().contiguous(), x_hp_reshape, **_tols[fp8_dtype]
)
# Check the bitwise equality of the inner data
assert is_bitwise_equal(x_fp8_reshape._rowwise_data, x_fp8._rowwise_data)
assert is_bitwise_equal(x_fp8_reshape._rowwise_scale_inv, x_fp8._rowwise_scale_inv)
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2], ids=str)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("dims", [[16, 16, 512, 16], [2, 512, 512, 128], [3, 13, 14, 16]])
def test_view_and_reshape_2D(
self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: List[int]
) -> None:
"""Test view operations that preserve tensor shape"""
device = "cuda"
def is_bitwise_equal(a, b):
if a.numel() != b.numel():
return False
a_flat = a.reshape(-1).view(torch.uint8)
b_flat = b.reshape(-1).view(torch.uint8)
return torch.all((a_flat ^ b_flat) == 0)
x_hp = torch.rand(dims, dtype=dtype, device=device)
quantizer = Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=True,
columnwise=True,
block_scaling_dim=2,
)
x_fp8 = quantizer.make_empty(x_hp.shape, dtype=dtype, device=device)
quantizer.update_quantized(x_hp.clone(), x_fp8)
# Test view, high dimension tensor -> 2D tensor
x_hp_view = x_hp.view(-1, dims[-2], dims[-1]).contiguous()
x_fp8_view = x_fp8.view(-1, dims[-2], dims[-1])
# Check the dequantized result
torch.testing.assert_close(
x_fp8_view.dequantize().contiguous(), x_hp_view, **_tols[fp8_dtype]
)
# Check the bitwise equality of the inner data
assert is_bitwise_equal(x_fp8_view._rowwise_data, x_fp8._rowwise_data)
assert is_bitwise_equal(x_fp8_view._rowwise_scale_inv, x_fp8._rowwise_scale_inv)
# Check the data ptr
assert x_fp8_view._rowwise_data.data_ptr() == x_fp8._rowwise_data.data_ptr()
assert x_fp8_view._rowwise_scale_inv.data_ptr() == x_fp8._rowwise_scale_inv.data_ptr()
# Test reshape high dimension tensor -> 2D tensor
x_hp_reshape = x_hp.reshape(-1, dims[-2], dims[-1]).contiguous()
x_fp8_reshape = x_fp8.reshape(-1, dims[-2], dims[-1])
# Check the dequantized result
torch.testing.assert_close(
x_fp8_reshape.dequantize().contiguous(), x_hp_reshape, **_tols[fp8_dtype]
)
# Check the bitwise equality of the inner data
assert is_bitwise_equal(x_fp8_reshape._rowwise_data, x_fp8._rowwise_data)
assert is_bitwise_equal(x_fp8_reshape._rowwise_scale_inv, x_fp8._rowwise_scale_inv)
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize("dims", [[256, 512], [250, 500]])
......
......@@ -18,7 +18,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.constants import TE_DType, TE_DType_To_Torch
from transformer_engine.pytorch.utils import non_tn_fp8_gemm_supported
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
import transformer_engine_torch as tex
from references.ref_per_tensor_cs import ref_per_tensor_cs_cast
......@@ -400,7 +400,7 @@ class TestCurrentScalingFloat8Tensor:
"""Check numerical error when casting to FP8"""
# Skip invalid configurations
if non_tn_fp8_gemm_supported() and return_transpose:
if is_non_tn_fp8_gemm_supported() and return_transpose:
pytest.skip("FP8 transpose is neither needed nor supported on current system")
# Initialize random high precision data
......
......@@ -12,10 +12,11 @@ from torch import nn
from torch.testing._internal.common_device_type import largeTensorTest
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.pytorch.attention import MultiheadAttention
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
from transformer_engine.pytorch import fp8_model_init
from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import gpu_autocast_ctx
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
......@@ -596,7 +597,7 @@ class AdamTest:
gt_ = gt.clone()
# Reference
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model(x)
loss = ((gt - y) ** 2).mean()
......@@ -605,7 +606,7 @@ class AdamTest:
scaler.update()
# DUT
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()
......@@ -647,7 +648,7 @@ class AdamTest:
gt_ = gt.clone()
# Reference
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model(x)
loss = ((gt - y) ** 2).mean()
......@@ -656,7 +657,7 @@ class AdamTest:
scaler.update()
# DUT
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()
......@@ -705,7 +706,7 @@ class AdamTest:
gt_ = gt.clone()
# Reference
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model(x)
loss = ((gt - y) ** 2).mean()
......@@ -714,7 +715,7 @@ class AdamTest:
scaler.update()
# DUT
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from typing import Callable, Tuple, Union
import math
import pytest
import torch
from typing import Callable, Tuple, Union
from transformer_engine.pytorch.dot_product_attention.rope import (
import pytest
from transformer_engine.pytorch.attention.rope import (
RotaryPositionEmbedding,
apply_rotary_pos_emb,
)
......@@ -22,6 +22,7 @@ def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
return torch.sum(output * t)
@pytest.mark.parametrize("start_positions", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("seq_length", [2048, 4096])
@pytest.mark.parametrize("hidden_size", [128, 256])
......@@ -43,7 +44,17 @@ def test_fused_rope(
loss_func: Callable,
cp_size: int,
interleaved: bool,
start_positions: bool,
) -> None:
if margin == 0 and start_positions == True:
# This makes sure that the `start_positions` offsets being applied
# are with the maximum length of the rope embeddings.
pytest.skip("Skipping test with margin=0 and start_positions=True")
if start_positions == True and cp_size > 1:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest.skip("Skipping test with cp_size>1 and start_positions=True")
device = torch.device("cuda:0")
batch_size, head_num = 2, 64
t = torch.rand(
......@@ -51,6 +62,14 @@ def test_fused_rope(
dtype=dtype,
device=device,
)
# Get arbitrary offsets to be used with RoPE for all the sequences
start_positions = (
torch.randint(0, margin, (batch_size,), dtype=torch.int32, device=device)
if start_positions
else None
)
if tensor_format == "bshd":
t = t.transpose(0, 1).contiguous()
if transpose:
......@@ -69,14 +88,18 @@ def test_fused_rope(
t.float(),
emb,
tensor_format=tensor_format,
start_positions=start_positions,
interleaved=interleaved,
fused=False,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)
loss_unfused = loss_func(output_unfused)
if not isinstance(start_positions, torch.Tensor):
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None
# fused
......@@ -84,21 +107,29 @@ def test_fused_rope(
t,
emb,
tensor_format=tensor_format,
start_positions=start_positions,
interleaved=interleaved,
fused=True,
cp_size=cp_size,
cp_rank=cp_rank,
)
loss_fused = loss_func(output_fused)
if not isinstance(start_positions, torch.Tensor):
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None
torch.testing.assert_close(output_fused, output_unfused)
if not isinstance(start_positions, torch.Tensor):
torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous()
@pytest.mark.parametrize("margin", [10])
@pytest.mark.parametrize("start_positions", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
......@@ -114,10 +145,25 @@ def test_fused_rope_thd(
loss_func: Callable,
cp_size: int,
interleaved: bool,
start_positions: bool,
margin: int,
) -> None:
if start_positions == True and cp_size > 1:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest.skip("Skipping test with cp_size>1 and start_positions=True")
device = torch.device("cuda:0")
batch_size, head_num = 2, 64
cu_seqlens = [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048]
# Get arbitrary offsets to be used with RoPE for all the sequences
start_positions = (
torch.randint(0, margin, (len(cu_seqlens) - 1,), dtype=torch.int32, device=device)
if start_positions
else None
)
if cp_size > 1:
cu_seqlens_padded = [0]
for i in range(1, len(cu_seqlens)):
......@@ -152,6 +198,7 @@ def test_fused_rope_thd(
output_unfused = apply_rotary_pos_emb(
t.float(),
emb,
start_positions=start_positions,
tensor_format="thd",
interleaved=interleaved,
fused=False,
......@@ -160,6 +207,8 @@ def test_fused_rope_thd(
cp_rank=cp_rank,
).to(dtype)
loss_unfused = loss_func(output_unfused)
if not isinstance(start_positions, torch.Tensor):
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None
......@@ -168,6 +217,7 @@ def test_fused_rope_thd(
output_fused = apply_rotary_pos_emb(
t,
emb,
start_positions=start_positions,
interleaved=interleaved,
fused=True,
tensor_format="thd",
......@@ -176,9 +226,15 @@ def test_fused_rope_thd(
cp_rank=cp_rank,
)
loss_fused = loss_func(output_fused)
if not isinstance(start_positions, torch.Tensor):
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None
torch.testing.assert_close(output_fused, output_unfused)
if not isinstance(start_positions, torch.Tensor):
torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous()
......@@ -160,7 +160,7 @@ def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tens
normab = torch.cat((a.norm().view(1), b.norm().view(1)))
norm_per_tensor = norm_per_tensor.view(-1, 2)
else:
norm, _ = applier(tex.multi_tensor_l2norm, overflow_buf, [in_list], True)
norm, _ = applier(tex.multi_tensor_l2norm, overflow_buf, [in_list], False)
reference = torch.full(
[(sizea + sizeb) * repeat], val, dtype=torch.float32, device=device
......
......@@ -7,7 +7,6 @@ import math
import os
from typing import Dict, List, Tuple, Optional
import pytest
import copy
import random
import torch
......@@ -40,12 +39,12 @@ from transformer_engine.pytorch import (
Fp8Unpadding,
)
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.pytorch.utils import get_device_compute_capability, get_cudnn_version
from transformer_engine.common import recipe
import transformer_engine_torch as tex
......@@ -135,18 +134,20 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
def assert_allclose(
l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None
l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float = None, rtol: float = None
) -> bool:
"""Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs."
for i, (t1, t2) in enumerate(zip(l1, l2)):
tols = dict(atol=atol)
tols = dtype_tols(t2.dtype)
if rtol is not None:
tols["rtol"] = rtol
if atol is not None:
tols["atol"] = atol
result = torch.allclose(t1, t2, **tols)
if not result:
diff = torch.abs(t1 - t2)
tol = atol + (rtol * torch.abs(t2))
tol = tols["atol"] + (tols["rtol"] * torch.abs(t2))
exceed_mask = diff > tol
if exceed_mask.any():
indices = torch.nonzero(exceed_mask, as_tuple=True)
......@@ -2304,6 +2305,12 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
pytest.skip("FusedAttention and FlashAttention do not support FP32")
if use_RoPE:
pytest.skip("KV cache does not support starting positions for RoPE")
if (
backend == "FusedAttention"
and get_device_compute_capability() == (8, 9)
and get_cudnn_version() < (9, 11, 0)
):
pytest.skip("Skip KV cache for sm89 and cuDNN < 9.11")
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
......
......@@ -19,11 +19,12 @@ class TestParallelCrossEntropy:
label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none"
)
def generate_input(self, dtype: torch.dtype, swap_dim: bool):
def generate_input(self, dtype: torch.dtype, swap_dim: bool, ignore_idx: bool):
SQ = random.choice([64, 128])
batch = random.choice([1, 2])
vocab = random.choice([64000, 128000])
ignore = random.sample(range(0, SQ - 1), 5)
if swap_dim:
self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype).cuda()
......@@ -32,14 +33,27 @@ class TestParallelCrossEntropy:
self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype).cuda()
self.tar_test = torch.randint(0, vocab, (batch, SQ)).cuda()
if ignore_idx:
for i in ignore:
# Ignore 5 indices
if swap_dim:
self.tar_test[i][0] = -100
else:
self.tar_test[0][i] = -100
self.input_ref = torch.reshape(self.input_test.clone().detach(), (batch * SQ, vocab))
self.tar_ref = torch.reshape(self.tar_test.clone().detach(), (batch * SQ,))
def one_iteration_test(
self, dtype: torch.dtype, swap_dim: bool, label_smoothing: float, reduce_loss: bool
self,
dtype: torch.dtype,
swap_dim: bool,
label_smoothing: float,
reduce_loss: bool,
ignore_idx: bool = False,
):
self.generate_input(dtype, swap_dim)
self.generate_input(dtype, swap_dim, ignore_idx)
self.input_test.requires_grad_(True)
self.input_ref.requires_grad_(True)
......@@ -57,6 +71,8 @@ class TestParallelCrossEntropy:
test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss
torch.testing.assert_close(test_loss, ref_loss, check_dtype=False)
if ignore_idx:
print(test_loss, ref_loss)
if reduce_loss:
torch.testing.assert_close(
torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad
......@@ -106,3 +122,15 @@ class TestParallelCrossEntropy:
self.one_iteration_test(
dtype=torch.float32, swap_dim=False, label_smoothing=0, reduce_loss=False
)
def test_ignore_idx(self):
self.generate_iters(5)
self.generate_infra(False, 0)
for i in range(self.iters):
self.one_iteration_test(
dtype=torch.float32,
swap_dim=random.choice([True, False]),
label_smoothing=0,
reduce_loss=False,
ignore_idx=True,
)
......@@ -373,7 +373,9 @@ def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
torch.cuda.synchronize()
def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad):
def _test_sanity_common(
block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching=True
):
if skip_dgrad and skip_wgrad:
pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")
......@@ -389,7 +391,11 @@ def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad
use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
if not microbatching:
te_out = block(te_inp)
else:
_ = block(te_inp, is_first_microbatch=True)
te_out = block(te_inp, is_first_microbatch=False)
if isinstance(te_out, tuple):
te_out = te_out[0]
loss = te_out.sum()
......@@ -443,8 +449,16 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("microbatching", all_boolean)
def test_sanity_layernorm_linear(
dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, normalization
dtype,
fp8_recipe,
model,
skip_wgrad,
zero_centered_gamma,
skip_dgrad,
normalization,
microbatching,
):
config = model_configs[model]
......@@ -470,7 +484,7 @@ def test_sanity_layernorm_linear(
params_dtype=dtype,
device="cuda",
)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
@pytest.mark.parametrize("dtype", param_types)
......@@ -478,7 +492,8 @@ def test_sanity_layernorm_linear(
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
@pytest.mark.parametrize("microbatching", all_boolean)
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microbatching):
config = model_configs[model]
if fp8_recipe is not None:
......@@ -501,7 +516,7 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
params_dtype=dtype,
device="cuda",
)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
@pytest.mark.parametrize("dtype", param_types)
......@@ -600,8 +615,17 @@ def test_sanity_grouped_linear(
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("microbatching", all_boolean)
def test_sanity_layernorm_mlp(
dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, activation, normalization
dtype,
fp8_recipe,
model,
skip_wgrad,
zero_centered_gamma,
skip_dgrad,
activation,
normalization,
microbatching,
):
config = model_configs[model]
......@@ -630,7 +654,7 @@ def test_sanity_layernorm_mlp(
params_dtype=dtype,
device="cuda",
)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
@pytest.mark.parametrize("dtype", param_types)
......
......@@ -11,12 +11,12 @@ import transformer_engine.common
try:
from . import pytorch
except (ImportError, StopIteration) as e:
except ImportError as e:
pass
try:
from . import jax
except (ImportError, StopIteration) as e:
except ImportError as e:
pass
__version__ = str(metadata.version("transformer_engine"))
......@@ -111,6 +111,11 @@ if(USE_CUDA)
cudnn_utils.cpp
transformer_engine.cpp
common.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
transpose/cast_transpose.cu
transpose/transpose.cu
transpose/cast_transpose_fusion.cu
......@@ -148,6 +153,7 @@ if(USE_CUDA)
fused_rope/fused_rope.cu
recipe/current_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
......@@ -158,6 +164,11 @@ else()
cudnn_utils.cpp
transformer_engine.cpp
common.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
transpose/cast_transpose.cu
transpose/transpose.cu
transpose/cast_transpose_fusion.cu
......@@ -191,6 +202,7 @@ else()
fused_rope/fused_rope.cu
recipe/current_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
......@@ -345,6 +357,14 @@ target_include_directories(transformer_engine PRIVATE
set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
PROPERTIES
COMPILE_OPTIONS "--use_fast_math")
option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF)
......
......@@ -9,28 +9,193 @@ import glob
import sysconfig
import subprocess
import ctypes
import logging
import os
import platform
import importlib
import functools
from pathlib import Path
from importlib.metadata import version, metadata, PackageNotFoundError
import transformer_engine
_logger = logging.getLogger(__name__)
def is_package_installed(package):
"""Checks if a pip package is installed."""
return (
subprocess.run(
[sys.executable, "-m", "pip", "show", package], capture_output=True, check=False
).returncode
== 0
@functools.lru_cache(maxsize=None)
def _is_pip_package_installed(package):
"""Check if the given package is installed via pip."""
# This is needed because we only want to return true
# if the python package is installed via pip, and not
# if it's importable in the current directory due to
# the presence of the shared library module.
try:
metadata(package)
except PackageNotFoundError:
return False
return True
@functools.lru_cache(maxsize=None)
def _find_shared_object_in_te_dir(te_path: Path, prefix: str):
"""
Find a shared object file of given prefix in the top level TE directory.
Only the following locations are searched to avoid stray SOs and build
artifacts:
1. The given top level directory (editable install).
2. `transformer_engine` named directories (source install).
3. `wheel_lib` named directories (PyPI install).
Returns None if no shared object files are found.
Raises an error if multiple shared object files are found.
"""
# Ensure top level dir exists and has the module. before searching.
if not te_path.exists() or not (te_path / "transformer_engine").exists():
return None
files = []
search_paths = (
te_path,
te_path / "transformer_engine",
te_path / "transformer_engine/wheel_lib",
te_path / "wheel_lib",
)
# Search.
for dirname, _, names in os.walk(te_path):
if Path(dirname) in search_paths:
for name in names:
if name.startswith(prefix) and name.endswith(f".{_get_sys_extension()}"):
files.append(Path(dirname, name))
if len(files) == 0:
return None
if len(files) == 1:
return files[0]
raise RuntimeError(f"Multiple files found: {files}")
@functools.lru_cache(maxsize=None)
def _get_shared_object_file(library: str) -> Path:
"""
Return the path of the shared object file for the given TE
library, one of 'core', 'torch', or 'jax'.
Several factors affect finding the correct location of the shared object:
1. System and environment.
2. If the installation is from source or via PyPI.
- Source installed .sos are placed in top level dir
- Wheel/PyPI installed .sos are placed in 'wheel_lib' dir to avoid conflicts.
3. For source installations, is the install editable/inplace?
4. The user directory from where TE is being imported.
"""
# Check provided input and determine the correct prefix for .so.
assert library in ("core", "torch", "jax"), f"Unsupported TE library {library}."
if library == "core":
so_prefix = "libtransformer_engine"
else:
so_prefix = f"transformer_engine_{library}"
# Check TE install location (will be local if TE is available in current dir for import).
te_install_dir = Path(importlib.util.find_spec("transformer_engine").origin).parent.parent
so_path_in_install_dir = _find_shared_object_in_te_dir(te_install_dir, so_prefix)
# Check default python package install location in system.
site_packages_dir = Path(sysconfig.get_paths()["purelib"])
so_path_in_default_dir = _find_shared_object_in_te_dir(site_packages_dir, so_prefix)
# Case 1: Typical user workflow: Both locations are the same, return any result.
if te_install_dir == site_packages_dir:
assert (
so_path_in_install_dir is not None
), f"Could not find shared object file for Transformer Engine {library} lib."
return so_path_in_install_dir
# Case 2: ERR! Both locations are different but returned a valid result.
# NOTE: Unlike for source installations, pip does not wipe out artifacts from
# editable builds. In case developers are executing inside a TE directory via
# an inplace build, and then move to a regular build, the local shared object
# file will be incorrectly picked up without the following logic.
if so_path_in_install_dir is not None and so_path_in_default_dir is not None:
raise RuntimeError(
f"Found multiple shared object files: {so_path_in_install_dir} and"
f" {so_path_in_default_dir}. Remove local shared objects installed"
f" here {so_path_in_install_dir} or change the working directory to"
"execute from outside TE."
)
# Case 3: Typical dev workflow: Editable install
if so_path_in_install_dir is not None:
return so_path_in_install_dir
# Case 4: Executing from inside a TE directory without an inplace build available.
if so_path_in_default_dir is not None:
return so_path_in_default_dir
raise RuntimeError(f"Could not find shared object file for Transformer Engine {library} lib.")
@functools.lru_cache(maxsize=None)
def load_framework_extension(framework: str):
"""
Load shared library with Transformer Engine framework bindings
and check verify correctness if installed via PyPI.
"""
# Supported frameworks.
assert framework in ("jax", "torch"), f"Unsupported framework {framework}"
# Name of the framework extension library.
module_name = f"transformer_engine_{framework}"
# Name of the pip extra dependency for framework extensions from PyPI.
extra_dep_name = module_name
if framework == "torch":
extra_dep_name = "pytorch"
# If the framework extension pip package is installed, it means that TE is installed via
# PyPI. For this case we need to make sure that the metapackage, the core lib, and framework
# extension are all installed via PyPI and have matching version.
if _is_pip_package_installed(module_name):
assert _is_pip_package_installed(
"transformer_engine"
), "Could not find `transformer-engine`."
assert _is_pip_package_installed(
"transformer_engine_cu12"
), "Could not find `transformer-engine-cu12`."
assert (
version(module_name)
== version("transformer-engine")
== version("transformer-engine-cu12")
), (
"TransformerEngine package version mismatch. Found"
f" {module_name} v{version(module_name)}, transformer-engine"
f" v{version('transformer-engine')}, and transformer-engine-cu12"
f" v{version('transformer-engine-cu12')}. Install transformer-engine using "
f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'"
)
# If the core package is installed via PyPI, log if
# the framework extension is not found from PyPI.
# Note: Should we error? This is a rare use case.
if _is_pip_package_installed("transformer-engine-cu12"):
if not _is_pip_package_installed(module_name):
_logger.info(
"Could not find package %s. Install transformer-engine using "
f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'",
module_name,
)
def get_te_path():
"""Find Transformer Engine install path using pip"""
return Path(transformer_engine.__path__[0]).parent
# After all checks are completed, load the shared object file.
spec = importlib.util.spec_from_file_location(module_name, _get_shared_object_file(framework))
solib = importlib.util.module_from_spec(spec)
sys.modules[module_name] = solib
spec.loader.exec_module(solib)
@functools.lru_cache(maxsize=None)
def _get_sys_extension():
system = platform.system()
if system == "Linux":
......@@ -45,20 +210,47 @@ def _get_sys_extension():
return extension
def _load_cudnn():
"""Load CUDNN shared library."""
# Attempt to locate cuDNN in Python dist-packages
lib_path = glob.glob(
@functools.lru_cache(maxsize=None)
def _load_nvidia_cuda_library(lib_name: str):
"""
Attempts to load shared object file installed via pip.
`lib_name`: Name of package as found in the `nvidia` dir in python environment.
"""
so_paths = glob.glob(
os.path.join(
sysconfig.get_path("purelib"),
f"nvidia/cudnn/lib/libcudnn.{_get_sys_extension()}.*[0-9]",
f"nvidia/{lib_name}/lib/lib*.{_get_sys_extension()}.*[0-9]",
)
)
if lib_path:
assert (
len(lib_path) == 1
), f"Found {len(lib_path)} libcudnn.{_get_sys_extension()}.x in nvidia-cudnn-cuXX."
return ctypes.CDLL(lib_path[0], mode=ctypes.RTLD_GLOBAL)
path_found = len(so_paths) > 0
ctypes_handles = []
if path_found:
for so_path in so_paths:
ctypes_handles.append(ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL))
return path_found, ctypes_handles
@functools.lru_cache(maxsize=None)
def _nvidia_cudart_include_dir():
"""Returns the include directory for cuda_runtime.h if exists in python environment."""
try:
import nvidia
except ModuleNotFoundError:
return ""
include_dir = Path(nvidia.__file__).parent / "cuda_runtime"
return str(include_dir) if include_dir.exists() else ""
@functools.lru_cache(maxsize=None)
def _load_cudnn():
"""Load CUDNN shared library."""
# Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set
cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH")
......@@ -75,28 +267,16 @@ def _load_cudnn():
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate cuDNN in Python dist-packages
found, handle = _load_nvidia_cuda_library("cudnn")
if found:
return handle
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libcudnn.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
def _load_library():
"""Load shared library with Transformer Engine C extensions"""
so_path = get_te_path() / "transformer_engine" / f"libtransformer_engine.{_get_sys_extension()}"
if not so_path.exists():
so_path = (
get_te_path()
/ "transformer_engine"
/ "wheel_lib"
/ f"libtransformer_engine.{_get_sys_extension()}"
)
if not so_path.exists():
so_path = get_te_path() / f"libtransformer_engine.{_get_sys_extension()}"
assert so_path.exists(), f"Could not find libtransformer_engine.{_get_sys_extension()}"
return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
@functools.lru_cache(maxsize=None)
def _load_nvrtc():
"""Load NVRTC shared library."""
# Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda
......@@ -107,6 +287,11 @@ def _load_nvrtc():
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate NVRTC in Python dist-packages
found, handle = _load_nvidia_cuda_library("cuda_nvrtc")
if found:
return handle
# Attempt to locate NVRTC via ldconfig
libs = subprocess.check_output("ldconfig -p | grep 'libnvrtc'", shell=True)
libs = libs.decode("utf-8").split("\n")
......@@ -123,10 +308,22 @@ def _load_nvrtc():
return ctypes.CDLL(f"libnvrtc.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
@functools.lru_cache(maxsize=None)
def _load_core_library():
"""Load shared library with Transformer Engine C extensions"""
return ctypes.CDLL(_get_shared_object_file("core"), mode=ctypes.RTLD_GLOBAL)
if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
try:
_CUDNN_LIB_CTYPES = _load_cudnn()
_NVRTC_LIB_CTYPES = _load_nvrtc()
_CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas")
_CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime")
# Needed to find the correct headers for NVRTC kernels.
if not os.getenv("NVTE_CUDA_INCLUDE_DIR") and _nvidia_cudart_include_dir():
os.environ["NVTE_CUDA_INCLUDE_DIR"] = _nvidia_cudart_include_dir()
except OSError:
pass
_TE_LIB_CTYPES = _load_library()
_TE_LIB_CTYPES = _load_core_library()
......@@ -21,12 +21,18 @@
#define HALF_BYTES 2
#define UB_MAX_SM 32
#define AS_VECTOR(shape) std::vector<size_t>(shape.data, shape.data + shape.ndim)
using namespace std::placeholders;
namespace transformer_engine {
namespace {
std::vector<size_t> shape_to_vector(const NVTEShape &shape) {
return std::vector<size_t>(shape.data, shape.data + shape.ndim);
}
} // namespace
/***************************************************************************************************
* Comm+GEMM Overlap Common Core
**************************************************************************************************/
......@@ -147,13 +153,50 @@ CommOverlapCore::~CommOverlapCore() {
TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, size_t chunk_offset,
const std::vector<size_t> &chunk_shape) {
TensorWrapper chunk;
const auto scaling_mode = source.scaling_mode();
// Tensor dimensions
std::vector<size_t> shape = shape_to_vector(source.shape());
auto flatten_shape_to_2d = [](const std::vector<size_t> &shape) -> std::pair<size_t, size_t> {
if (shape.empty()) {
return {1, 1};
}
size_t height = 1;
for (size_t i = 0; i < shape.size() - 1; ++i) {
height *= shape[i];
}
return {height, shape.back()};
};
size_t height, width, chunk_height, chunk_width;
std::tie(height, width) = flatten_shape_to_2d(shape);
std::tie(chunk_height, chunk_width) = flatten_shape_to_2d(chunk_shape);
// Check tensor dimensions
#define NVTE_DIM_CHECK(cond, message) \
NVTE_CHECK(cond, message, " (tensor shape=", shape, ", chunk shape=", chunk_shape, \
", chunk offset=", chunk_offset, ")")
NVTE_DIM_CHECK(height > 0 && width > 0, "Attempted to get chunk from empty tensor");
NVTE_DIM_CHECK(chunk_height > 0 && chunk_width > 0, "Attempted to get empty tensor chunk");
NVTE_DIM_CHECK(chunk_height <= height && chunk_width <= width,
"Attempted to get out-of-bounds tensor chunk");
if (scaling_mode == NVTEScalingMode::NVTE_MXFP8_1D_SCALING) {
// MXFP8 scale-inverses are padded to a 2D matrix with dims that
// are divisible by 128. UB doesn't handle this padding yet.
NVTE_DIM_CHECK(height % 128 == 0 && width % 128 == 0,
"Userbuffers requires MXFP8 tensor dims that are divisible by 128");
NVTE_DIM_CHECK(chunk_height % 128 == 0 && chunk_width % 128 == 0,
"Userbuffers requires MXFP8 tensor chunk dims that are divisible by 128");
}
#undef NVTE_DIM_CHECK
// Construct tensor chunk
TensorWrapper chunk(scaling_mode);
for (int param_id = 0; param_id < NVTETensorParam::kNVTENumTensorParams; param_id++) {
auto param_type = static_cast<NVTETensorParam>(param_id);
auto param = source.get_parameter(param_type);
auto param_dptr = reinterpret_cast<char *>(param.data_ptr);
auto param_dtype = static_cast<DType>(param.dtype);
auto param_shape = AS_VECTOR(param.shape);
auto param_shape = shape_to_vector(param.shape);
if (param_dptr != nullptr) {
if (param_type == NVTETensorParam::kNVTERowwiseData ||
......@@ -163,8 +206,8 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
param_shape = chunk_shape;
if (param_type == NVTETensorParam::kNVTEColumnwiseData &&
source.scaling_mode() != NVTEScalingMode::NVTE_MXFP8_1D_SCALING) {
// Columnwise shape for non-block scaled tensors shifts the last dimension to the front
source.scaling_mode() == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) {
// Columnwise shape for FP8 tensor-scaled tensors shifts the last dimension to the front
auto last_dim = param_shape.back();
param_shape.pop_back();
param_shape.insert(param_shape.begin(), last_dim);
......@@ -172,18 +215,16 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
} else if (source.scaling_mode() == NVTEScalingMode::NVTE_MXFP8_1D_SCALING &&
(param_type == NVTETensorParam::kNVTERowwiseScaleInv ||
param_type == NVTETensorParam::kNVTEColumnwiseScaleInv)) {
// Calculate block scaling offset and size
auto scaled_tensor_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv)
? source.shape().data[0]
: source.columnwise_shape().data[0];
auto scaled_chunk_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv)
? chunk_shape.front()
: chunk_shape.back();
auto chunk_scale_start = chunk_offset / 32;
auto chunk_scale_end = (chunk_offset + scaled_chunk_dim_size) / 32;
auto chunk_scale_size = chunk_scale_end - chunk_scale_start;
param_dptr += chunk_scale_start * typeToSize(param_dtype);
param_shape = std::vector<size_t>{chunk_scale_size};
// Calculate offset and size for MXFP8 scale-invs
size_t chunk_scale_height = chunk_height;
size_t chunk_scale_width = chunk_width;
if (param_type == NVTETensorParam::kNVTERowwiseScaleInv) {
chunk_scale_width /= 32;
} else {
chunk_scale_height /= 32;
}
param_dptr += (chunk_offset / 32) * typeToSize(param_dtype);
param_shape = {chunk_scale_height, chunk_scale_width};
}
// Set chunked source parameters into the chunked tensor output
......@@ -434,10 +475,21 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
size_t k = transa ? A.size(1) : A.size(0);
size_t n = _ubuf.size(0);
size_t m_chunk = m / _num_splits;
const std::vector<size_t> input_a_chunk_shape =
(transa ? std::vector<size_t>{m_chunk, k} : std::vector<size_t>{k, m_chunk});
const std::vector<size_t> output_chunk_shape = {n, m_chunk};
size_t input_a_chunk_size = m_chunk * k;
size_t output_chunk_size = n * m_chunk;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
// Helper function to get bias chunk if needed
auto maybe_get_bias_chunk = [this, &bias, m_chunk](size_t chunk_id) -> TensorWrapper {
if (bias.dptr() == nullptr) {
return TensorWrapper();
}
return get_tensor_chunk(bias, chunk_id * m_chunk, {m_chunk});
};
// Catch up the default torch stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
for (size_t i = 0; i < _stream_compute.size(); i++) {
......@@ -449,12 +501,13 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
if (_rs_overlap_first_gemm) {
auto input_a_chunk = get_tensor_chunk(A, 0, {m_chunk, k});
auto output_chunk = get_buffer_chunk_like(D, 0, {m, m_chunk});
auto input_a_chunk = get_tensor_chunk(A, 0, input_a_chunk_shape);
auto output_chunk = get_buffer_chunk_like(D, 0, output_chunk_shape);
auto bias_chunk = maybe_get_bias_chunk(0);
auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk});
if (_ub_stream_nums == 1) {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[0]);
} else {
......@@ -464,18 +517,19 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
}
for (int i = 1; i < _num_splits; i++) {
input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k});
output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk});
input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, input_a_chunk_shape);
output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, output_chunk_shape);
bias_chunk = maybe_get_bias_chunk(i);
workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
if (_ub_stream_nums == 1) {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
} else {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size());
......@@ -519,13 +573,14 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
}
} else {
for (int i = 0; i < _num_splits; i++) {
auto input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k});
auto output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk});
auto input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, input_a_chunk_shape);
auto output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, output_chunk_shape);
auto bias_chunk = maybe_get_bias_chunk(i);
auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
if (_ub_stream_nums == 1) {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
......@@ -605,14 +660,17 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
void *buffer_ptr;
_ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true);
if (_rank == 0) printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg);
_ubuf = TensorWrapper(buffer_ptr, {buffer_shape[0] / tp_size * _num_ubuf_chunks, buffer_shape[1]},
_ubuf = TensorWrapper(
buffer_ptr,
std::vector<size_t>{buffer_shape[0] / tp_size * _num_ubuf_chunks, buffer_shape[1]},
buffer_dtype);
// Create tensor chunks for easy management
char *ubuf_byte_ptr = reinterpret_cast<char *>(buffer_ptr);
for (int i = 0; i < _num_ubuf_chunks; i++) {
_ubufs.push_back(TensorWrapper(reinterpret_cast<void *>(ubuf_byte_ptr),
{buffer_shape[0] / tp_size, buffer_shape[1]}, buffer_dtype));
std::vector<size_t>{buffer_shape[0] / tp_size, buffer_shape[1]},
buffer_dtype));
ubuf_byte_ptr += buffer_chunk_bytes;
}
......@@ -661,7 +719,7 @@ CommOverlapP2PBase::~CommOverlapP2PBase() {
TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source,
size_t chunk_id) {
// Start with a chunk of the source tensor
auto chunk = get_tensor_chunk(source, 0, AS_VECTOR(_ubufs[chunk_id].shape()));
auto chunk = get_tensor_chunk(source, 0, shape_to_vector(_ubufs[chunk_id].shape()));
// Update chunk with offset data pointers from the communication buffer
if (chunk.dptr() != nullptr) {
......@@ -711,7 +769,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0));
auto input_b = get_buffer_chunk_like(B, 0, AS_VECTOR(B.shape()));
auto input_b = get_buffer_chunk_like(B, 0, shape_to_vector(B.shape()));
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk});
......@@ -798,8 +856,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
// Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const bool do_gelu = pre_gelu_out.numel() > 0;
size_t input_chunk_size = n_chunk * k;
size_t output_chunk_size = n_chunk * m;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
......@@ -810,10 +866,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
}
if (_aggregate) {
const int num_steps = _tp_size / 2;
#ifndef __HIP_PLATFORM_AMD__
input_chunk_size *= 2;
output_chunk_size *= 2;
#endif
// Chunk dims
std::vector<size_t> input_b_chunk_shape =
(transb ? std::vector<size_t>{k, 2 * n_chunk} : std::vector<size_t>{2 * n_chunk, k});
std::vector<size_t> output_chunk_shape = {2 * n_chunk, k};
size_t input_b_chunk_size = 2 * n_chunk * k;
size_t output_chunk_size = 2 * n_chunk * m;
// Initial 1X input chunk exchange between neighboring peers
int send_chunk_id = _tp_id;
......@@ -842,8 +901,9 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
// GEMM
auto input_b_chunk =
get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk * 2, k});
auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk * 2, m});
get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape);
auto output_chunk =
get_tensor_chunk(D, output_chunk_size * send_chunk_id, output_chunk_shape);
auto aux_chunk =
(do_gelu)
? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk * 2, k})
......@@ -882,6 +942,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
}
}
} else {
// Chunk dims
std::vector<size_t> input_b_chunk_shape =
(transb ? std::vector<size_t>{k, n_chunk} : std::vector<size_t>{n_chunk, k});
std::vector<size_t> output_chunk_shape = {n_chunk, m};
size_t input_b_chunk_size = n_chunk * k;
size_t output_chunk_size = n_chunk * m;
for (int i = 0; i < _tp_size; i++) {
// Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
......@@ -893,8 +960,10 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
int recv_offset = comm_bytes * recv_chunk_id;
// GEMM
auto input_b_chunk = get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk, k});
auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk, m});
auto input_b_chunk =
get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape);
auto output_chunk =
get_tensor_chunk(D, output_chunk_size * send_chunk_id, output_chunk_shape);
auto aux_chunk =
(do_gelu)
? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk, k})
......@@ -972,7 +1041,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(
// Atomic GEMM
// Process GEMM chunks in the order that AG+GEMM places the output chunks.
auto output_d = get_buffer_chunk_like(D, 0, AS_VECTOR(D.shape()));
auto output_d = get_buffer_chunk_like(D, 0, shape_to_vector(D.shape()));
nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(),
transa, transb, grad, workspace.data(), accumulate, use_split_accumulator,
_math_sms, 0, _tp_size, true, _counter.data(), stream_main);
......@@ -1053,6 +1122,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
auto input_b_chunk = get_tensor_chunk(B, input_b_chunk_id * input_chunk_size, {n_chunk, k});
auto output_chunk = get_buffer_chunk_by_id(D, i);
auto workspace_chunk =
get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk});
......
......@@ -35,6 +35,65 @@ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) {
}
}
namespace {
constexpr size_t kThreadsPerBlock = 256;
template <typename TVectorized>
__global__ void __launch_bounds__(kThreadsPerBlock)
memset_kernel(void *__restrict__ ptr, int value, size_t size_in_bytes) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx * sizeof(TVectorized) >= size_in_bytes) {
return; // Out of bounds
}
if ((idx + 1) * sizeof(TVectorized) > size_in_bytes) {
// If the buffer size is not an even multiple of the vectorization, manually set the remaining bytes unvectorized.
size_t remaining_bytes = size_in_bytes - idx * sizeof(TVectorized);
memset(reinterpret_cast<uint8_t *>(ptr) + idx * sizeof(TVectorized), value, remaining_bytes);
return;
}
union {
TVectorized value;
uint8_t data[sizeof(TVectorized)];
} data;
for (size_t i = 0; i < sizeof(TVectorized); ++i) {
data.data[i] = static_cast<uint8_t>(value);
}
reinterpret_cast<TVectorized *>(ptr)[idx] = data.value;
}
} // namespace
#define MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, vectorizedType, stream) \
if (size_in_bytes >= sizeof(vectorizedType) && \
reinterpret_cast<size_t>(ptr) % sizeof(vectorizedType) == 0) { \
size_t numBlocks = DIVUP(size_in_bytes, kThreadsPerBlock * sizeof(vectorizedType)); \
dim3 grid(numBlocks, 1, 1); \
memset_kernel<vectorizedType> \
<<<grid, kThreadsPerBlock, 0, stream>>>(ptr, value, size_in_bytes); \
return; \
}
extern "C" {
void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream) {
NVTE_API_CALL(nvte_memset);
NVTE_CHECK(ptr != nullptr, "Pointer for memset must be allocated.");
if (size_in_bytes > 4096) {
// Use cudaMemsetAsync for larger sizes.
cudaMemsetAsync(ptr, value, size_in_bytes, stream);
return;
}
MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, float4, stream);
MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, float2, stream);
MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, float, stream);
MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, uint8_t, stream);
}
} // extern "C"
void checkCuDriverContext(CUstream stream) {
#ifdef __HIP_PLATFORM_AMD__
return;
......@@ -144,4 +203,16 @@ bool is_supported_by_CC_100() {
#endif
}
std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensors,
size_t outer_size, size_t inner_size) {
std::vector<std::vector<Tensor *>> ret;
for (size_t i = 0; i < outer_size; ++i) {
ret.emplace_back();
for (size_t j = 0; j < inner_size; ++j) {
ret.back().push_back(reinterpret_cast<Tensor *>(nvte_tensors[i][j]));
}
}
return ret;
}
} // namespace transformer_engine
......@@ -116,7 +116,7 @@ struct Tensor {
columnwise_scale_inv(nullptr, {1}, DType::kFloat32),
scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {}
int numel() const {
size_t numel() const {
size_t acc = 1;
for (const auto dim : shape()) {
acc *= dim;
......@@ -138,6 +138,14 @@ struct Tensor {
return data.dtype;
}
size_t dim() const {
if (!has_data() && has_columnwise_data()) {
return columnwise_data.shape.size();
} else {
return data.shape.size();
}
}
std::vector<size_t> shape() const {
/* Note: We sometimes experience spurious compiler errors
* (-Wstringop-overflow) from this function. It appears that GCC
......@@ -243,6 +251,7 @@ constexpr T DIVUP(const T &x, const T &y) {
}
using byte = uint8_t;
using int16 = int16_t;
using int32 = int32_t;
using int64 = int64_t;
using fp32 = float;
......@@ -271,6 +280,7 @@ constexpr inline const char *type_name() noexcept;
return #T; \
}
TRANSFORMER_ENGINE_TYPE_NAME(uint8_t)
TRANSFORMER_ENGINE_TYPE_NAME(int16_t)
TRANSFORMER_ENGINE_TYPE_NAME(int32_t)
TRANSFORMER_ENGINE_TYPE_NAME(int64_t)
TRANSFORMER_ENGINE_TYPE_NAME(float)
......@@ -327,7 +337,7 @@ struct TypeExtrema {
template <typename T>
struct TypeInfo {
using types = std::tuple<byte, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2>;
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2>;
template <typename U, DType current>
struct Helper {
......@@ -364,6 +374,10 @@ struct TypeInfo {
using type = unsigned char; \
{ __VA_ARGS__ } \
} break; \
case DType::kInt16: { \
using type = int16_t; \
{ __VA_ARGS__ } \
} break; \
case DType::kInt32: { \
using type = int32_t; \
{ __VA_ARGS__ } \
......@@ -400,6 +414,33 @@ struct TypeInfo {
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat16: { \
using type = fp16; \
{ __VA_ARGS__ } \
} break; \
case DType::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E4M3: { \
using type = fp8e4m3; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E5M2: { \
using type = fp8e5m2; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
......@@ -599,6 +640,9 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
bool is_supported_by_CC_100();
std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensors,
size_t outer_size, size_t inner_size);
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_
......@@ -3,21 +3,25 @@
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_CUH_
#define TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_CUH_
#include <assert.h>
#include <cuda.h>
#include <cuda_bf16.h>
#include "../common.h"
#include "transformer_engine/fused_attn.h"
namespace transformer_engine {
namespace context_parallel {
struct LseCorrectionFunctor {
__forceinline__ __device__ static void run(double *lse, float *half_lse, size_t idx,
__forceinline__ __device__ static void run(float *lse, float *half_lse, size_t idx,
size_t half_idx) {
double val = lse[idx];
float val = lse[idx];
float val_per_step = half_lse[half_idx];
double max_scale = max(val, val_per_step);
double min_scale = min(val, val_per_step);
lse[idx] = max_scale + log(1.0 + exp(min_scale - max_scale));
float max_scale = max(val, val_per_step);
float min_scale = min(val, val_per_step);
lse[idx] = max_scale + log1pf(expf(min_scale - max_scale));
}
};
......@@ -49,16 +53,13 @@ struct AddFunctor {
#pragma unroll
for (int i = 0; i < sizeof(float4) / sizeof(dtype); i++) {
p_[i] += p[i];
p_[i] = p_[i] + p[i];
}
reinterpret_cast<float4 *>(token)[idx] = d_;
}
};
namespace transformer_engine {
namespace fused_attn {
/***************************************************************************************************
* Support THD format for Context Parallel: Binary search an array for a target value
**************************************************************************************************/
......@@ -107,6 +108,7 @@ __global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int b
/***************************************************************************************************
* Support THD format for Context Parallel: Read the half of a THD tensor
**************************************************************************************************/
__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch,
int hidden_size_in_bytes, int half_idx,
int dim_size_of_token) {
......@@ -148,8 +150,8 @@ __global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_se
* Support THD format for Context Parallel: softmax_lse related operations
**************************************************************************************************/
template <typename lse_dtype, bool lse_packed, typename Functor>
__global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, int batch,
template <bool lse_packed, typename Functor>
__global__ void thd_lse_kernel(float *lse, float *half_lse, int *cu_seqlens, int batch,
int num_heads, int lse_seqlen, int second_half_lse_seqlen) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
......@@ -218,7 +220,7 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float
idx = row * lse_seqlen + col + seq_len * only_second_half;
idx_per_step = row * lse_per_step_seqlen + col;
}
float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]);
float lse_corrected_exp = expf(lse_per_step[idx_per_step] - lse[idx]);
idx = token_id + cu_seqlens_s[seq_id + 1] * only_second_half;
idx = (idx * num_heads + head_id) * dim_per_head;
......@@ -232,7 +234,10 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float
dtype *p_per_step = reinterpret_cast<dtype *>(&data_per_step);
dtype *p = reinterpret_cast<dtype *>(&data);
for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) {
p[k] += (p_per_step[k] == 0 ? 0 : p_per_step[k] * lse_corrected_exp);
p[k] = p[k] +
(p_per_step[k] == static_cast<dtype>(0.f)
? static_cast<dtype>(0.f)
: static_cast<dtype>(static_cast<float>(p_per_step[k]) * lse_corrected_exp));
}
reinterpret_cast<float4 *>(cur_out)[j] = data;
}
......@@ -297,6 +302,442 @@ __global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, in
}
}
} // namespace fused_attn
/***************************************************************************************************
* Support THD format for Context Parallel: Read the half of a THD tensor
**************************************************************************************************/
void thd_read_half_tensor(const Tensor &tensor, const Tensor &cu_seqlens, Tensor &half,
int half_idx, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(tensor.dim() == 3 || tensor.dim() == 4);
NVTE_CHECK(cu_seqlens.dtype() == DType::kInt32);
auto cu_seqlens_shape = cu_seqlens.shape();
auto tensor_shape = tensor.shape();
NVTE_CHECK(cu_seqlens.dim() == 1);
NVTE_CHECK(cu_seqlens_shape[0] >= 2);
// Shapes of q and dq are [t, h, d], so the dimension of "t" is 0
// Shapes of kv and dkv are [2, t, h, d], so the dimension of "t" is 1
int seq_dim = tensor.dim() == 3 ? 0 : 1;
int batch = cu_seqlens_shape[0] - 1;
int num_heads = tensor_shape[seq_dim + 1];
int dim_per_head = tensor_shape[seq_dim + 2];
int hidden_size_in_bytes = num_heads * dim_per_head * typeToSize(tensor.dtype());
// For 128-bits load/store
NVTE_CHECK(hidden_size_in_bytes % 16 == 0);
// Launch Kernel
constexpr unsigned int block = 256;
unsigned int grid_x = (tensor_shape[seq_dim] / 2 * 32 + block - 1) / block;
unsigned int grid_y = 1;
for (int i = 0; i < seq_dim; i++) {
grid_y *= tensor_shape[i];
}
dim3 grid = {grid_x, grid_y};
thd_read_half_tensor_kernel<<<grid, block, sizeof(int) * (batch + 1), stream>>>(
half.data.dptr, tensor.data.dptr, reinterpret_cast<int *>(cu_seqlens.data.dptr), batch,
hidden_size_in_bytes, half_idx, tensor_shape[seq_dim]);
}
/***************************************************************************************************
* Support THD format for Context Parallel: softmax_lse related operations
**************************************************************************************************/
void thd_second_half_lse_correction(Tensor lse, const Tensor &lse_per_step,
const Tensor &cu_seqlens, bool lse_packed,
cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(lse.dtype() == DType::kFloat32);
NVTE_CHECK(lse_per_step.dtype() == DType::kFloat32);
NVTE_CHECK(cu_seqlens.dtype() == DType::kInt32);
NVTE_CHECK(cu_seqlens.dim() == 1);
int batch, num_heads, lse_seqlen, second_half_lse_seqlen;
auto cu_seqlens_shape = cu_seqlens.shape();
auto lse_shape = lse.shape();
auto lse_per_step_shape = lse_per_step.shape();
if (lse_packed) {
NVTE_CHECK(lse.dim() == 2);
NVTE_CHECK(lse_per_step.dim() == 2);
batch = cu_seqlens_shape[0] - 1;
num_heads = lse_shape[0];
lse_seqlen = lse_shape[1];
second_half_lse_seqlen = lse_per_step_shape[1];
NVTE_CHECK(lse_per_step_shape[0] == num_heads);
NVTE_CHECK(second_half_lse_seqlen >= lse_seqlen / 2);
} else {
NVTE_CHECK(lse.dim() == 3);
NVTE_CHECK(lse_per_step.dim() == 3);
batch = lse_shape[0];
num_heads = lse_shape[1];
lse_seqlen = lse_shape[2];
second_half_lse_seqlen = lse_per_step_shape[2];
NVTE_CHECK(lse_per_step_shape[0] == batch);
NVTE_CHECK(lse_per_step_shape[1] == num_heads);
NVTE_CHECK(second_half_lse_seqlen == lse_seqlen / 2);
NVTE_CHECK(cu_seqlens_shape[0] == batch + 1);
}
constexpr unsigned int block = 256;
unsigned int grid_x = (lse_seqlen / 2 + block - 1) / block;
unsigned int grid_y = num_heads;
dim3 grid = {grid_x, grid_y};
if (lse_packed) {
thd_lse_kernel<true, LseCorrectionFunctor><<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(lse_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen,
second_half_lse_seqlen);
} else {
thd_lse_kernel<false, LseCorrectionFunctor><<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(lse_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen,
second_half_lse_seqlen);
}
}
void thd_read_second_half_lse(const Tensor &lse, const Tensor &cu_seqlens, Tensor &half_lse,
bool lse_packed, int second_half_lse_seqlen, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(lse.dtype() == DType::kFloat32);
NVTE_CHECK(cu_seqlens.dtype() == DType::kInt32);
NVTE_CHECK(cu_seqlens.dim() == 1);
int batch, num_heads, lse_seqlen;
auto cu_seqlens_shape = cu_seqlens.shape();
auto lse_shape = lse.shape();
if (lse_packed) {
NVTE_CHECK(lse.dim() == 2);
batch = cu_seqlens_shape[0] - 1;
num_heads = lse_shape[0];
lse_seqlen = lse_shape[1];
NVTE_CHECK(second_half_lse_seqlen >= lse_seqlen / 2);
} else {
NVTE_CHECK(lse.dim() == 3);
batch = lse_shape[0];
num_heads = lse_shape[1];
lse_seqlen = lse_shape[2];
NVTE_CHECK(cu_seqlens_shape[0] == batch + 1);
NVTE_CHECK(second_half_lse_seqlen == lse_seqlen / 2);
}
constexpr unsigned int block = 256;
unsigned int grid_x = (lse_seqlen / 2 + block - 1) / block;
unsigned int grid_y = num_heads;
dim3 grid = {grid_x, grid_y};
if (lse_packed) {
thd_lse_kernel<true, ReadLseFunctor><<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(half_lse.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen,
second_half_lse_seqlen);
} else {
thd_lse_kernel<false, ReadLseFunctor><<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(half_lse.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen,
second_half_lse_seqlen);
}
}
/***************************************************************************************************
* Support THD format for Context Parallel: Out correction in forward
**************************************************************************************************/
template <typename dtype, int only_second_half>
static void thd_out_correction_helper(Tensor out, const Tensor &out_per_step, const Tensor &lse,
const Tensor &lse_per_step, const Tensor &cu_seqlens,
bool lse_packed, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(out.dtype() == out_per_step.dtype());
NVTE_CHECK(lse.dtype() == DType::kFloat32);
NVTE_CHECK(lse_per_step.dtype() == DType::kFloat32);
NVTE_CHECK(cu_seqlens.dtype() == DType::kInt32);
auto out_shape = out.shape();
auto lse_shape = lse.shape();
auto out_per_step_shape = out_per_step.shape();
auto lse_per_step_shape = lse_per_step.shape();
auto cu_seqlens_shape = cu_seqlens.shape();
int total_tokens = out_shape[0];
int num_heads = out_shape[1];
int dim_per_head = out_shape[2];
NVTE_CHECK(out_per_step_shape[0] == total_tokens / (only_second_half + 1));
NVTE_CHECK(out_per_step_shape[1] == num_heads);
NVTE_CHECK(out_per_step_shape[2] == dim_per_head);
int batch, lse_seqlen, lse_per_step_seqlen;
if (lse_packed) {
batch = cu_seqlens_shape[0] - 1;
lse_seqlen = lse_shape[1];
lse_per_step_seqlen = lse_per_step_shape[1];
NVTE_CHECK(lse_shape[0] == num_heads);
NVTE_CHECK(lse_seqlen >= total_tokens);
NVTE_CHECK(lse_per_step_shape[0] == num_heads);
NVTE_CHECK(lse_per_step_seqlen >= lse_seqlen / (only_second_half + 1));
} else {
batch = lse_shape[0];
lse_seqlen = lse_shape[2];
lse_per_step_seqlen = lse_per_step_shape[2];
NVTE_CHECK(lse_shape[1] == num_heads);
NVTE_CHECK(lse_per_step_shape[0] == batch);
NVTE_CHECK(lse_per_step_shape[1] == num_heads);
NVTE_CHECK(lse_per_step_seqlen == lse_seqlen / (only_second_half + 1));
NVTE_CHECK(cu_seqlens_shape[0] == batch + 1);
}
constexpr int tile = 16;
constexpr int block = 512;
unsigned int grid_x =
(static_cast<size_t>(total_tokens) / (only_second_half + 1) * tile + block - 1) / block;
dim3 grid = {grid_x, (unsigned int)num_heads};
if (lse_packed) {
thd_out_correction_kernel<dtype, only_second_half, tile, true>
<<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<dtype *>(out.data.dptr),
reinterpret_cast<dtype *>(out_per_step.data.dptr),
reinterpret_cast<float *>(lse.data.dptr),
reinterpret_cast<float *>(lse_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, dim_per_head,
lse_seqlen, lse_per_step_seqlen);
} else {
thd_out_correction_kernel<dtype, only_second_half, tile, false>
<<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<dtype *>(out.data.dptr),
reinterpret_cast<dtype *>(out_per_step.data.dptr),
reinterpret_cast<float *>(lse.data.dptr),
reinterpret_cast<float *>(lse_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, dim_per_head,
lse_seqlen, lse_per_step_seqlen);
}
}
void thd_out_correction(Tensor out, const Tensor &out_per_step, const Tensor &lse,
const Tensor &lse_per_step, const Tensor &cu_seqlens, bool only_second_half,
bool lse_packed, cudaStream_t stream) {
using namespace transformer_engine;
if (only_second_half) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
out.dtype(), dtype,
thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed, stream););
} else {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
out.dtype(), dtype,
thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed, stream););
}
}
/***************************************************************************************************
* Support THD format for Context Parallel: Gradients correction in backward
**************************************************************************************************/
template <typename dtype, typename Functor_0, typename Functor_1, int functor_idx>
static void thd_grad_correction_helper(Tensor grad, const Tensor &grad_per_step,
const Tensor &cu_seqlens, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(grad.dim() == 3 || grad.dim() == 4);
NVTE_CHECK(cu_seqlens.dtype() == DType::kInt32);
NVTE_CHECK(cu_seqlens.dim() == 1);
auto grad_shape = grad.shape();
auto cu_seqlens_shape = cu_seqlens.shape();
auto grad_per_step_shape = grad_per_step.shape();
// Shape of dq is [t, h, d], so the dimension of "t" is 0
// Shape of dkv is [2, t, h, d], so the dimension of "t" is 1
int seq_dim = grad.dim() == 3 ? 0 : 1;
int total_tokens = grad_shape[seq_dim];
int num_heads = grad_shape[seq_dim + 1];
int dim_per_head = grad_shape[seq_dim + 2];
int batch = cu_seqlens_shape[0] - 1;
if constexpr (functor_idx < 2) {
NVTE_CHECK(grad_per_step_shape[seq_dim] == total_tokens / 2);
} else {
NVTE_CHECK(grad_per_step_shape[seq_dim] == total_tokens);
}
NVTE_CHECK(grad_per_step_shape[seq_dim + 1] == num_heads);
NVTE_CHECK(grad_per_step_shape[seq_dim + 2] == dim_per_head);
size_t hidden_size = num_heads * dim_per_head;
NVTE_CHECK((hidden_size * typeToSize(grad.dtype())) % 16 == 0);
constexpr unsigned int block = 256;
unsigned int grid_x;
if constexpr (functor_idx < 2) {
grid_x = (total_tokens / 2 * 32 + block - 1) / block;
} else {
grid_x = (total_tokens * 32 + block - 1) / block;
}
unsigned int grid_y = 1;
for (int i = 0; i < seq_dim; i++) {
grid_y *= grad_shape[i];
}
dim3 grid = {grid_x, grid_y};
thd_grad_correction_kernel<dtype, Functor_0, Functor_1, functor_idx, 32>
<<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<dtype *>(grad.data.dptr),
reinterpret_cast<dtype *>(grad_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, hidden_size, total_tokens);
}
template <typename dtype>
static void thd_grad_dispatcher(Tensor grad, const Tensor &grad_per_step, const Tensor &cu_seqlens,
const std::string &first_half, const std::string &second_half,
cudaStream_t stream) {
using namespace transformer_engine;
if (first_half == "add" && second_half == "none") {
thd_grad_correction_helper<dtype, AddFunctor<dtype>, EmptyFunctor, 0>(grad, grad_per_step,
cu_seqlens, stream);
} else if (first_half == "copy" && second_half == "none") {
thd_grad_correction_helper<dtype, CopyFunctor, EmptyFunctor, 0>(grad, grad_per_step, cu_seqlens,
stream);
} else if (first_half == "none" && second_half == "add") {
thd_grad_correction_helper<dtype, EmptyFunctor, AddFunctor<dtype>, 1>(grad, grad_per_step,
cu_seqlens, stream);
} else if (first_half == "none" && second_half == "copy") {
thd_grad_correction_helper<dtype, EmptyFunctor, CopyFunctor, 1>(grad, grad_per_step, cu_seqlens,
stream);
} else if (first_half == "add" && second_half == "copy") {
thd_grad_correction_helper<dtype, AddFunctor<dtype>, CopyFunctor, 2>(grad, grad_per_step,
cu_seqlens, stream);
} else if (first_half == "copy" && second_half == "add") {
thd_grad_correction_helper<dtype, CopyFunctor, AddFunctor<dtype>, 2>(grad, grad_per_step,
cu_seqlens, stream);
} else {
NVTE_ERROR("Unsupported Functor of first half and second_half\n");
}
}
void thd_grad_correction(Tensor grad, const Tensor &grad_per_step, const Tensor &cu_seqlens,
const std::string &first_half, const std::string &second_half,
cudaStream_t stream) {
using namespace transformer_engine;
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
grad.dtype(), dtype,
thd_grad_dispatcher<dtype>(grad, grad_per_step, cu_seqlens, first_half, second_half,
stream););
}
/***************************************************************************************************
* Support THD format for Context Parallel: Generate partitioned indices for input tokens
**************************************************************************************************/
void thd_get_partitioned_indices(const Tensor &cu_seqlens, Tensor output, int total_tokens,
int world_size, int rank, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(cu_seqlens.dtype() == DType::kInt32);
NVTE_CHECK(cu_seqlens.dim() == 1);
auto cu_seqlens_shape = cu_seqlens.shape();
auto output_shape = output.shape();
NVTE_CHECK(cu_seqlens_shape[0] >= 2);
NVTE_CHECK(rank >= 0 && rank < world_size);
NVTE_CHECK(world_size > 0);
NVTE_CHECK(total_tokens > 0 && total_tokens % (world_size * 2) == 0);
int batch = cu_seqlens_shape[0] - 1;
constexpr unsigned int block = 256;
unsigned int grid = (output_shape[0] + block - 1) / block;
thd_partition_indices_kernel<<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<int *>(output.data.dptr), reinterpret_cast<int *>(cu_seqlens.data.dptr),
batch, total_tokens, world_size, rank);
}
} // namespace context_parallel
} // namespace transformer_engine
#endif
void nvte_cp_thd_read_half_tensor(const NVTETensor &tensor, const NVTETensor &cu_seqlens,
NVTETensor half, int half_idx, cudaStream_t stream) {
NVTE_API_CALL(nvte_thd_read_half_tensor);
using namespace transformer_engine;
context_parallel::thd_read_half_tensor(*reinterpret_cast<Tensor *>(tensor),
*reinterpret_cast<Tensor *>(cu_seqlens),
*reinterpret_cast<Tensor *>(half), half_idx, stream);
}
void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &lse_per_step,
const NVTETensor &cu_seqlens, int lse_packed,
cudaStream_t stream) {
NVTE_API_CALL(nvte_thd_second_half_lse_correction);
using namespace transformer_engine;
context_parallel::thd_second_half_lse_correction(
*reinterpret_cast<Tensor *>(lse), *reinterpret_cast<Tensor *>(lse_per_step),
*reinterpret_cast<Tensor *>(cu_seqlens), lse_packed, stream);
}
void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &cu_seqlens,
NVTETensor half_lse, int lse_packed,
int second_half_lse_seqlen, cudaStream_t stream) {
NVTE_API_CALL(nvte_thd_read_second_half_lse);
using namespace transformer_engine;
context_parallel::thd_read_second_half_lse(
*reinterpret_cast<Tensor *>(lse), *reinterpret_cast<Tensor *>(cu_seqlens),
*reinterpret_cast<Tensor *>(half_lse), lse_packed, second_half_lse_seqlen, stream);
}
void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step,
const NVTETensor &lse, const NVTETensor &lse_per_step,
const NVTETensor &cu_seqlens, int only_second_half, int lse_packed,
cudaStream_t stream) {
NVTE_API_CALL(nvte_thd_out_correction);
using namespace transformer_engine;
context_parallel::thd_out_correction(
*reinterpret_cast<Tensor *>(out), *reinterpret_cast<Tensor *>(out_per_step),
*reinterpret_cast<Tensor *>(lse), *reinterpret_cast<Tensor *>(lse_per_step),
*reinterpret_cast<Tensor *>(cu_seqlens), only_second_half, lse_packed, stream);
}
void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_step,
const NVTETensor &cu_seqlens, const char *first_half,
const char *second_half, cudaStream_t stream) {
NVTE_API_CALL(nvte_thd_grad_correction);
using namespace transformer_engine;
std::string first_half_str(first_half);
std::string second_half_str(second_half);
context_parallel::thd_grad_correction(
*reinterpret_cast<Tensor *>(grad), *reinterpret_cast<Tensor *>(grad_per_step),
*reinterpret_cast<Tensor *>(cu_seqlens), first_half_str, second_half_str, stream);
}
void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETensor output,
int total_tokens, int world_size, int rank,
cudaStream_t stream) {
NVTE_API_CALL(nvte_thd_get_partitioned_indices);
using namespace transformer_engine;
context_parallel::thd_get_partitioned_indices(*reinterpret_cast<Tensor *>(cu_seqlens),
*reinterpret_cast<Tensor *>(output), total_tokens,
world_size, rank, stream);
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../common.h"
#include "transformer_engine/fused_attn.h"
namespace transformer_engine {
namespace flash_attention {
constexpr int warp_size = 32;
constexpr int type_size = 2; // FP16 or BF16
constexpr int nvec = sizeof(uint64_t) / type_size;
constexpr int load_size = warp_size * nvec;
constexpr int block_size = 512;
template <typename T>
__launch_bounds__(block_size) __global__
void prepare_kernel_fwd(const T *qkvi, T *qkv, const size_t B, const size_t S, const size_t Z,
const size_t W) {
const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size;
const int id_in_warp = threadIdx.x % warp_size;
const size_t offset_input = blockIdx.y * W + warpid * 3 * W * Z + id_in_warp * nvec;
const T *my_input = qkvi + offset_input;
const size_t s = warpid / B;
if (s >= S) return;
const size_t b = warpid % B;
const size_t offset_output = blockIdx.y * B * S * Z * W + (s + b * S) * W * Z + id_in_warp * nvec;
T *my_output = qkv + offset_output;
for (int i = 0; i < Z; ++i) {
uint64_t *out = reinterpret_cast<uint64_t *>(my_output + i * load_size);
*out = *reinterpret_cast<const uint64_t *>(my_input + i * load_size * 3);
}
}
template <typename T>
__launch_bounds__(block_size) __global__
void prepare_kernel_bwd(const T *q, const T *k, const T *v, T *qkv, const size_t B,
const size_t S, const size_t Z, const size_t W) {
const T *input = blockIdx.y == 0 ? q : (blockIdx.y == 1 ? k : v);
const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size;
const int id_in_warp = threadIdx.x % warp_size;
const size_t offset_input = warpid * W * Z + id_in_warp * nvec;
const T *my_input = input + offset_input;
const size_t b = warpid / S;
if (b >= B) return;
const size_t s = warpid % S;
const size_t offset_output = (b + s * B) * 3 * W * Z + id_in_warp * nvec + blockIdx.y * W;
T *my_output = qkv + offset_output;
for (int i = 0; i < Z; ++i) {
uint64_t *out = reinterpret_cast<uint64_t *>(my_output + i * load_size * 3);
*out = *reinterpret_cast<const uint64_t *>(my_input + i * load_size);
}
}
void prepare_flash_attn_fwd(Tensor qkvi, Tensor qkv, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(qkvi.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(qkvi.dtype() == DType::kFloat16 || qkvi.dtype() == DType::kBFloat16);
auto qkvi_shape = qkvi.shape();
NVTE_CHECK(qkvi_shape[3] % load_size == 0);
NVTE_CHECK(qkvi_shape[3] == load_size);
// [s, b, n, h * 3] -> [3, b, s, n, h]
std::vector<uint64_t> shape = {3, qkvi_shape[1], qkvi_shape[0], qkvi_shape[2], qkvi_shape[3]};
size_t warps = qkvi_shape[0] * qkvi_shape[1];
size_t warps_per_block = block_size / warp_size;
size_t blocks = (warps + warps_per_block - 1) / warps_per_block;
dim3 grid(blocks, 3);
int threads = block_size;
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(
qkvi.dtype(), dtype,
prepare_kernel_fwd<dtype><<<grid, threads, 0, stream>>>(
reinterpret_cast<dtype *>(qkvi.data.dptr), reinterpret_cast<dtype *>(qkv.data.dptr),
shape[1], shape[2], shape[3], shape[4]););
}
void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(q.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(k.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(v.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(q.dtype() == DType::kFloat16 || q.dtype() == DType::kBFloat16);
NVTE_CHECK(k.dtype() == q.dtype());
NVTE_CHECK(v.dtype() == q.dtype());
auto q_shape = q.shape();
auto k_shape = k.shape();
auto v_shape = v.shape();
NVTE_CHECK(q_shape[3] % load_size == 0);
NVTE_CHECK(q_shape[3] == load_size);
NVTE_CHECK(k_shape[3] % load_size == 0);
NVTE_CHECK(k_shape[3] == load_size);
NVTE_CHECK(v_shape[3] % load_size == 0);
NVTE_CHECK(v_shape[3] == load_size);
// 3 x [s, b, n, h] -> [b, s, n, 3 * h]
std::vector<uint64_t> shape = {q_shape[1], q_shape[0], q_shape[2], 3 * q_shape[3]};
size_t warps = q_shape[0] * q_shape[1];
size_t warps_per_block = block_size / warp_size;
size_t blocks = (warps + warps_per_block - 1) / warps_per_block;
dim3 grid(blocks, 3);
int threads = block_size;
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(
q.dtype(), dtype,
prepare_kernel_bwd<dtype><<<grid, threads, 0, stream>>>(
reinterpret_cast<dtype *>(q.data.dptr), reinterpret_cast<dtype *>(k.data.dptr),
reinterpret_cast<dtype *>(v.data.dptr), reinterpret_cast<dtype *>(qkv.data.dptr),
q_shape[0], q_shape[1], q_shape[2], q_shape[3]););
}
} // namespace flash_attention
} // namespace transformer_engine
void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t stream) {
NVTE_API_CALL(nvte_prepare_flash_attn_fwd);
using namespace transformer_engine;
flash_attention::prepare_flash_attn_fwd(*reinterpret_cast<Tensor *>(qkvi),
*reinterpret_cast<Tensor *>(qkv), stream);
}
void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor qkv,
cudaStream_t stream) {
NVTE_API_CALL(nvte_prepare_flash_attn_bwd);
using namespace transformer_engine;
flash_attention::prepare_flash_attn_bwd(
*reinterpret_cast<Tensor *>(q), *reinterpret_cast<Tensor *>(k),
*reinterpret_cast<Tensor *>(v), *reinterpret_cast<Tensor *>(qkv), stream);
}
......@@ -1006,3 +1006,18 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
}
}
uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len,
cudaStream_t stream) {
NVTE_API_CALL(nvte_get_runtime_num_segments);
using namespace transformer_engine::fused_attn;
return GetRuntimeNumSegments(cu_seqlen, workspace, len, stream);
}
void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor seed,
size_t q_max_seqlen, size_t kv_max_seqlen,
NVTE_Fused_Attn_Backend backend, cudaStream_t stream) {
NVTE_API_CALL(nvte_populate_rng_state_async);
using namespace transformer_engine::fused_attn;
PopulateRngStateAsync(rng_state_dst, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
}
......@@ -3,48 +3,15 @@
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_
#define TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_
namespace transformer_engine {
namespace fused_attn {
template <typename scalar_t>
__global__ void convert_thd_to_bshd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens,
int b, int max_seq_len, int h, int d) {
// tensor: thd; new_tensor: bshd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int num_elts = (cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]) * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
scalar_t *thd_token = tensor + thd_offset;
scalar_t *bshd_token = new_tensor + bshd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(bshd_token + i) = *(thd_token + i);
}
}
}
#include "../common.h"
#include "transformer_engine/fused_attn.h"
template <typename scalar_t>
__global__ void convert_bshd_to_thd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens,
int b, int max_seq_len, int h, int d) {
// tensor: bshd; new_tensor: thd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx];
int num_elts = seqlen * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
scalar_t *bshd_token = tensor + bshd_offset;
scalar_t *thd_token = new_tensor + thd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(thd_token + i) = *(bshd_token + i);
}
}
}
namespace transformer_engine {
namespace kv_cache {
template <typename scalar_t>
__global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, int *batch_indices,
template <typename dtype>
__global__ void reindex_kv_cache_kernel(dtype *k_cache, dtype *v_cache, int *batch_indices,
int *cu_new_lens, int *cu_cached_lens, int h_kv, int d_k,
int d_v, int b, int max_seq_len) {
// k_cache, v_cache: bshd
......@@ -75,11 +42,11 @@ __global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, in
}
}
template <typename scalar_t>
__global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar_t *k_cache,
scalar_t *v_cache, int *page_table, int *cu_new_lens,
int *cu_cached_lens, NVTE_QKV_Format qkv_format, int h_kv,
int d_k, int d_v, int b, int max_ctx_len, int max_seq_len,
template <typename dtype>
__global__ void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cache, dtype *v_cache,
int *page_table, int *cu_new_lens, int *cu_cached_lens,
NVTE_QKV_Format qkv_format, int h_kv, int d_k, int d_v,
int b, int max_ctx_len, int max_seq_len,
int max_pages_per_seq, bool is_non_paged) {
// new_k, new_v: qkv_format; k_cache, v_cache: bshd
// cu_new_lens, cu_cached_lens: [b + 1]
......@@ -140,6 +107,191 @@ __global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar
}
}
}
} // namespace fused_attn
template <typename dtype>
void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tensor v_cache,
Tensor page_table, Tensor cu_new_lens, Tensor cu_cached_lens,
NVTE_QKV_Format qkv_format, int h_kv, int d_k, int d_v, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq,
bool is_non_paged, cudaStream_t stream) {
if (new_k.has_data() && new_v.has_data() && k_cache.has_data() && v_cache.has_data()) {
if (is_non_paged) {
reindex_kv_cache_kernel<<<16, 256, 0, stream>>>(
reinterpret_cast<dtype *>(k_cache.data.dptr),
reinterpret_cast<dtype *>(v_cache.data.dptr),
reinterpret_cast<int *>(page_table.data.dptr),
reinterpret_cast<int *>(cu_new_lens.data.dptr),
reinterpret_cast<int *>(cu_cached_lens.data.dptr), h_kv, d_k, d_v, b, max_seq_len);
}
copy_to_kv_cache_kernel<<<16, 256, 0, stream>>>(
reinterpret_cast<dtype *>(new_k.data.dptr), reinterpret_cast<dtype *>(new_v.data.dptr),
reinterpret_cast<dtype *>(k_cache.data.dptr), reinterpret_cast<dtype *>(v_cache.data.dptr),
reinterpret_cast<int *>(page_table.data.dptr),
reinterpret_cast<int *>(cu_new_lens.data.dptr),
reinterpret_cast<int *>(cu_cached_lens.data.dptr), qkv_format, h_kv, d_k, d_v, b,
max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged);
}
}
void copy_to_kv_cache(Tensor new_k, Tensor new_v, Tensor k_cache, Tensor v_cache, Tensor page_table,
Tensor cu_new_lens, Tensor cu_cached_lens, NVTE_QKV_Format qkv_format, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq, bool is_non_paged,
cudaStream_t stream) {
int h_kv = new_k.shape()[new_k.dim() - 2];
int d_k = new_k.shape()[new_k.dim() - 1];
int d_v = new_v.shape()[new_v.dim() - 1];
NVTE_CHECK(k_cache.dtype() == v_cache.dtype() && new_k.dtype() == new_v.dtype() &&
new_k.dtype() == k_cache.dtype(),
"new_k, new_v, k_cache and v_cache must be of the same data type.");
NVTE_CHECK(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD ||
qkv_format == NVTE_QKV_Format::NVTE_THD,
"qkv_format must be {BSHD, SBHD, THD}.");
TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(
k_cache.dtype(), dtype,
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged, stream););
}
template <typename scalar_t>
__global__ void convert_thd_to_bshd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens,
int b, int max_seq_len, int h, int d) {
// tensor: thd; new_tensor: bshd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int num_elts = (cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]) * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
scalar_t *thd_token = tensor + thd_offset;
scalar_t *bshd_token = new_tensor + bshd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(bshd_token + i) = *(thd_token + i);
}
}
}
template <typename scalar_t>
void convert_thd_to_bshd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_seqlens, int b,
int max_seq_len, int h, int d, cudaStream_t stream) {
using namespace transformer_engine;
convert_thd_to_bshd_kernel<<<16, 256, 0, stream>>>(
reinterpret_cast<scalar_t *>(tensor.data.dptr),
reinterpret_cast<scalar_t *>(new_tensor.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), b, max_seq_len, h, d);
}
void convert_thd_to_bshd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, int b,
int max_seq_len, cudaStream_t stream) {
using namespace transformer_engine;
auto tensor_shape = tensor.shape();
TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(
new_tensor.dtype(), dtype,
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len,
tensor_shape[1], tensor_shape[2], stream););
}
template <typename scalar_t>
__global__ void convert_bshd_to_thd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens,
int b, int max_seq_len, int h, int d) {
// tensor: bshd; new_tensor: thd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx];
int num_elts = seqlen * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
scalar_t *bshd_token = tensor + bshd_offset;
scalar_t *thd_token = new_tensor + thd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(thd_token + i) = *(bshd_token + i);
}
}
}
template <typename scalar_t>
void convert_bshd_to_thd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_seqlens, int b,
int max_seq_len, int h, int d, cudaStream_t stream) {
using namespace transformer_engine;
convert_bshd_to_thd_kernel<<<16, 256, 0, stream>>>(
reinterpret_cast<scalar_t *>(tensor.data.dptr),
reinterpret_cast<scalar_t *>(new_tensor.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), b, max_seq_len, h, d);
}
void convert_bshd_to_thd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, int t,
cudaStream_t stream) {
using namespace transformer_engine;
auto tensor_shape = tensor.shape();
TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(
tensor.dtype(), dtype,
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, tensor_shape[0],
tensor_shape[1], tensor_shape[2], tensor_shape[3],
stream););
}
} // namespace kv_cache
} // namespace transformer_engine
#endif
/***************************************************************************************************
* KV Cache: Copy new KV tokens to the KV cache
* 1. new_k and new_v are in qkv_format; k_cache and v_cache are in 'bshd' format
* 2. cu_new_lens and cu_cached_lens are in shape [b + 1]; cu_cached_lens include the added lens
* in current step
* 3. Non-paged KV cache is a special case of paged KV cache, with page_table = [b, 1] and
* max_pages_per_seq = 1. We use the same underlying kernel for both non-paged and paged.
* Set is_non_paged = True/False to indicate as such.
* 4. is_non_paged = True also re-indexes the KV cache, e.g. the initial batch indices [0, 3, 1, 2]
* becomes [0, 1, 1, 2]. The page_table = batch_indices.unsqueeze(1) is however unchanged.
* batch_indices_post can be used for monotonical indexing, i.e. [0, 1, 2, 3]. batch_indices is
* preserved for the next layer in the same iteration.
* 5. Only supports same page_table for k_cache and v_cache
* 6. Only pad_between_seqs = False when qkv_format = thd, i.e. there should be no pad tokens
* between sequences in new_k and new_v such as [a a a 0..0 b b 0..0 c 0..0].
**************************************************************************************************/
void nvte_copy_to_kv_cache(NVTETensor new_k, NVTETensor new_v, NVTETensor k_cache,
NVTETensor v_cache, NVTETensor page_table, NVTETensor cu_new_lens,
NVTETensor cu_cached_lens, NVTE_QKV_Format qkv_format, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq,
int is_non_paged, cudaStream_t stream) {
NVTE_API_CALL(nvte_copy_to_kv_cache);
using namespace transformer_engine;
kv_cache::copy_to_kv_cache(
*reinterpret_cast<Tensor *>(new_k), *reinterpret_cast<Tensor *>(new_v),
*reinterpret_cast<Tensor *>(k_cache), *reinterpret_cast<Tensor *>(v_cache),
*reinterpret_cast<Tensor *>(page_table), *reinterpret_cast<Tensor *>(cu_new_lens),
*reinterpret_cast<Tensor *>(cu_cached_lens), qkv_format, b, max_ctx_len, max_seq_len,
max_pages_per_seq, is_non_paged, stream);
}
/***************************************************************************************************
* KV Cache: Convert a tensor from qkv_format = thd to qkv_format = bshd
**************************************************************************************************/
void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
int b, int max_seq_len, cudaStream_t stream) {
NVTE_API_CALL(nvte_convert_thd_to_bshd);
using namespace transformer_engine;
kv_cache::convert_thd_to_bshd(*reinterpret_cast<Tensor *>(tensor),
*reinterpret_cast<Tensor *>(cu_seqlens),
*reinterpret_cast<Tensor *>(new_tensor), b, max_seq_len, stream);
}
/***************************************************************************************************
* KV Cache: Convert a tensor from qkv_format = bshd to qkv_format = thd
**************************************************************************************************/
void nvte_convert_bshd_to_thd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
int t, cudaStream_t stream) {
NVTE_API_CALL(nvte_convert_bshd_to_thd);
using namespace transformer_engine;
kv_cache::convert_bshd_to_thd(*reinterpret_cast<Tensor *>(tensor),
*reinterpret_cast<Tensor *>(cu_seqlens),
*reinterpret_cast<Tensor *>(new_tensor), t, stream);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment