Commit f8c2af4c authored by yuguo's avatar yuguo
Browse files

Merge commit '1d903f5e' of...

Merge commit '1d903f5e' of https://github.com/NVIDIA/TransformerEngine
parents e92773a3 1d903f5e
...@@ -11,7 +11,7 @@ from transformer_engine.pytorch.utils import ( ...@@ -11,7 +11,7 @@ from transformer_engine.pytorch.utils import (
get_device_compute_capability, get_device_compute_capability,
get_cudnn_version, get_cudnn_version,
) )
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils
from test_fused_attn import ModelConfig from test_fused_attn import ModelConfig
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
......
...@@ -11,6 +11,12 @@ import math ...@@ -11,6 +11,12 @@ import math
import pytest import pytest
import torch import torch
from test_fused_attn import (
ModelConfig,
reset_rng_states,
_get_attention_backends,
)
from torch.distributions import Exponential from torch.distributions import Exponential
from transformer_engine.pytorch import make_graphed_callables from transformer_engine.pytorch import make_graphed_callables
from transformer_engine.common import recipe from transformer_engine.common import recipe
...@@ -18,20 +24,15 @@ from transformer_engine.pytorch import fp8_autocast, fp8_model_init ...@@ -18,20 +24,15 @@ from transformer_engine.pytorch import fp8_autocast, fp8_model_init
from transformer_engine.pytorch.transformer import ( from transformer_engine.pytorch.transformer import (
TransformerLayer, TransformerLayer,
) )
from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention import DotProductAttention, InferenceParams
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams from transformer_engine.pytorch.attention.dot_product_attention.utils import (
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils FlashAttentionUtils as fa_utils,
)
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
get_device_compute_capability,
init_method_normal, init_method_normal,
scaled_init_method_normal, scaled_init_method_normal,
is_bf16_compatible, is_bf16_compatible,
) )
from test_fused_attn import (
ModelConfig,
reset_rng_states,
_get_attention_backends,
)
# Initialize RNG state # Initialize RNG state
seed = 1234 seed = 1234
......
...@@ -392,6 +392,110 @@ class TestFloat8BlockwiseTensor: ...@@ -392,6 +392,110 @@ class TestFloat8BlockwiseTensor:
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
torch.testing.assert_close(x_view.dequantize(), -x_hp, **_tols[fp8_dtype]) torch.testing.assert_close(x_view.dequantize(), -x_hp, **_tols[fp8_dtype])
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2], ids=str)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize(
"dims", [[16, 16, 512], [16, 16, 512, 16], [12, 7, 11], [13, 14, 16], [2, 3, 5]]
)
def test_view_and_reshape_1D(
self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: List[int]
) -> None:
"""Test view operations that preserve tensor shape"""
device = "cuda"
def is_bitwise_equal(a, b):
if a.numel() != b.numel():
return False
a_flat = a.reshape(-1).view(torch.uint8)
b_flat = b.reshape(-1).view(torch.uint8)
return torch.all((a_flat ^ b_flat) == 0)
x_hp = torch.rand(dims, dtype=dtype, device=device)
quantizer = Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=True,
columnwise=True,
block_scaling_dim=1,
)
x_fp8 = quantizer.make_empty(x_hp.shape, dtype=dtype, device=device)
quantizer.update_quantized(x_hp.clone(), x_fp8)
# Test view, high dimension tensor -> 2D tensor
x_hp_view = x_hp.view(-1, dims[-1]).contiguous()
x_fp8_view = x_fp8.view(-1, dims[-1])
# Check the dequantized result
torch.testing.assert_close(
x_fp8_view.dequantize().contiguous(), x_hp_view, **_tols[fp8_dtype]
)
# Check the bitwise equality of the inner data
assert is_bitwise_equal(x_fp8_view._rowwise_data, x_fp8._rowwise_data)
assert is_bitwise_equal(x_fp8_view._rowwise_scale_inv, x_fp8._rowwise_scale_inv)
# Check the data ptr
assert x_fp8_view._rowwise_data.data_ptr() == x_fp8._rowwise_data.data_ptr()
assert x_fp8_view._rowwise_scale_inv.data_ptr() == x_fp8._rowwise_scale_inv.data_ptr()
# Test reshape high dimension tensor -> 2D tensor
x_hp_reshape = x_hp.reshape(-1, dims[-1]).contiguous()
x_fp8_reshape = x_fp8.reshape(-1, dims[-1])
# Check the dequantized result
torch.testing.assert_close(
x_fp8_reshape.dequantize().contiguous(), x_hp_reshape, **_tols[fp8_dtype]
)
# Check the bitwise equality of the inner data
assert is_bitwise_equal(x_fp8_reshape._rowwise_data, x_fp8._rowwise_data)
assert is_bitwise_equal(x_fp8_reshape._rowwise_scale_inv, x_fp8._rowwise_scale_inv)
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2], ids=str)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("dims", [[16, 16, 512, 16], [2, 512, 512, 128], [3, 13, 14, 16]])
def test_view_and_reshape_2D(
self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: List[int]
) -> None:
"""Test view operations that preserve tensor shape"""
device = "cuda"
def is_bitwise_equal(a, b):
if a.numel() != b.numel():
return False
a_flat = a.reshape(-1).view(torch.uint8)
b_flat = b.reshape(-1).view(torch.uint8)
return torch.all((a_flat ^ b_flat) == 0)
x_hp = torch.rand(dims, dtype=dtype, device=device)
quantizer = Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=True,
columnwise=True,
block_scaling_dim=2,
)
x_fp8 = quantizer.make_empty(x_hp.shape, dtype=dtype, device=device)
quantizer.update_quantized(x_hp.clone(), x_fp8)
# Test view, high dimension tensor -> 2D tensor
x_hp_view = x_hp.view(-1, dims[-2], dims[-1]).contiguous()
x_fp8_view = x_fp8.view(-1, dims[-2], dims[-1])
# Check the dequantized result
torch.testing.assert_close(
x_fp8_view.dequantize().contiguous(), x_hp_view, **_tols[fp8_dtype]
)
# Check the bitwise equality of the inner data
assert is_bitwise_equal(x_fp8_view._rowwise_data, x_fp8._rowwise_data)
assert is_bitwise_equal(x_fp8_view._rowwise_scale_inv, x_fp8._rowwise_scale_inv)
# Check the data ptr
assert x_fp8_view._rowwise_data.data_ptr() == x_fp8._rowwise_data.data_ptr()
assert x_fp8_view._rowwise_scale_inv.data_ptr() == x_fp8._rowwise_scale_inv.data_ptr()
# Test reshape high dimension tensor -> 2D tensor
x_hp_reshape = x_hp.reshape(-1, dims[-2], dims[-1]).contiguous()
x_fp8_reshape = x_fp8.reshape(-1, dims[-2], dims[-1])
# Check the dequantized result
torch.testing.assert_close(
x_fp8_reshape.dequantize().contiguous(), x_hp_reshape, **_tols[fp8_dtype]
)
# Check the bitwise equality of the inner data
assert is_bitwise_equal(x_fp8_reshape._rowwise_data, x_fp8._rowwise_data)
assert is_bitwise_equal(x_fp8_reshape._rowwise_scale_inv, x_fp8._rowwise_scale_inv)
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str) @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize("dims", [[256, 512], [250, 500]]) @pytest.mark.parametrize("dims", [[256, 512], [250, 500]])
......
...@@ -18,7 +18,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import ( ...@@ -18,7 +18,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
Float8CurrentScalingQuantizer, Float8CurrentScalingQuantizer,
) )
from transformer_engine.pytorch.constants import TE_DType, TE_DType_To_Torch from transformer_engine.pytorch.constants import TE_DType, TE_DType_To_Torch
from transformer_engine.pytorch.utils import non_tn_fp8_gemm_supported from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
import transformer_engine_torch as tex import transformer_engine_torch as tex
from references.ref_per_tensor_cs import ref_per_tensor_cs_cast from references.ref_per_tensor_cs import ref_per_tensor_cs_cast
...@@ -400,7 +400,7 @@ class TestCurrentScalingFloat8Tensor: ...@@ -400,7 +400,7 @@ class TestCurrentScalingFloat8Tensor:
"""Check numerical error when casting to FP8""" """Check numerical error when casting to FP8"""
# Skip invalid configurations # Skip invalid configurations
if non_tn_fp8_gemm_supported() and return_transpose: if is_non_tn_fp8_gemm_supported() and return_transpose:
pytest.skip("FP8 transpose is neither needed nor supported on current system") pytest.skip("FP8 transpose is neither needed nor supported on current system")
# Initialize random high precision data # Initialize random high precision data
......
...@@ -12,10 +12,11 @@ from torch import nn ...@@ -12,10 +12,11 @@ from torch import nn
from torch.testing._internal.common_device_type import largeTensorTest from torch.testing._internal.common_device_type import largeTensorTest
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.pytorch.attention import MultiheadAttention from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
from transformer_engine.pytorch import fp8_model_init from transformer_engine.pytorch import fp8_model_init
from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import gpu_autocast_ctx
# Check if FP8 is supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
...@@ -596,7 +597,7 @@ class AdamTest: ...@@ -596,7 +597,7 @@ class AdamTest:
gt_ = gt.clone() gt_ = gt.clone()
# Reference # Reference
with torch.cuda.amp.autocast(enabled=True): with gpu_autocast_ctx(enabled=True):
y = self.model(x) y = self.model(x)
loss = ((gt - y) ** 2).mean() loss = ((gt - y) ** 2).mean()
...@@ -605,7 +606,7 @@ class AdamTest: ...@@ -605,7 +606,7 @@ class AdamTest:
scaler.update() scaler.update()
# DUT # DUT
with torch.cuda.amp.autocast(enabled=True): with gpu_autocast_ctx(enabled=True):
y = self.model_(x) y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean() loss_ = ((gt_ - y) ** 2).mean()
...@@ -647,7 +648,7 @@ class AdamTest: ...@@ -647,7 +648,7 @@ class AdamTest:
gt_ = gt.clone() gt_ = gt.clone()
# Reference # Reference
with torch.cuda.amp.autocast(enabled=True): with gpu_autocast_ctx(enabled=True):
y = self.model(x) y = self.model(x)
loss = ((gt - y) ** 2).mean() loss = ((gt - y) ** 2).mean()
...@@ -656,7 +657,7 @@ class AdamTest: ...@@ -656,7 +657,7 @@ class AdamTest:
scaler.update() scaler.update()
# DUT # DUT
with torch.cuda.amp.autocast(enabled=True): with gpu_autocast_ctx(enabled=True):
y = self.model_(x) y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean() loss_ = ((gt_ - y) ** 2).mean()
...@@ -705,7 +706,7 @@ class AdamTest: ...@@ -705,7 +706,7 @@ class AdamTest:
gt_ = gt.clone() gt_ = gt.clone()
# Reference # Reference
with torch.cuda.amp.autocast(enabled=True): with gpu_autocast_ctx(enabled=True):
y = self.model(x) y = self.model(x)
loss = ((gt - y) ** 2).mean() loss = ((gt - y) ** 2).mean()
...@@ -714,7 +715,7 @@ class AdamTest: ...@@ -714,7 +715,7 @@ class AdamTest:
scaler.update() scaler.update()
# DUT # DUT
with torch.cuda.amp.autocast(enabled=True): with gpu_autocast_ctx(enabled=True):
y = self.model_(x) y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean() loss_ = ((gt_ - y) ** 2).mean()
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
from typing import Callable, Tuple, Union
import math import math
import pytest
import torch import torch
from typing import Callable, Tuple, Union import pytest
from transformer_engine.pytorch.dot_product_attention.rope import ( from transformer_engine.pytorch.attention.rope import (
RotaryPositionEmbedding, RotaryPositionEmbedding,
apply_rotary_pos_emb, apply_rotary_pos_emb,
) )
...@@ -22,6 +22,7 @@ def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor: ...@@ -22,6 +22,7 @@ def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
return torch.sum(output * t) return torch.sum(output * t)
@pytest.mark.parametrize("start_positions", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("seq_length", [2048, 4096]) @pytest.mark.parametrize("seq_length", [2048, 4096])
@pytest.mark.parametrize("hidden_size", [128, 256]) @pytest.mark.parametrize("hidden_size", [128, 256])
...@@ -43,7 +44,17 @@ def test_fused_rope( ...@@ -43,7 +44,17 @@ def test_fused_rope(
loss_func: Callable, loss_func: Callable,
cp_size: int, cp_size: int,
interleaved: bool, interleaved: bool,
start_positions: bool,
) -> None: ) -> None:
if margin == 0 and start_positions == True:
# This makes sure that the `start_positions` offsets being applied
# are with the maximum length of the rope embeddings.
pytest.skip("Skipping test with margin=0 and start_positions=True")
if start_positions == True and cp_size > 1:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest.skip("Skipping test with cp_size>1 and start_positions=True")
device = torch.device("cuda:0") device = torch.device("cuda:0")
batch_size, head_num = 2, 64 batch_size, head_num = 2, 64
t = torch.rand( t = torch.rand(
...@@ -51,6 +62,14 @@ def test_fused_rope( ...@@ -51,6 +62,14 @@ def test_fused_rope(
dtype=dtype, dtype=dtype,
device=device, device=device,
) )
# Get arbitrary offsets to be used with RoPE for all the sequences
start_positions = (
torch.randint(0, margin, (batch_size,), dtype=torch.int32, device=device)
if start_positions
else None
)
if tensor_format == "bshd": if tensor_format == "bshd":
t = t.transpose(0, 1).contiguous() t = t.transpose(0, 1).contiguous()
if transpose: if transpose:
...@@ -69,14 +88,18 @@ def test_fused_rope( ...@@ -69,14 +88,18 @@ def test_fused_rope(
t.float(), t.float(),
emb, emb,
tensor_format=tensor_format, tensor_format=tensor_format,
start_positions=start_positions,
interleaved=interleaved, interleaved=interleaved,
fused=False, fused=False,
cp_size=cp_size, cp_size=cp_size,
cp_rank=cp_rank, cp_rank=cp_rank,
).to(dtype) ).to(dtype)
loss_unfused = loss_func(output_unfused) loss_unfused = loss_func(output_unfused)
loss_unfused.backward()
grad_unfused = t.grad.detach().clone() if not isinstance(start_positions, torch.Tensor):
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None t.grad = None
# fused # fused
...@@ -84,21 +107,29 @@ def test_fused_rope( ...@@ -84,21 +107,29 @@ def test_fused_rope(
t, t,
emb, emb,
tensor_format=tensor_format, tensor_format=tensor_format,
start_positions=start_positions,
interleaved=interleaved, interleaved=interleaved,
fused=True, fused=True,
cp_size=cp_size, cp_size=cp_size,
cp_rank=cp_rank, cp_rank=cp_rank,
) )
loss_fused = loss_func(output_fused) loss_fused = loss_func(output_fused)
loss_fused.backward()
grad_fused = t.grad.detach().clone() if not isinstance(start_positions, torch.Tensor):
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None t.grad = None
torch.testing.assert_close(output_fused, output_unfused) torch.testing.assert_close(output_fused, output_unfused)
torch.testing.assert_close(grad_fused, grad_unfused)
if not isinstance(start_positions, torch.Tensor):
torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous() assert output_fused.is_contiguous()
@pytest.mark.parametrize("margin", [10])
@pytest.mark.parametrize("start_positions", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("hidden_size", [128, 256]) @pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0]) @pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
...@@ -114,10 +145,25 @@ def test_fused_rope_thd( ...@@ -114,10 +145,25 @@ def test_fused_rope_thd(
loss_func: Callable, loss_func: Callable,
cp_size: int, cp_size: int,
interleaved: bool, interleaved: bool,
start_positions: bool,
margin: int,
) -> None: ) -> None:
if start_positions == True and cp_size > 1:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest.skip("Skipping test with cp_size>1 and start_positions=True")
device = torch.device("cuda:0") device = torch.device("cuda:0")
batch_size, head_num = 2, 64 batch_size, head_num = 2, 64
cu_seqlens = [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048] cu_seqlens = [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048]
# Get arbitrary offsets to be used with RoPE for all the sequences
start_positions = (
torch.randint(0, margin, (len(cu_seqlens) - 1,), dtype=torch.int32, device=device)
if start_positions
else None
)
if cp_size > 1: if cp_size > 1:
cu_seqlens_padded = [0] cu_seqlens_padded = [0]
for i in range(1, len(cu_seqlens)): for i in range(1, len(cu_seqlens)):
...@@ -152,6 +198,7 @@ def test_fused_rope_thd( ...@@ -152,6 +198,7 @@ def test_fused_rope_thd(
output_unfused = apply_rotary_pos_emb( output_unfused = apply_rotary_pos_emb(
t.float(), t.float(),
emb, emb,
start_positions=start_positions,
tensor_format="thd", tensor_format="thd",
interleaved=interleaved, interleaved=interleaved,
fused=False, fused=False,
...@@ -160,14 +207,17 @@ def test_fused_rope_thd( ...@@ -160,14 +207,17 @@ def test_fused_rope_thd(
cp_rank=cp_rank, cp_rank=cp_rank,
).to(dtype) ).to(dtype)
loss_unfused = loss_func(output_unfused) loss_unfused = loss_func(output_unfused)
loss_unfused.backward()
grad_unfused = t.grad.detach().clone() if not isinstance(start_positions, torch.Tensor):
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None t.grad = None
# fused # fused
output_fused = apply_rotary_pos_emb( output_fused = apply_rotary_pos_emb(
t, t,
emb, emb,
start_positions=start_positions,
interleaved=interleaved, interleaved=interleaved,
fused=True, fused=True,
tensor_format="thd", tensor_format="thd",
...@@ -176,9 +226,15 @@ def test_fused_rope_thd( ...@@ -176,9 +226,15 @@ def test_fused_rope_thd(
cp_rank=cp_rank, cp_rank=cp_rank,
) )
loss_fused = loss_func(output_fused) loss_fused = loss_func(output_fused)
loss_fused.backward()
grad_fused = t.grad.detach().clone() if not isinstance(start_positions, torch.Tensor):
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None t.grad = None
torch.testing.assert_close(output_fused, output_unfused) torch.testing.assert_close(output_fused, output_unfused)
torch.testing.assert_close(grad_fused, grad_unfused)
if not isinstance(start_positions, torch.Tensor):
torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous()
...@@ -160,7 +160,7 @@ def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tens ...@@ -160,7 +160,7 @@ def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tens
normab = torch.cat((a.norm().view(1), b.norm().view(1))) normab = torch.cat((a.norm().view(1), b.norm().view(1)))
norm_per_tensor = norm_per_tensor.view(-1, 2) norm_per_tensor = norm_per_tensor.view(-1, 2)
else: else:
norm, _ = applier(tex.multi_tensor_l2norm, overflow_buf, [in_list], True) norm, _ = applier(tex.multi_tensor_l2norm, overflow_buf, [in_list], False)
reference = torch.full( reference = torch.full(
[(sizea + sizeb) * repeat], val, dtype=torch.float32, device=device [(sizea + sizeb) * repeat], val, dtype=torch.float32, device=device
......
...@@ -7,7 +7,6 @@ import math ...@@ -7,7 +7,6 @@ import math
import os import os
from typing import Dict, List, Tuple, Optional from typing import Dict, List, Tuple, Optional
import pytest import pytest
import copy
import random import random
import torch import torch
...@@ -40,12 +39,12 @@ from transformer_engine.pytorch import ( ...@@ -40,12 +39,12 @@ from transformer_engine.pytorch import (
Fp8Unpadding, Fp8Unpadding,
) )
from transformer_engine.pytorch import torch_version from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch.utils import get_device_compute_capability, get_cudnn_version
from transformer_engine.common import recipe from transformer_engine.common import recipe
import transformer_engine_torch as tex import transformer_engine_torch as tex
...@@ -135,18 +134,20 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]: ...@@ -135,18 +134,20 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
def assert_allclose( def assert_allclose(
l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float = None, rtol: float = None
) -> bool: ) -> bool:
"""Ensures two lists are equal.""" """Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs." assert len(l1) == len(l2), "Unequal number of outputs."
for i, (t1, t2) in enumerate(zip(l1, l2)): for i, (t1, t2) in enumerate(zip(l1, l2)):
tols = dict(atol=atol) tols = dtype_tols(t2.dtype)
if rtol is not None: if rtol is not None:
tols["rtol"] = rtol tols["rtol"] = rtol
if atol is not None:
tols["atol"] = atol
result = torch.allclose(t1, t2, **tols) result = torch.allclose(t1, t2, **tols)
if not result: if not result:
diff = torch.abs(t1 - t2) diff = torch.abs(t1 - t2)
tol = atol + (rtol * torch.abs(t2)) tol = tols["atol"] + (tols["rtol"] * torch.abs(t2))
exceed_mask = diff > tol exceed_mask = diff > tol
if exceed_mask.any(): if exceed_mask.any():
indices = torch.nonzero(exceed_mask, as_tuple=True) indices = torch.nonzero(exceed_mask, as_tuple=True)
...@@ -2304,6 +2305,12 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ...@@ -2304,6 +2305,12 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
pytest.skip("FusedAttention and FlashAttention do not support FP32") pytest.skip("FusedAttention and FlashAttention do not support FP32")
if use_RoPE: if use_RoPE:
pytest.skip("KV cache does not support starting positions for RoPE") pytest.skip("KV cache does not support starting positions for RoPE")
if (
backend == "FusedAttention"
and get_device_compute_capability() == (8, 9)
and get_cudnn_version() < (9, 11, 0)
):
pytest.skip("Skip KV cache for sm89 and cuDNN < 9.11")
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
......
...@@ -19,11 +19,12 @@ class TestParallelCrossEntropy: ...@@ -19,11 +19,12 @@ class TestParallelCrossEntropy:
label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none" label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none"
) )
def generate_input(self, dtype: torch.dtype, swap_dim: bool): def generate_input(self, dtype: torch.dtype, swap_dim: bool, ignore_idx: bool):
SQ = random.choice([64, 128]) SQ = random.choice([64, 128])
batch = random.choice([1, 2]) batch = random.choice([1, 2])
vocab = random.choice([64000, 128000]) vocab = random.choice([64000, 128000])
ignore = random.sample(range(0, SQ - 1), 5)
if swap_dim: if swap_dim:
self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype).cuda() self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype).cuda()
...@@ -32,14 +33,27 @@ class TestParallelCrossEntropy: ...@@ -32,14 +33,27 @@ class TestParallelCrossEntropy:
self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype).cuda() self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype).cuda()
self.tar_test = torch.randint(0, vocab, (batch, SQ)).cuda() self.tar_test = torch.randint(0, vocab, (batch, SQ)).cuda()
if ignore_idx:
for i in ignore:
# Ignore 5 indices
if swap_dim:
self.tar_test[i][0] = -100
else:
self.tar_test[0][i] = -100
self.input_ref = torch.reshape(self.input_test.clone().detach(), (batch * SQ, vocab)) self.input_ref = torch.reshape(self.input_test.clone().detach(), (batch * SQ, vocab))
self.tar_ref = torch.reshape(self.tar_test.clone().detach(), (batch * SQ,)) self.tar_ref = torch.reshape(self.tar_test.clone().detach(), (batch * SQ,))
def one_iteration_test( def one_iteration_test(
self, dtype: torch.dtype, swap_dim: bool, label_smoothing: float, reduce_loss: bool self,
dtype: torch.dtype,
swap_dim: bool,
label_smoothing: float,
reduce_loss: bool,
ignore_idx: bool = False,
): ):
self.generate_input(dtype, swap_dim) self.generate_input(dtype, swap_dim, ignore_idx)
self.input_test.requires_grad_(True) self.input_test.requires_grad_(True)
self.input_ref.requires_grad_(True) self.input_ref.requires_grad_(True)
...@@ -57,6 +71,8 @@ class TestParallelCrossEntropy: ...@@ -57,6 +71,8 @@ class TestParallelCrossEntropy:
test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss
torch.testing.assert_close(test_loss, ref_loss, check_dtype=False) torch.testing.assert_close(test_loss, ref_loss, check_dtype=False)
if ignore_idx:
print(test_loss, ref_loss)
if reduce_loss: if reduce_loss:
torch.testing.assert_close( torch.testing.assert_close(
torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad
...@@ -106,3 +122,15 @@ class TestParallelCrossEntropy: ...@@ -106,3 +122,15 @@ class TestParallelCrossEntropy:
self.one_iteration_test( self.one_iteration_test(
dtype=torch.float32, swap_dim=False, label_smoothing=0, reduce_loss=False dtype=torch.float32, swap_dim=False, label_smoothing=0, reduce_loss=False
) )
def test_ignore_idx(self):
self.generate_iters(5)
self.generate_infra(False, 0)
for i in range(self.iters):
self.one_iteration_test(
dtype=torch.float32,
swap_dim=random.choice([True, False]),
label_smoothing=0,
reduce_loss=False,
ignore_idx=True,
)
...@@ -373,7 +373,9 @@ def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad): ...@@ -373,7 +373,9 @@ def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
torch.cuda.synchronize() torch.cuda.synchronize()
def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad): def _test_sanity_common(
block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching=True
):
if skip_dgrad and skip_wgrad: if skip_dgrad and skip_wgrad:
pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.") pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")
...@@ -389,7 +391,11 @@ def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad ...@@ -389,7 +391,11 @@ def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad
use_fp8 = fp8_recipe is not None use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp) if not microbatching:
te_out = block(te_inp)
else:
_ = block(te_inp, is_first_microbatch=True)
te_out = block(te_inp, is_first_microbatch=False)
if isinstance(te_out, tuple): if isinstance(te_out, tuple):
te_out = te_out[0] te_out = te_out[0]
loss = te_out.sum() loss = te_out.sum()
...@@ -443,8 +449,16 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz ...@@ -443,8 +449,16 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean) @pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("microbatching", all_boolean)
def test_sanity_layernorm_linear( def test_sanity_layernorm_linear(
dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, normalization dtype,
fp8_recipe,
model,
skip_wgrad,
zero_centered_gamma,
skip_dgrad,
normalization,
microbatching,
): ):
config = model_configs[model] config = model_configs[model]
...@@ -470,7 +484,7 @@ def test_sanity_layernorm_linear( ...@@ -470,7 +484,7 @@ def test_sanity_layernorm_linear(
params_dtype=dtype, params_dtype=dtype,
device="cuda", device="cuda",
) )
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad) _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
...@@ -478,7 +492,8 @@ def test_sanity_layernorm_linear( ...@@ -478,7 +492,8 @@ def test_sanity_layernorm_linear(
@pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean) @pytest.mark.parametrize("skip_dgrad", all_boolean)
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad): @pytest.mark.parametrize("microbatching", all_boolean)
def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microbatching):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
...@@ -501,7 +516,7 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad): ...@@ -501,7 +516,7 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
params_dtype=dtype, params_dtype=dtype,
device="cuda", device="cuda",
) )
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad) _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
...@@ -600,8 +615,17 @@ def test_sanity_grouped_linear( ...@@ -600,8 +615,17 @@ def test_sanity_grouped_linear(
@pytest.mark.parametrize("skip_dgrad", all_boolean) @pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("activation", all_activations) @pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("microbatching", all_boolean)
def test_sanity_layernorm_mlp( def test_sanity_layernorm_mlp(
dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, activation, normalization dtype,
fp8_recipe,
model,
skip_wgrad,
zero_centered_gamma,
skip_dgrad,
activation,
normalization,
microbatching,
): ):
config = model_configs[model] config = model_configs[model]
...@@ -630,7 +654,7 @@ def test_sanity_layernorm_mlp( ...@@ -630,7 +654,7 @@ def test_sanity_layernorm_mlp(
params_dtype=dtype, params_dtype=dtype,
device="cuda", device="cuda",
) )
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad) _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
......
...@@ -11,12 +11,12 @@ import transformer_engine.common ...@@ -11,12 +11,12 @@ import transformer_engine.common
try: try:
from . import pytorch from . import pytorch
except (ImportError, StopIteration) as e: except ImportError as e:
pass pass
try: try:
from . import jax from . import jax
except (ImportError, StopIteration) as e: except ImportError as e:
pass pass
__version__ = str(metadata.version("transformer_engine")) __version__ = str(metadata.version("transformer_engine"))
...@@ -111,6 +111,11 @@ if(USE_CUDA) ...@@ -111,6 +111,11 @@ if(USE_CUDA)
cudnn_utils.cpp cudnn_utils.cpp
transformer_engine.cpp transformer_engine.cpp
common.cu common.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
transpose/cast_transpose.cu transpose/cast_transpose.cu
transpose/transpose.cu transpose/transpose.cu
transpose/cast_transpose_fusion.cu transpose/cast_transpose_fusion.cu
...@@ -148,6 +153,7 @@ if(USE_CUDA) ...@@ -148,6 +153,7 @@ if(USE_CUDA)
fused_rope/fused_rope.cu fused_rope/fused_rope.cu
recipe/current_scaling.cu recipe/current_scaling.cu
recipe/delayed_scaling.cu recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu comm_gemm_overlap/userbuffers/userbuffers.cu
...@@ -158,6 +164,11 @@ else() ...@@ -158,6 +164,11 @@ else()
cudnn_utils.cpp cudnn_utils.cpp
transformer_engine.cpp transformer_engine.cpp
common.cu common.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
transpose/cast_transpose.cu transpose/cast_transpose.cu
transpose/transpose.cu transpose/transpose.cu
transpose/cast_transpose_fusion.cu transpose/cast_transpose_fusion.cu
...@@ -191,6 +202,7 @@ else() ...@@ -191,6 +202,7 @@ else()
fused_rope/fused_rope.cu fused_rope/fused_rope.cu
recipe/current_scaling.cu recipe/current_scaling.cu
recipe/delayed_scaling.cu recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu comm_gemm_overlap/userbuffers/userbuffers.cu
...@@ -345,6 +357,14 @@ target_include_directories(transformer_engine PRIVATE ...@@ -345,6 +357,14 @@ target_include_directories(transformer_engine PRIVATE
set_source_files_properties(fused_softmax/scaled_masked_softmax.cu set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
PROPERTIES PROPERTIES
COMPILE_OPTIONS "--use_fast_math") COMPILE_OPTIONS "--use_fast_math")
option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF) option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF)
......
...@@ -9,28 +9,193 @@ import glob ...@@ -9,28 +9,193 @@ import glob
import sysconfig import sysconfig
import subprocess import subprocess
import ctypes import ctypes
import logging
import os import os
import platform import platform
import importlib
import functools
from pathlib import Path from pathlib import Path
from importlib.metadata import version, metadata, PackageNotFoundError
import transformer_engine
_logger = logging.getLogger(__name__)
def is_package_installed(package):
"""Checks if a pip package is installed.""" @functools.lru_cache(maxsize=None)
return ( def _is_pip_package_installed(package):
subprocess.run( """Check if the given package is installed via pip."""
[sys.executable, "-m", "pip", "show", package], capture_output=True, check=False
).returncode # This is needed because we only want to return true
== 0 # if the python package is installed via pip, and not
# if it's importable in the current directory due to
# the presence of the shared library module.
try:
metadata(package)
except PackageNotFoundError:
return False
return True
@functools.lru_cache(maxsize=None)
def _find_shared_object_in_te_dir(te_path: Path, prefix: str):
"""
Find a shared object file of given prefix in the top level TE directory.
Only the following locations are searched to avoid stray SOs and build
artifacts:
1. The given top level directory (editable install).
2. `transformer_engine` named directories (source install).
3. `wheel_lib` named directories (PyPI install).
Returns None if no shared object files are found.
Raises an error if multiple shared object files are found.
"""
# Ensure top level dir exists and has the module. before searching.
if not te_path.exists() or not (te_path / "transformer_engine").exists():
return None
files = []
search_paths = (
te_path,
te_path / "transformer_engine",
te_path / "transformer_engine/wheel_lib",
te_path / "wheel_lib",
) )
# Search.
for dirname, _, names in os.walk(te_path):
if Path(dirname) in search_paths:
for name in names:
if name.startswith(prefix) and name.endswith(f".{_get_sys_extension()}"):
files.append(Path(dirname, name))
if len(files) == 0:
return None
if len(files) == 1:
return files[0]
raise RuntimeError(f"Multiple files found: {files}")
@functools.lru_cache(maxsize=None)
def _get_shared_object_file(library: str) -> Path:
"""
Return the path of the shared object file for the given TE
library, one of 'core', 'torch', or 'jax'.
Several factors affect finding the correct location of the shared object:
1. System and environment.
2. If the installation is from source or via PyPI.
- Source installed .sos are placed in top level dir
- Wheel/PyPI installed .sos are placed in 'wheel_lib' dir to avoid conflicts.
3. For source installations, is the install editable/inplace?
4. The user directory from where TE is being imported.
"""
# Check provided input and determine the correct prefix for .so.
assert library in ("core", "torch", "jax"), f"Unsupported TE library {library}."
if library == "core":
so_prefix = "libtransformer_engine"
else:
so_prefix = f"transformer_engine_{library}"
# Check TE install location (will be local if TE is available in current dir for import).
te_install_dir = Path(importlib.util.find_spec("transformer_engine").origin).parent.parent
so_path_in_install_dir = _find_shared_object_in_te_dir(te_install_dir, so_prefix)
# Check default python package install location in system.
site_packages_dir = Path(sysconfig.get_paths()["purelib"])
so_path_in_default_dir = _find_shared_object_in_te_dir(site_packages_dir, so_prefix)
# Case 1: Typical user workflow: Both locations are the same, return any result.
if te_install_dir == site_packages_dir:
assert (
so_path_in_install_dir is not None
), f"Could not find shared object file for Transformer Engine {library} lib."
return so_path_in_install_dir
# Case 2: ERR! Both locations are different but returned a valid result.
# NOTE: Unlike for source installations, pip does not wipe out artifacts from
# editable builds. In case developers are executing inside a TE directory via
# an inplace build, and then move to a regular build, the local shared object
# file will be incorrectly picked up without the following logic.
if so_path_in_install_dir is not None and so_path_in_default_dir is not None:
raise RuntimeError(
f"Found multiple shared object files: {so_path_in_install_dir} and"
f" {so_path_in_default_dir}. Remove local shared objects installed"
f" here {so_path_in_install_dir} or change the working directory to"
"execute from outside TE."
)
# Case 3: Typical dev workflow: Editable install
if so_path_in_install_dir is not None:
return so_path_in_install_dir
# Case 4: Executing from inside a TE directory without an inplace build available.
if so_path_in_default_dir is not None:
return so_path_in_default_dir
raise RuntimeError(f"Could not find shared object file for Transformer Engine {library} lib.")
@functools.lru_cache(maxsize=None)
def load_framework_extension(framework: str):
"""
Load shared library with Transformer Engine framework bindings
and check verify correctness if installed via PyPI.
"""
# Supported frameworks.
assert framework in ("jax", "torch"), f"Unsupported framework {framework}"
# Name of the framework extension library.
module_name = f"transformer_engine_{framework}"
# Name of the pip extra dependency for framework extensions from PyPI.
extra_dep_name = module_name
if framework == "torch":
extra_dep_name = "pytorch"
# If the framework extension pip package is installed, it means that TE is installed via
# PyPI. For this case we need to make sure that the metapackage, the core lib, and framework
# extension are all installed via PyPI and have matching version.
if _is_pip_package_installed(module_name):
assert _is_pip_package_installed(
"transformer_engine"
), "Could not find `transformer-engine`."
assert _is_pip_package_installed(
"transformer_engine_cu12"
), "Could not find `transformer-engine-cu12`."
assert (
version(module_name)
== version("transformer-engine")
== version("transformer-engine-cu12")
), (
"TransformerEngine package version mismatch. Found"
f" {module_name} v{version(module_name)}, transformer-engine"
f" v{version('transformer-engine')}, and transformer-engine-cu12"
f" v{version('transformer-engine-cu12')}. Install transformer-engine using "
f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'"
)
# If the core package is installed via PyPI, log if
# the framework extension is not found from PyPI.
# Note: Should we error? This is a rare use case.
if _is_pip_package_installed("transformer-engine-cu12"):
if not _is_pip_package_installed(module_name):
_logger.info(
"Could not find package %s. Install transformer-engine using "
f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'",
module_name,
)
def get_te_path(): # After all checks are completed, load the shared object file.
"""Find Transformer Engine install path using pip""" spec = importlib.util.spec_from_file_location(module_name, _get_shared_object_file(framework))
return Path(transformer_engine.__path__[0]).parent solib = importlib.util.module_from_spec(spec)
sys.modules[module_name] = solib
spec.loader.exec_module(solib)
@functools.lru_cache(maxsize=None)
def _get_sys_extension(): def _get_sys_extension():
system = platform.system() system = platform.system()
if system == "Linux": if system == "Linux":
...@@ -45,20 +210,47 @@ def _get_sys_extension(): ...@@ -45,20 +210,47 @@ def _get_sys_extension():
return extension return extension
def _load_cudnn(): @functools.lru_cache(maxsize=None)
"""Load CUDNN shared library.""" def _load_nvidia_cuda_library(lib_name: str):
# Attempt to locate cuDNN in Python dist-packages """
lib_path = glob.glob( Attempts to load shared object file installed via pip.
`lib_name`: Name of package as found in the `nvidia` dir in python environment.
"""
so_paths = glob.glob(
os.path.join( os.path.join(
sysconfig.get_path("purelib"), sysconfig.get_path("purelib"),
f"nvidia/cudnn/lib/libcudnn.{_get_sys_extension()}.*[0-9]", f"nvidia/{lib_name}/lib/lib*.{_get_sys_extension()}.*[0-9]",
) )
) )
if lib_path:
assert ( path_found = len(so_paths) > 0
len(lib_path) == 1 ctypes_handles = []
), f"Found {len(lib_path)} libcudnn.{_get_sys_extension()}.x in nvidia-cudnn-cuXX."
return ctypes.CDLL(lib_path[0], mode=ctypes.RTLD_GLOBAL) if path_found:
for so_path in so_paths:
ctypes_handles.append(ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL))
return path_found, ctypes_handles
@functools.lru_cache(maxsize=None)
def _nvidia_cudart_include_dir():
"""Returns the include directory for cuda_runtime.h if exists in python environment."""
try:
import nvidia
except ModuleNotFoundError:
return ""
include_dir = Path(nvidia.__file__).parent / "cuda_runtime"
return str(include_dir) if include_dir.exists() else ""
@functools.lru_cache(maxsize=None)
def _load_cudnn():
"""Load CUDNN shared library."""
# Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set # Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set
cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH") cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH")
...@@ -75,28 +267,16 @@ def _load_cudnn(): ...@@ -75,28 +267,16 @@ def _load_cudnn():
if libs: if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate cuDNN in Python dist-packages
found, handle = _load_nvidia_cuda_library("cudnn")
if found:
return handle
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libcudnn.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) return ctypes.CDLL(f"libcudnn.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
def _load_library(): @functools.lru_cache(maxsize=None)
"""Load shared library with Transformer Engine C extensions"""
so_path = get_te_path() / "transformer_engine" / f"libtransformer_engine.{_get_sys_extension()}"
if not so_path.exists():
so_path = (
get_te_path()
/ "transformer_engine"
/ "wheel_lib"
/ f"libtransformer_engine.{_get_sys_extension()}"
)
if not so_path.exists():
so_path = get_te_path() / f"libtransformer_engine.{_get_sys_extension()}"
assert so_path.exists(), f"Could not find libtransformer_engine.{_get_sys_extension()}"
return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
def _load_nvrtc(): def _load_nvrtc():
"""Load NVRTC shared library.""" """Load NVRTC shared library."""
# Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda # Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda
...@@ -107,6 +287,11 @@ def _load_nvrtc(): ...@@ -107,6 +287,11 @@ def _load_nvrtc():
if libs: if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate NVRTC in Python dist-packages
found, handle = _load_nvidia_cuda_library("cuda_nvrtc")
if found:
return handle
# Attempt to locate NVRTC via ldconfig # Attempt to locate NVRTC via ldconfig
libs = subprocess.check_output("ldconfig -p | grep 'libnvrtc'", shell=True) libs = subprocess.check_output("ldconfig -p | grep 'libnvrtc'", shell=True)
libs = libs.decode("utf-8").split("\n") libs = libs.decode("utf-8").split("\n")
...@@ -123,10 +308,22 @@ def _load_nvrtc(): ...@@ -123,10 +308,22 @@ def _load_nvrtc():
return ctypes.CDLL(f"libnvrtc.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) return ctypes.CDLL(f"libnvrtc.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
@functools.lru_cache(maxsize=None)
def _load_core_library():
"""Load shared library with Transformer Engine C extensions"""
return ctypes.CDLL(_get_shared_object_file("core"), mode=ctypes.RTLD_GLOBAL)
if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
try: try:
_CUDNN_LIB_CTYPES = _load_cudnn() _CUDNN_LIB_CTYPES = _load_cudnn()
_NVRTC_LIB_CTYPES = _load_nvrtc() _NVRTC_LIB_CTYPES = _load_nvrtc()
_CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas")
_CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime")
# Needed to find the correct headers for NVRTC kernels.
if not os.getenv("NVTE_CUDA_INCLUDE_DIR") and _nvidia_cudart_include_dir():
os.environ["NVTE_CUDA_INCLUDE_DIR"] = _nvidia_cudart_include_dir()
except OSError: except OSError:
pass pass
_TE_LIB_CTYPES = _load_library() _TE_LIB_CTYPES = _load_core_library()
...@@ -21,12 +21,18 @@ ...@@ -21,12 +21,18 @@
#define HALF_BYTES 2 #define HALF_BYTES 2
#define UB_MAX_SM 32 #define UB_MAX_SM 32
#define AS_VECTOR(shape) std::vector<size_t>(shape.data, shape.data + shape.ndim)
using namespace std::placeholders; using namespace std::placeholders;
namespace transformer_engine { namespace transformer_engine {
namespace {
std::vector<size_t> shape_to_vector(const NVTEShape &shape) {
return std::vector<size_t>(shape.data, shape.data + shape.ndim);
}
} // namespace
/*************************************************************************************************** /***************************************************************************************************
* Comm+GEMM Overlap Common Core * Comm+GEMM Overlap Common Core
**************************************************************************************************/ **************************************************************************************************/
...@@ -147,13 +153,50 @@ CommOverlapCore::~CommOverlapCore() { ...@@ -147,13 +153,50 @@ CommOverlapCore::~CommOverlapCore() {
TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, size_t chunk_offset, TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, size_t chunk_offset,
const std::vector<size_t> &chunk_shape) { const std::vector<size_t> &chunk_shape) {
TensorWrapper chunk; const auto scaling_mode = source.scaling_mode();
// Tensor dimensions
std::vector<size_t> shape = shape_to_vector(source.shape());
auto flatten_shape_to_2d = [](const std::vector<size_t> &shape) -> std::pair<size_t, size_t> {
if (shape.empty()) {
return {1, 1};
}
size_t height = 1;
for (size_t i = 0; i < shape.size() - 1; ++i) {
height *= shape[i];
}
return {height, shape.back()};
};
size_t height, width, chunk_height, chunk_width;
std::tie(height, width) = flatten_shape_to_2d(shape);
std::tie(chunk_height, chunk_width) = flatten_shape_to_2d(chunk_shape);
// Check tensor dimensions
#define NVTE_DIM_CHECK(cond, message) \
NVTE_CHECK(cond, message, " (tensor shape=", shape, ", chunk shape=", chunk_shape, \
", chunk offset=", chunk_offset, ")")
NVTE_DIM_CHECK(height > 0 && width > 0, "Attempted to get chunk from empty tensor");
NVTE_DIM_CHECK(chunk_height > 0 && chunk_width > 0, "Attempted to get empty tensor chunk");
NVTE_DIM_CHECK(chunk_height <= height && chunk_width <= width,
"Attempted to get out-of-bounds tensor chunk");
if (scaling_mode == NVTEScalingMode::NVTE_MXFP8_1D_SCALING) {
// MXFP8 scale-inverses are padded to a 2D matrix with dims that
// are divisible by 128. UB doesn't handle this padding yet.
NVTE_DIM_CHECK(height % 128 == 0 && width % 128 == 0,
"Userbuffers requires MXFP8 tensor dims that are divisible by 128");
NVTE_DIM_CHECK(chunk_height % 128 == 0 && chunk_width % 128 == 0,
"Userbuffers requires MXFP8 tensor chunk dims that are divisible by 128");
}
#undef NVTE_DIM_CHECK
// Construct tensor chunk
TensorWrapper chunk(scaling_mode);
for (int param_id = 0; param_id < NVTETensorParam::kNVTENumTensorParams; param_id++) { for (int param_id = 0; param_id < NVTETensorParam::kNVTENumTensorParams; param_id++) {
auto param_type = static_cast<NVTETensorParam>(param_id); auto param_type = static_cast<NVTETensorParam>(param_id);
auto param = source.get_parameter(param_type); auto param = source.get_parameter(param_type);
auto param_dptr = reinterpret_cast<char *>(param.data_ptr); auto param_dptr = reinterpret_cast<char *>(param.data_ptr);
auto param_dtype = static_cast<DType>(param.dtype); auto param_dtype = static_cast<DType>(param.dtype);
auto param_shape = AS_VECTOR(param.shape); auto param_shape = shape_to_vector(param.shape);
if (param_dptr != nullptr) { if (param_dptr != nullptr) {
if (param_type == NVTETensorParam::kNVTERowwiseData || if (param_type == NVTETensorParam::kNVTERowwiseData ||
...@@ -163,8 +206,8 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz ...@@ -163,8 +206,8 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
param_shape = chunk_shape; param_shape = chunk_shape;
if (param_type == NVTETensorParam::kNVTEColumnwiseData && if (param_type == NVTETensorParam::kNVTEColumnwiseData &&
source.scaling_mode() != NVTEScalingMode::NVTE_MXFP8_1D_SCALING) { source.scaling_mode() == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) {
// Columnwise shape for non-block scaled tensors shifts the last dimension to the front // Columnwise shape for FP8 tensor-scaled tensors shifts the last dimension to the front
auto last_dim = param_shape.back(); auto last_dim = param_shape.back();
param_shape.pop_back(); param_shape.pop_back();
param_shape.insert(param_shape.begin(), last_dim); param_shape.insert(param_shape.begin(), last_dim);
...@@ -172,18 +215,16 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz ...@@ -172,18 +215,16 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
} else if (source.scaling_mode() == NVTEScalingMode::NVTE_MXFP8_1D_SCALING && } else if (source.scaling_mode() == NVTEScalingMode::NVTE_MXFP8_1D_SCALING &&
(param_type == NVTETensorParam::kNVTERowwiseScaleInv || (param_type == NVTETensorParam::kNVTERowwiseScaleInv ||
param_type == NVTETensorParam::kNVTEColumnwiseScaleInv)) { param_type == NVTETensorParam::kNVTEColumnwiseScaleInv)) {
// Calculate block scaling offset and size // Calculate offset and size for MXFP8 scale-invs
auto scaled_tensor_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv) size_t chunk_scale_height = chunk_height;
? source.shape().data[0] size_t chunk_scale_width = chunk_width;
: source.columnwise_shape().data[0]; if (param_type == NVTETensorParam::kNVTERowwiseScaleInv) {
auto scaled_chunk_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv) chunk_scale_width /= 32;
? chunk_shape.front() } else {
: chunk_shape.back(); chunk_scale_height /= 32;
auto chunk_scale_start = chunk_offset / 32; }
auto chunk_scale_end = (chunk_offset + scaled_chunk_dim_size) / 32; param_dptr += (chunk_offset / 32) * typeToSize(param_dtype);
auto chunk_scale_size = chunk_scale_end - chunk_scale_start; param_shape = {chunk_scale_height, chunk_scale_width};
param_dptr += chunk_scale_start * typeToSize(param_dtype);
param_shape = std::vector<size_t>{chunk_scale_size};
} }
// Set chunked source parameters into the chunked tensor output // Set chunked source parameters into the chunked tensor output
...@@ -434,10 +475,21 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -434,10 +475,21 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
size_t k = transa ? A.size(1) : A.size(0); size_t k = transa ? A.size(1) : A.size(0);
size_t n = _ubuf.size(0); size_t n = _ubuf.size(0);
size_t m_chunk = m / _num_splits; size_t m_chunk = m / _num_splits;
const std::vector<size_t> input_a_chunk_shape =
(transa ? std::vector<size_t>{m_chunk, k} : std::vector<size_t>{k, m_chunk});
const std::vector<size_t> output_chunk_shape = {n, m_chunk};
size_t input_a_chunk_size = m_chunk * k; size_t input_a_chunk_size = m_chunk * k;
size_t output_chunk_size = n * m_chunk; size_t output_chunk_size = n * m_chunk;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
// Helper function to get bias chunk if needed
auto maybe_get_bias_chunk = [this, &bias, m_chunk](size_t chunk_id) -> TensorWrapper {
if (bias.dptr() == nullptr) {
return TensorWrapper();
}
return get_tensor_chunk(bias, chunk_id * m_chunk, {m_chunk});
};
// Catch up the default torch stream // Catch up the default torch stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
for (size_t i = 0; i < _stream_compute.size(); i++) { for (size_t i = 0; i < _stream_compute.size(); i++) {
...@@ -449,12 +501,13 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -449,12 +501,13 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr()); char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
if (_rs_overlap_first_gemm) { if (_rs_overlap_first_gemm) {
auto input_a_chunk = get_tensor_chunk(A, 0, {m_chunk, k}); auto input_a_chunk = get_tensor_chunk(A, 0, input_a_chunk_shape);
auto output_chunk = get_buffer_chunk_like(D, 0, {m, m_chunk}); auto output_chunk = get_buffer_chunk_like(D, 0, output_chunk_shape);
auto bias_chunk = maybe_get_bias_chunk(0);
auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk}); auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk});
if (_ub_stream_nums == 1) { if (_ub_stream_nums == 1) {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[0]); use_split_accumulator, _math_sms, _stream_compute[0]);
} else { } else {
...@@ -464,18 +517,19 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -464,18 +517,19 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
} }
for (int i = 1; i < _num_splits; i++) { for (int i = 1; i < _num_splits; i++) {
input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k}); input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, input_a_chunk_shape);
output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk}); output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, output_chunk_shape);
bias_chunk = maybe_get_bias_chunk(i);
workspace_chunk = get_tensor_chunk( workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
if (_ub_stream_nums == 1) { if (_ub_stream_nums == 1) {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms, accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]); _stream_compute[i % _stream_compute.size()]);
} else { } else {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms, accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size()); _stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size());
...@@ -519,13 +573,14 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -519,13 +573,14 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
} }
} else { } else {
for (int i = 0; i < _num_splits; i++) { for (int i = 0; i < _num_splits; i++) {
auto input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k}); auto input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, input_a_chunk_shape);
auto output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk}); auto output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, output_chunk_shape);
auto bias_chunk = maybe_get_bias_chunk(i);
auto workspace_chunk = get_tensor_chunk( auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
if (_ub_stream_nums == 1) { if (_ub_stream_nums == 1) {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms, accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]); _stream_compute[i % _stream_compute.size()]);
...@@ -605,14 +660,17 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape, ...@@ -605,14 +660,17 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
void *buffer_ptr; void *buffer_ptr;
_ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true);
if (_rank == 0) printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); if (_rank == 0) printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg);
_ubuf = TensorWrapper(buffer_ptr, {buffer_shape[0] / tp_size * _num_ubuf_chunks, buffer_shape[1]}, _ubuf = TensorWrapper(
buffer_dtype); buffer_ptr,
std::vector<size_t>{buffer_shape[0] / tp_size * _num_ubuf_chunks, buffer_shape[1]},
buffer_dtype);
// Create tensor chunks for easy management // Create tensor chunks for easy management
char *ubuf_byte_ptr = reinterpret_cast<char *>(buffer_ptr); char *ubuf_byte_ptr = reinterpret_cast<char *>(buffer_ptr);
for (int i = 0; i < _num_ubuf_chunks; i++) { for (int i = 0; i < _num_ubuf_chunks; i++) {
_ubufs.push_back(TensorWrapper(reinterpret_cast<void *>(ubuf_byte_ptr), _ubufs.push_back(TensorWrapper(reinterpret_cast<void *>(ubuf_byte_ptr),
{buffer_shape[0] / tp_size, buffer_shape[1]}, buffer_dtype)); std::vector<size_t>{buffer_shape[0] / tp_size, buffer_shape[1]},
buffer_dtype));
ubuf_byte_ptr += buffer_chunk_bytes; ubuf_byte_ptr += buffer_chunk_bytes;
} }
...@@ -661,7 +719,7 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { ...@@ -661,7 +719,7 @@ CommOverlapP2PBase::~CommOverlapP2PBase() {
TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source, TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source,
size_t chunk_id) { size_t chunk_id) {
// Start with a chunk of the source tensor // Start with a chunk of the source tensor
auto chunk = get_tensor_chunk(source, 0, AS_VECTOR(_ubufs[chunk_id].shape())); auto chunk = get_tensor_chunk(source, 0, shape_to_vector(_ubufs[chunk_id].shape()));
// Update chunk with offset data pointers from the communication buffer // Update chunk with offset data pointers from the communication buffer
if (chunk.dptr() != nullptr) { if (chunk.dptr() != nullptr) {
...@@ -711,7 +769,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag( ...@@ -711,7 +769,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0));
auto input_b = get_buffer_chunk_like(B, 0, AS_VECTOR(B.shape())); auto input_b = get_buffer_chunk_like(B, 0, shape_to_vector(B.shape()));
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk}); auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk});
...@@ -798,8 +856,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -798,8 +856,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
// Get communication and GEMM output chunk sizes // Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const bool do_gelu = pre_gelu_out.numel() > 0; const bool do_gelu = pre_gelu_out.numel() > 0;
size_t input_chunk_size = n_chunk * k;
size_t output_chunk_size = n_chunk * m;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
...@@ -810,10 +866,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -810,10 +866,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
} }
if (_aggregate) { if (_aggregate) {
const int num_steps = _tp_size / 2; const int num_steps = _tp_size / 2;
#ifndef __HIP_PLATFORM_AMD__
input_chunk_size *= 2; // Chunk dims
output_chunk_size *= 2; std::vector<size_t> input_b_chunk_shape =
#endif (transb ? std::vector<size_t>{k, 2 * n_chunk} : std::vector<size_t>{2 * n_chunk, k});
std::vector<size_t> output_chunk_shape = {2 * n_chunk, k};
size_t input_b_chunk_size = 2 * n_chunk * k;
size_t output_chunk_size = 2 * n_chunk * m;
// Initial 1X input chunk exchange between neighboring peers // Initial 1X input chunk exchange between neighboring peers
int send_chunk_id = _tp_id; int send_chunk_id = _tp_id;
...@@ -842,8 +901,9 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -842,8 +901,9 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
// GEMM // GEMM
auto input_b_chunk = auto input_b_chunk =
get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk * 2, k}); get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape);
auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk * 2, m}); auto output_chunk =
get_tensor_chunk(D, output_chunk_size * send_chunk_id, output_chunk_shape);
auto aux_chunk = auto aux_chunk =
(do_gelu) (do_gelu)
? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk * 2, k}) ? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk * 2, k})
...@@ -882,6 +942,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -882,6 +942,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
} }
} }
} else { } else {
// Chunk dims
std::vector<size_t> input_b_chunk_shape =
(transb ? std::vector<size_t>{k, n_chunk} : std::vector<size_t>{n_chunk, k});
std::vector<size_t> output_chunk_shape = {n_chunk, m};
size_t input_b_chunk_size = n_chunk * k;
size_t output_chunk_size = n_chunk * m;
for (int i = 0; i < _tp_size; i++) { for (int i = 0; i < _tp_size; i++) {
// Set the userbuffer id. Buffer under send is the input for the current // Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
...@@ -893,8 +960,10 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -893,8 +960,10 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
int recv_offset = comm_bytes * recv_chunk_id; int recv_offset = comm_bytes * recv_chunk_id;
// GEMM // GEMM
auto input_b_chunk = get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk, k}); auto input_b_chunk =
auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk, m}); get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape);
auto output_chunk =
get_tensor_chunk(D, output_chunk_size * send_chunk_id, output_chunk_shape);
auto aux_chunk = auto aux_chunk =
(do_gelu) (do_gelu)
? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk, k}) ? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk, k})
...@@ -972,7 +1041,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs( ...@@ -972,7 +1041,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(
// Atomic GEMM // Atomic GEMM
// Process GEMM chunks in the order that AG+GEMM places the output chunks. // Process GEMM chunks in the order that AG+GEMM places the output chunks.
auto output_d = get_buffer_chunk_like(D, 0, AS_VECTOR(D.shape())); auto output_d = get_buffer_chunk_like(D, 0, shape_to_vector(D.shape()));
nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(), nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(),
transa, transb, grad, workspace.data(), accumulate, use_split_accumulator, transa, transb, grad, workspace.data(), accumulate, use_split_accumulator,
_math_sms, 0, _tp_size, true, _counter.data(), stream_main); _math_sms, 0, _tp_size, true, _counter.data(), stream_main);
...@@ -1053,6 +1122,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, ...@@ -1053,6 +1122,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
auto input_b_chunk = get_tensor_chunk(B, input_b_chunk_id * input_chunk_size, {n_chunk, k}); auto input_b_chunk = get_tensor_chunk(B, input_b_chunk_id * input_chunk_size, {n_chunk, k});
auto output_chunk = get_buffer_chunk_by_id(D, i); auto output_chunk = get_buffer_chunk_by_id(D, i);
auto workspace_chunk = auto workspace_chunk =
get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk}); get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk});
......
...@@ -35,6 +35,65 @@ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) { ...@@ -35,6 +35,65 @@ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) {
} }
} }
namespace {
constexpr size_t kThreadsPerBlock = 256;
template <typename TVectorized>
__global__ void __launch_bounds__(kThreadsPerBlock)
memset_kernel(void *__restrict__ ptr, int value, size_t size_in_bytes) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx * sizeof(TVectorized) >= size_in_bytes) {
return; // Out of bounds
}
if ((idx + 1) * sizeof(TVectorized) > size_in_bytes) {
// If the buffer size is not an even multiple of the vectorization, manually set the remaining bytes unvectorized.
size_t remaining_bytes = size_in_bytes - idx * sizeof(TVectorized);
memset(reinterpret_cast<uint8_t *>(ptr) + idx * sizeof(TVectorized), value, remaining_bytes);
return;
}
union {
TVectorized value;
uint8_t data[sizeof(TVectorized)];
} data;
for (size_t i = 0; i < sizeof(TVectorized); ++i) {
data.data[i] = static_cast<uint8_t>(value);
}
reinterpret_cast<TVectorized *>(ptr)[idx] = data.value;
}
} // namespace
#define MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, vectorizedType, stream) \
if (size_in_bytes >= sizeof(vectorizedType) && \
reinterpret_cast<size_t>(ptr) % sizeof(vectorizedType) == 0) { \
size_t numBlocks = DIVUP(size_in_bytes, kThreadsPerBlock * sizeof(vectorizedType)); \
dim3 grid(numBlocks, 1, 1); \
memset_kernel<vectorizedType> \
<<<grid, kThreadsPerBlock, 0, stream>>>(ptr, value, size_in_bytes); \
return; \
}
extern "C" {
void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream) {
NVTE_API_CALL(nvte_memset);
NVTE_CHECK(ptr != nullptr, "Pointer for memset must be allocated.");
if (size_in_bytes > 4096) {
// Use cudaMemsetAsync for larger sizes.
cudaMemsetAsync(ptr, value, size_in_bytes, stream);
return;
}
MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, float4, stream);
MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, float2, stream);
MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, float, stream);
MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, uint8_t, stream);
}
} // extern "C"
void checkCuDriverContext(CUstream stream) { void checkCuDriverContext(CUstream stream) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
return; return;
...@@ -144,4 +203,16 @@ bool is_supported_by_CC_100() { ...@@ -144,4 +203,16 @@ bool is_supported_by_CC_100() {
#endif #endif
} }
std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensors,
size_t outer_size, size_t inner_size) {
std::vector<std::vector<Tensor *>> ret;
for (size_t i = 0; i < outer_size; ++i) {
ret.emplace_back();
for (size_t j = 0; j < inner_size; ++j) {
ret.back().push_back(reinterpret_cast<Tensor *>(nvte_tensors[i][j]));
}
}
return ret;
}
} // namespace transformer_engine } // namespace transformer_engine
...@@ -116,7 +116,7 @@ struct Tensor { ...@@ -116,7 +116,7 @@ struct Tensor {
columnwise_scale_inv(nullptr, {1}, DType::kFloat32), columnwise_scale_inv(nullptr, {1}, DType::kFloat32),
scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {} scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {}
int numel() const { size_t numel() const {
size_t acc = 1; size_t acc = 1;
for (const auto dim : shape()) { for (const auto dim : shape()) {
acc *= dim; acc *= dim;
...@@ -138,6 +138,14 @@ struct Tensor { ...@@ -138,6 +138,14 @@ struct Tensor {
return data.dtype; return data.dtype;
} }
size_t dim() const {
if (!has_data() && has_columnwise_data()) {
return columnwise_data.shape.size();
} else {
return data.shape.size();
}
}
std::vector<size_t> shape() const { std::vector<size_t> shape() const {
/* Note: We sometimes experience spurious compiler errors /* Note: We sometimes experience spurious compiler errors
* (-Wstringop-overflow) from this function. It appears that GCC * (-Wstringop-overflow) from this function. It appears that GCC
...@@ -243,6 +251,7 @@ constexpr T DIVUP(const T &x, const T &y) { ...@@ -243,6 +251,7 @@ constexpr T DIVUP(const T &x, const T &y) {
} }
using byte = uint8_t; using byte = uint8_t;
using int16 = int16_t;
using int32 = int32_t; using int32 = int32_t;
using int64 = int64_t; using int64 = int64_t;
using fp32 = float; using fp32 = float;
...@@ -271,6 +280,7 @@ constexpr inline const char *type_name() noexcept; ...@@ -271,6 +280,7 @@ constexpr inline const char *type_name() noexcept;
return #T; \ return #T; \
} }
TRANSFORMER_ENGINE_TYPE_NAME(uint8_t) TRANSFORMER_ENGINE_TYPE_NAME(uint8_t)
TRANSFORMER_ENGINE_TYPE_NAME(int16_t)
TRANSFORMER_ENGINE_TYPE_NAME(int32_t) TRANSFORMER_ENGINE_TYPE_NAME(int32_t)
TRANSFORMER_ENGINE_TYPE_NAME(int64_t) TRANSFORMER_ENGINE_TYPE_NAME(int64_t)
TRANSFORMER_ENGINE_TYPE_NAME(float) TRANSFORMER_ENGINE_TYPE_NAME(float)
...@@ -327,7 +337,7 @@ struct TypeExtrema { ...@@ -327,7 +337,7 @@ struct TypeExtrema {
template <typename T> template <typename T>
struct TypeInfo { struct TypeInfo {
using types = std::tuple<byte, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2>; using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2>;
template <typename U, DType current> template <typename U, DType current>
struct Helper { struct Helper {
...@@ -364,6 +374,10 @@ struct TypeInfo { ...@@ -364,6 +374,10 @@ struct TypeInfo {
using type = unsigned char; \ using type = unsigned char; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} break; \ } break; \
case DType::kInt16: { \
using type = int16_t; \
{ __VA_ARGS__ } \
} break; \
case DType::kInt32: { \ case DType::kInt32: { \
using type = int32_t; \ using type = int32_t; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
...@@ -400,6 +414,33 @@ struct TypeInfo { ...@@ -400,6 +414,33 @@ struct TypeInfo {
NVTE_ERROR("Invalid type."); \ NVTE_ERROR("Invalid type."); \
} }
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat16: { \
using type = fp16; \
{ __VA_ARGS__ } \
} break; \
case DType::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E4M3: { \
using type = fp8e4m3; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E5M2: { \
using type = fp8e5m2; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \ #define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \
switch (dtype) { \ switch (dtype) { \
using namespace transformer_engine; \ using namespace transformer_engine; \
...@@ -599,6 +640,9 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, ...@@ -599,6 +640,9 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
bool is_supported_by_CC_100(); bool is_supported_by_CC_100();
std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensors,
size_t outer_size, size_t inner_size);
} // namespace transformer_engine } // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_ #endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../common.h"
#include "transformer_engine/fused_attn.h"
namespace transformer_engine {
namespace flash_attention {
constexpr int warp_size = 32;
constexpr int type_size = 2; // FP16 or BF16
constexpr int nvec = sizeof(uint64_t) / type_size;
constexpr int load_size = warp_size * nvec;
constexpr int block_size = 512;
template <typename T>
__launch_bounds__(block_size) __global__
void prepare_kernel_fwd(const T *qkvi, T *qkv, const size_t B, const size_t S, const size_t Z,
const size_t W) {
const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size;
const int id_in_warp = threadIdx.x % warp_size;
const size_t offset_input = blockIdx.y * W + warpid * 3 * W * Z + id_in_warp * nvec;
const T *my_input = qkvi + offset_input;
const size_t s = warpid / B;
if (s >= S) return;
const size_t b = warpid % B;
const size_t offset_output = blockIdx.y * B * S * Z * W + (s + b * S) * W * Z + id_in_warp * nvec;
T *my_output = qkv + offset_output;
for (int i = 0; i < Z; ++i) {
uint64_t *out = reinterpret_cast<uint64_t *>(my_output + i * load_size);
*out = *reinterpret_cast<const uint64_t *>(my_input + i * load_size * 3);
}
}
template <typename T>
__launch_bounds__(block_size) __global__
void prepare_kernel_bwd(const T *q, const T *k, const T *v, T *qkv, const size_t B,
const size_t S, const size_t Z, const size_t W) {
const T *input = blockIdx.y == 0 ? q : (blockIdx.y == 1 ? k : v);
const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size;
const int id_in_warp = threadIdx.x % warp_size;
const size_t offset_input = warpid * W * Z + id_in_warp * nvec;
const T *my_input = input + offset_input;
const size_t b = warpid / S;
if (b >= B) return;
const size_t s = warpid % S;
const size_t offset_output = (b + s * B) * 3 * W * Z + id_in_warp * nvec + blockIdx.y * W;
T *my_output = qkv + offset_output;
for (int i = 0; i < Z; ++i) {
uint64_t *out = reinterpret_cast<uint64_t *>(my_output + i * load_size * 3);
*out = *reinterpret_cast<const uint64_t *>(my_input + i * load_size);
}
}
void prepare_flash_attn_fwd(Tensor qkvi, Tensor qkv, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(qkvi.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(qkvi.dtype() == DType::kFloat16 || qkvi.dtype() == DType::kBFloat16);
auto qkvi_shape = qkvi.shape();
NVTE_CHECK(qkvi_shape[3] % load_size == 0);
NVTE_CHECK(qkvi_shape[3] == load_size);
// [s, b, n, h * 3] -> [3, b, s, n, h]
std::vector<uint64_t> shape = {3, qkvi_shape[1], qkvi_shape[0], qkvi_shape[2], qkvi_shape[3]};
size_t warps = qkvi_shape[0] * qkvi_shape[1];
size_t warps_per_block = block_size / warp_size;
size_t blocks = (warps + warps_per_block - 1) / warps_per_block;
dim3 grid(blocks, 3);
int threads = block_size;
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(
qkvi.dtype(), dtype,
prepare_kernel_fwd<dtype><<<grid, threads, 0, stream>>>(
reinterpret_cast<dtype *>(qkvi.data.dptr), reinterpret_cast<dtype *>(qkv.data.dptr),
shape[1], shape[2], shape[3], shape[4]););
}
void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(q.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(k.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(v.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(q.dtype() == DType::kFloat16 || q.dtype() == DType::kBFloat16);
NVTE_CHECK(k.dtype() == q.dtype());
NVTE_CHECK(v.dtype() == q.dtype());
auto q_shape = q.shape();
auto k_shape = k.shape();
auto v_shape = v.shape();
NVTE_CHECK(q_shape[3] % load_size == 0);
NVTE_CHECK(q_shape[3] == load_size);
NVTE_CHECK(k_shape[3] % load_size == 0);
NVTE_CHECK(k_shape[3] == load_size);
NVTE_CHECK(v_shape[3] % load_size == 0);
NVTE_CHECK(v_shape[3] == load_size);
// 3 x [s, b, n, h] -> [b, s, n, 3 * h]
std::vector<uint64_t> shape = {q_shape[1], q_shape[0], q_shape[2], 3 * q_shape[3]};
size_t warps = q_shape[0] * q_shape[1];
size_t warps_per_block = block_size / warp_size;
size_t blocks = (warps + warps_per_block - 1) / warps_per_block;
dim3 grid(blocks, 3);
int threads = block_size;
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(
q.dtype(), dtype,
prepare_kernel_bwd<dtype><<<grid, threads, 0, stream>>>(
reinterpret_cast<dtype *>(q.data.dptr), reinterpret_cast<dtype *>(k.data.dptr),
reinterpret_cast<dtype *>(v.data.dptr), reinterpret_cast<dtype *>(qkv.data.dptr),
q_shape[0], q_shape[1], q_shape[2], q_shape[3]););
}
} // namespace flash_attention
} // namespace transformer_engine
void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t stream) {
NVTE_API_CALL(nvte_prepare_flash_attn_fwd);
using namespace transformer_engine;
flash_attention::prepare_flash_attn_fwd(*reinterpret_cast<Tensor *>(qkvi),
*reinterpret_cast<Tensor *>(qkv), stream);
}
void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor qkv,
cudaStream_t stream) {
NVTE_API_CALL(nvte_prepare_flash_attn_bwd);
using namespace transformer_engine;
flash_attention::prepare_flash_attn_bwd(
*reinterpret_cast<Tensor *>(q), *reinterpret_cast<Tensor *>(k),
*reinterpret_cast<Tensor *>(v), *reinterpret_cast<Tensor *>(qkv), stream);
}
...@@ -1006,3 +1006,18 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -1006,3 +1006,18 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
} }
} }
uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len,
cudaStream_t stream) {
NVTE_API_CALL(nvte_get_runtime_num_segments);
using namespace transformer_engine::fused_attn;
return GetRuntimeNumSegments(cu_seqlen, workspace, len, stream);
}
void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor seed,
size_t q_max_seqlen, size_t kv_max_seqlen,
NVTE_Fused_Attn_Backend backend, cudaStream_t stream) {
NVTE_API_CALL(nvte_populate_rng_state_async);
using namespace transformer_engine::fused_attn;
PopulateRngStateAsync(rng_state_dst, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
}
...@@ -3,48 +3,15 @@ ...@@ -3,48 +3,15 @@
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_
#define TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_
namespace transformer_engine { #include "../common.h"
namespace fused_attn { #include "transformer_engine/fused_attn.h"
template <typename scalar_t>
__global__ void convert_thd_to_bshd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens,
int b, int max_seq_len, int h, int d) {
// tensor: thd; new_tensor: bshd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int num_elts = (cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]) * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
scalar_t *thd_token = tensor + thd_offset;
scalar_t *bshd_token = new_tensor + bshd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(bshd_token + i) = *(thd_token + i);
}
}
}
template <typename scalar_t> namespace transformer_engine {
__global__ void convert_bshd_to_thd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens, namespace kv_cache {
int b, int max_seq_len, int h, int d) {
// tensor: bshd; new_tensor: thd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx];
int num_elts = seqlen * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
scalar_t *bshd_token = tensor + bshd_offset;
scalar_t *thd_token = new_tensor + thd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(thd_token + i) = *(bshd_token + i);
}
}
}
template <typename scalar_t> template <typename dtype>
__global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, int *batch_indices, __global__ void reindex_kv_cache_kernel(dtype *k_cache, dtype *v_cache, int *batch_indices,
int *cu_new_lens, int *cu_cached_lens, int h_kv, int d_k, int *cu_new_lens, int *cu_cached_lens, int h_kv, int d_k,
int d_v, int b, int max_seq_len) { int d_v, int b, int max_seq_len) {
// k_cache, v_cache: bshd // k_cache, v_cache: bshd
...@@ -75,11 +42,11 @@ __global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, in ...@@ -75,11 +42,11 @@ __global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, in
} }
} }
template <typename scalar_t> template <typename dtype>
__global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar_t *k_cache, __global__ void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cache, dtype *v_cache,
scalar_t *v_cache, int *page_table, int *cu_new_lens, int *page_table, int *cu_new_lens, int *cu_cached_lens,
int *cu_cached_lens, NVTE_QKV_Format qkv_format, int h_kv, NVTE_QKV_Format qkv_format, int h_kv, int d_k, int d_v,
int d_k, int d_v, int b, int max_ctx_len, int max_seq_len, int b, int max_ctx_len, int max_seq_len,
int max_pages_per_seq, bool is_non_paged) { int max_pages_per_seq, bool is_non_paged) {
// new_k, new_v: qkv_format; k_cache, v_cache: bshd // new_k, new_v: qkv_format; k_cache, v_cache: bshd
// cu_new_lens, cu_cached_lens: [b + 1] // cu_new_lens, cu_cached_lens: [b + 1]
...@@ -140,6 +107,191 @@ __global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar ...@@ -140,6 +107,191 @@ __global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar
} }
} }
} }
} // namespace fused_attn
template <typename dtype>
void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tensor v_cache,
Tensor page_table, Tensor cu_new_lens, Tensor cu_cached_lens,
NVTE_QKV_Format qkv_format, int h_kv, int d_k, int d_v, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq,
bool is_non_paged, cudaStream_t stream) {
if (new_k.has_data() && new_v.has_data() && k_cache.has_data() && v_cache.has_data()) {
if (is_non_paged) {
reindex_kv_cache_kernel<<<16, 256, 0, stream>>>(
reinterpret_cast<dtype *>(k_cache.data.dptr),
reinterpret_cast<dtype *>(v_cache.data.dptr),
reinterpret_cast<int *>(page_table.data.dptr),
reinterpret_cast<int *>(cu_new_lens.data.dptr),
reinterpret_cast<int *>(cu_cached_lens.data.dptr), h_kv, d_k, d_v, b, max_seq_len);
}
copy_to_kv_cache_kernel<<<16, 256, 0, stream>>>(
reinterpret_cast<dtype *>(new_k.data.dptr), reinterpret_cast<dtype *>(new_v.data.dptr),
reinterpret_cast<dtype *>(k_cache.data.dptr), reinterpret_cast<dtype *>(v_cache.data.dptr),
reinterpret_cast<int *>(page_table.data.dptr),
reinterpret_cast<int *>(cu_new_lens.data.dptr),
reinterpret_cast<int *>(cu_cached_lens.data.dptr), qkv_format, h_kv, d_k, d_v, b,
max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged);
}
}
void copy_to_kv_cache(Tensor new_k, Tensor new_v, Tensor k_cache, Tensor v_cache, Tensor page_table,
Tensor cu_new_lens, Tensor cu_cached_lens, NVTE_QKV_Format qkv_format, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq, bool is_non_paged,
cudaStream_t stream) {
int h_kv = new_k.shape()[new_k.dim() - 2];
int d_k = new_k.shape()[new_k.dim() - 1];
int d_v = new_v.shape()[new_v.dim() - 1];
NVTE_CHECK(k_cache.dtype() == v_cache.dtype() && new_k.dtype() == new_v.dtype() &&
new_k.dtype() == k_cache.dtype(),
"new_k, new_v, k_cache and v_cache must be of the same data type.");
NVTE_CHECK(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD ||
qkv_format == NVTE_QKV_Format::NVTE_THD,
"qkv_format must be {BSHD, SBHD, THD}.");
TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(
k_cache.dtype(), dtype,
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged, stream););
}
template <typename scalar_t>
__global__ void convert_thd_to_bshd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens,
int b, int max_seq_len, int h, int d) {
// tensor: thd; new_tensor: bshd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int num_elts = (cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]) * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
scalar_t *thd_token = tensor + thd_offset;
scalar_t *bshd_token = new_tensor + bshd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(bshd_token + i) = *(thd_token + i);
}
}
}
template <typename scalar_t>
void convert_thd_to_bshd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_seqlens, int b,
int max_seq_len, int h, int d, cudaStream_t stream) {
using namespace transformer_engine;
convert_thd_to_bshd_kernel<<<16, 256, 0, stream>>>(
reinterpret_cast<scalar_t *>(tensor.data.dptr),
reinterpret_cast<scalar_t *>(new_tensor.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), b, max_seq_len, h, d);
}
void convert_thd_to_bshd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, int b,
int max_seq_len, cudaStream_t stream) {
using namespace transformer_engine;
auto tensor_shape = tensor.shape();
TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(
new_tensor.dtype(), dtype,
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len,
tensor_shape[1], tensor_shape[2], stream););
}
template <typename scalar_t>
__global__ void convert_bshd_to_thd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens,
int b, int max_seq_len, int h, int d) {
// tensor: bshd; new_tensor: thd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx];
int num_elts = seqlen * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
scalar_t *bshd_token = tensor + bshd_offset;
scalar_t *thd_token = new_tensor + thd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(thd_token + i) = *(bshd_token + i);
}
}
}
template <typename scalar_t>
void convert_bshd_to_thd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_seqlens, int b,
int max_seq_len, int h, int d, cudaStream_t stream) {
using namespace transformer_engine;
convert_bshd_to_thd_kernel<<<16, 256, 0, stream>>>(
reinterpret_cast<scalar_t *>(tensor.data.dptr),
reinterpret_cast<scalar_t *>(new_tensor.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), b, max_seq_len, h, d);
}
void convert_bshd_to_thd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, int t,
cudaStream_t stream) {
using namespace transformer_engine;
auto tensor_shape = tensor.shape();
TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(
tensor.dtype(), dtype,
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, tensor_shape[0],
tensor_shape[1], tensor_shape[2], tensor_shape[3],
stream););
}
} // namespace kv_cache
} // namespace transformer_engine } // namespace transformer_engine
#endif
/***************************************************************************************************
* KV Cache: Copy new KV tokens to the KV cache
* 1. new_k and new_v are in qkv_format; k_cache and v_cache are in 'bshd' format
* 2. cu_new_lens and cu_cached_lens are in shape [b + 1]; cu_cached_lens include the added lens
* in current step
* 3. Non-paged KV cache is a special case of paged KV cache, with page_table = [b, 1] and
* max_pages_per_seq = 1. We use the same underlying kernel for both non-paged and paged.
* Set is_non_paged = True/False to indicate as such.
* 4. is_non_paged = True also re-indexes the KV cache, e.g. the initial batch indices [0, 3, 1, 2]
* becomes [0, 1, 1, 2]. The page_table = batch_indices.unsqueeze(1) is however unchanged.
* batch_indices_post can be used for monotonical indexing, i.e. [0, 1, 2, 3]. batch_indices is
* preserved for the next layer in the same iteration.
* 5. Only supports same page_table for k_cache and v_cache
* 6. Only pad_between_seqs = False when qkv_format = thd, i.e. there should be no pad tokens
* between sequences in new_k and new_v such as [a a a 0..0 b b 0..0 c 0..0].
**************************************************************************************************/
void nvte_copy_to_kv_cache(NVTETensor new_k, NVTETensor new_v, NVTETensor k_cache,
NVTETensor v_cache, NVTETensor page_table, NVTETensor cu_new_lens,
NVTETensor cu_cached_lens, NVTE_QKV_Format qkv_format, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq,
int is_non_paged, cudaStream_t stream) {
NVTE_API_CALL(nvte_copy_to_kv_cache);
using namespace transformer_engine;
kv_cache::copy_to_kv_cache(
*reinterpret_cast<Tensor *>(new_k), *reinterpret_cast<Tensor *>(new_v),
*reinterpret_cast<Tensor *>(k_cache), *reinterpret_cast<Tensor *>(v_cache),
*reinterpret_cast<Tensor *>(page_table), *reinterpret_cast<Tensor *>(cu_new_lens),
*reinterpret_cast<Tensor *>(cu_cached_lens), qkv_format, b, max_ctx_len, max_seq_len,
max_pages_per_seq, is_non_paged, stream);
}
/***************************************************************************************************
* KV Cache: Convert a tensor from qkv_format = thd to qkv_format = bshd
**************************************************************************************************/
void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
int b, int max_seq_len, cudaStream_t stream) {
NVTE_API_CALL(nvte_convert_thd_to_bshd);
using namespace transformer_engine;
kv_cache::convert_thd_to_bshd(*reinterpret_cast<Tensor *>(tensor),
*reinterpret_cast<Tensor *>(cu_seqlens),
*reinterpret_cast<Tensor *>(new_tensor), b, max_seq_len, stream);
}
/***************************************************************************************************
* KV Cache: Convert a tensor from qkv_format = bshd to qkv_format = thd
**************************************************************************************************/
void nvte_convert_bshd_to_thd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
int t, cudaStream_t stream) {
NVTE_API_CALL(nvte_convert_bshd_to_thd);
using namespace transformer_engine;
kv_cache::convert_bshd_to_thd(*reinterpret_cast<Tensor *>(tensor),
*reinterpret_cast<Tensor *>(cu_seqlens),
*reinterpret_cast<Tensor *>(new_tensor), t, stream);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment