Unverified Commit fdc09f42 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Exposing RMSNorm in pyTorch (#306)



* Exposing RMSNorm in pyTorch extensions
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* First pass at the Python API
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Small fixes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Added numerics tests and fixed issues
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Lint fixes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Added RMSNorm to LayerNormMLP
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Added ONNX export and tests for RMSNorm
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix python lint
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix BERT case
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Added normalization option to the TransformerLayer
Added tests
Fixed test failures
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix documentation
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix kwarg bug
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix IMA and invalid type error
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Increase RMSNorm threshold for bf16 case
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix ONNX tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 06d5fa97
...@@ -19,6 +19,7 @@ directly from C/C++, without Python. ...@@ -19,6 +19,7 @@ directly from C/C++, without Python.
gemm.h <gemm> gemm.h <gemm>
fused_attn.h <fused_attn> fused_attn.h <fused_attn>
layer_norm.h <layer_norm> layer_norm.h <layer_norm>
rmsnorm.h <rmsnorm>
softmax.h <softmax> softmax.h <softmax>
transformer_engine.h <transformer_engine> transformer_engine.h <transformer_engine>
transpose.h <transpose> transpose.h <transpose>
..
Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
rmsnorm.h
============
.. doxygenfile:: rmsnorm.h
...@@ -11,6 +11,8 @@ pyTorch ...@@ -11,6 +11,8 @@ pyTorch
.. autoapiclass:: transformer_engine.pytorch.LayerNorm(hidden_size, eps=1e-5, **kwargs) .. autoapiclass:: transformer_engine.pytorch.LayerNorm(hidden_size, eps=1e-5, **kwargs)
.. autoapiclass:: transformer_engine.pytorch.RMSNorm(hidden_size, eps=1e-5, **kwargs)
.. autoapiclass:: transformer_engine.pytorch.LayerNormLinear(in_features, out_features, eps=1e-5, bias=True, **kwargs) .. autoapiclass:: transformer_engine.pytorch.LayerNormLinear(in_features, out_features, eps=1e-5, bias=True, **kwargs)
:members: forward :members: forward
......
...@@ -461,16 +461,20 @@ def setup_common_extension() -> CMakeExtension: ...@@ -461,16 +461,20 @@ def setup_common_extension() -> CMakeExtension:
cmake_flags=cmake_flags, cmake_flags=cmake_flags,
) )
def _all_files_in_dir(path):
return list(path.iterdir())
def setup_pytorch_extension() -> setuptools.Extension: def setup_pytorch_extension() -> setuptools.Extension:
"""Setup CUDA extension for PyTorch support""" """Setup CUDA extension for PyTorch support"""
# Source files # Source files
src_dir = root_path / "transformer_engine" / "pytorch" / "csrc" src_dir = root_path / "transformer_engine" / "pytorch" / "csrc"
extensions_dir = src_dir / "extensions"
sources = [ sources = [
src_dir / "extensions.cu",
src_dir / "common.cu", src_dir / "common.cu",
src_dir / "ts_fp8_op.cpp", src_dir / "ts_fp8_op.cpp",
] ] + \
_all_files_in_dir(extensions_dir)
# Header files # Header files
include_dirs = [ include_dirs = [
......
...@@ -21,7 +21,7 @@ from transformer_engine.pytorch.utils import ( ...@@ -21,7 +21,7 @@ from transformer_engine.pytorch.utils import (
attention_mask_func, attention_mask_func,
) )
from transformer_engine.pytorch import ( from transformer_engine.pytorch import (
DotProductAttention, Linear, LayerNormLinear, LayerNormMLP, TransformerLayer DotProductAttention, Linear, LayerNormLinear, LayerNormMLP, TransformerLayer, RMSNorm
) )
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
...@@ -59,6 +59,8 @@ all_boolean = [True, False] ...@@ -59,6 +59,8 @@ all_boolean = [True, False]
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
all_normalizations = ["LayerNorm", "RMSNorm"]
def get_causal_attn_mask(sq: int) -> torch.Tensor: def get_causal_attn_mask(sq: int) -> torch.Tensor:
return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool() return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()
...@@ -74,7 +76,16 @@ def assert_allclose(l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float) ...@@ -74,7 +76,16 @@ def assert_allclose(l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float)
"""Ensures two lists are equal.""" """Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs." assert len(l1) == len(l2), "Unequal number of outputs."
for t1, t2 in zip(l1, l2): for t1, t2 in zip(l1, l2):
assert torch.allclose(t1, t2, atol=atol), "Outputs not close enough." result = torch.allclose(t1, t2, atol=atol)
if not result:
diff = torch.abs(t1 - t2).flatten()
m = torch.argmax(diff)
msg = (f"Outputs not close enough."
f"Location of the maximum difference: {m.item()} "
f"with {t1.flatten()[m].item()} vs {t2.flatten()[m].item()} "
f"(diff {diff[m].item()})."
)
raise AssertionError(msg)
def _set_cuda_rng_state(new_state, device=-1): def _set_cuda_rng_state(new_state, device=-1):
...@@ -310,11 +321,38 @@ class TorchDotProductAttention(torch.nn.Module): ...@@ -310,11 +321,38 @@ class TorchDotProductAttention(torch.nn.Module):
return context_layer return context_layer
# Adapted from https://github.com/bzhangGo/rmsnorm/blob/c6691f20ec0af4128c8159c903071f7575404295/rmsnorm_torch.py
class TorchRMSNorm(nn.Module):
def __init__(self, in_features, eps=1e-5):
super().__init__()
self.eps = eps
self.in_features = in_features
self.weight = nn.Parameter(torch.ones(in_features))
self.register_parameter("weight", self.weight)
def forward(self, x):
norm_x = x.norm(2, dim=-1, keepdim=True)
d_x = self.in_features
rms_x = norm_x * d_x ** (-1. / 2)
x_normed = x / (rms_x + self.eps)
return self.weight * x_normed
class TorchLayerNormLinear(nn.Module): class TorchLayerNormLinear(nn.Module):
def __init__(self, in_features: int, out_features: int, eps: float, bias: bool = True): def __init__(self, in_features: int, out_features: int,
eps: float, bias: bool = True,
normalization: str = "LayerNorm"):
super().__init__() super().__init__()
self.layernorm = nn.LayerNorm(in_features, eps=eps) if normalization == "LayerNorm":
self.layernorm = nn.LayerNorm(in_features, eps=eps)
elif normalization == "RMSNorm":
self.layernorm = TorchRMSNorm(in_features, eps=eps)
else:
raise RuntimeError("Unsupported normalization")
self.linear = nn.Linear(in_features, out_features) self.linear = nn.Linear(in_features, out_features)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
...@@ -355,9 +393,15 @@ class TorchGLU(nn.Module): ...@@ -355,9 +393,15 @@ class TorchGLU(nn.Module):
class TorchLayerNormMLP(nn.Module): class TorchLayerNormMLP(nn.Module):
def __init__(self, hidden_size: int, ffn_hidden_size: int, def __init__(self, hidden_size: int, ffn_hidden_size: int,
eps: float = 1e-5, activation = 'gelu'): eps: float = 1e-5, activation = 'gelu',
normalization: str = "LayerNorm"):
super().__init__() super().__init__()
self.ln = nn.LayerNorm(hidden_size, eps=eps) if normalization == "LayerNorm":
self.ln = nn.LayerNorm(hidden_size, eps=eps)
elif normalization == "RMSNorm":
self.ln = TorchRMSNorm(hidden_size, eps=eps)
else:
raise RuntimeError("Unsupported normalization")
if 'glu' in activation: if 'glu' in activation:
fc1_output_features = 2 * ffn_hidden_size fc1_output_features = 2 * ffn_hidden_size
self.gelu = TorchGLU(activation) self.gelu = TorchGLU(activation)
...@@ -830,11 +874,48 @@ def test_linear_accuracy(dtype, bs, model): ...@@ -830,11 +874,48 @@ def test_linear_accuracy(dtype, bs, model):
else: else:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_rmsnorm_accuracy(dtype, bs, model):
config = model_configs[model]
te_rmsnorm = (
RMSNorm(
config.hidden_size,
)
.to(dtype=dtype)
.cuda()
.eval()
)
torch_rmsnorm = (
TorchRMSNorm(
config.hidden_size,
)
.to(dtype=dtype)
.cuda()
.eval()
)
# Share params
with torch.no_grad():
torch_rmsnorm.weight = Parameter(te_rmsnorm.weight.clone())
te_outputs = _test_granular_accuracy(te_rmsnorm, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_rmsnorm, bs, dtype, config)
# Check output.
if dtype == torch.float32:
assert_allclose(te_outputs[0], torch_outputs[0], 1e-7)
else:
assert_allclose(te_outputs[0], torch_outputs[0], 2e-2)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_layernorm_linear_accuracy(dtype, bs, model): @pytest.mark.parametrize("normalization", all_normalizations)
def test_layernorm_linear_accuracy(dtype, bs, model, normalization):
config = model_configs[model] config = model_configs[model]
te_ln_linear = ( te_ln_linear = (
...@@ -843,6 +924,7 @@ def test_layernorm_linear_accuracy(dtype, bs, model): ...@@ -843,6 +924,7 @@ def test_layernorm_linear_accuracy(dtype, bs, model):
4 * config.hidden_size, 4 * config.hidden_size,
config.eps, config.eps,
bias=True, bias=True,
normalization=normalization,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -855,6 +937,7 @@ def test_layernorm_linear_accuracy(dtype, bs, model): ...@@ -855,6 +937,7 @@ def test_layernorm_linear_accuracy(dtype, bs, model):
4 * config.hidden_size, 4 * config.hidden_size,
config.eps, config.eps,
bias=True, bias=True,
normalization=normalization,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -864,7 +947,8 @@ def test_layernorm_linear_accuracy(dtype, bs, model): ...@@ -864,7 +947,8 @@ def test_layernorm_linear_accuracy(dtype, bs, model):
# Share params # Share params
with torch.no_grad(): with torch.no_grad():
torch_ln_linear.layernorm.weight = Parameter(te_ln_linear.layer_norm_weight.clone()) torch_ln_linear.layernorm.weight = Parameter(te_ln_linear.layer_norm_weight.clone())
torch_ln_linear.layernorm.bias = Parameter(te_ln_linear.layer_norm_bias.clone()) if normalization != "RMSNorm":
torch_ln_linear.layernorm.bias = Parameter(te_ln_linear.layer_norm_bias.clone())
torch_ln_linear.linear.weight = Parameter(te_ln_linear.weight.clone()) torch_ln_linear.linear.weight = Parameter(te_ln_linear.weight.clone())
torch_ln_linear.linear.bias = Parameter(te_ln_linear.bias.clone()) torch_ln_linear.linear.bias = Parameter(te_ln_linear.bias.clone())
...@@ -882,7 +966,8 @@ def test_layernorm_linear_accuracy(dtype, bs, model): ...@@ -882,7 +966,8 @@ def test_layernorm_linear_accuracy(dtype, bs, model):
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("activation", all_activations) @pytest.mark.parametrize("activation", all_activations)
def test_layernorm_mlp_accuracy(dtype, bs, model, activation): @pytest.mark.parametrize("normalization", all_normalizations)
def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization):
config = model_configs[model] config = model_configs[model]
te_ln_mlp = ( te_ln_mlp = (
...@@ -890,6 +975,7 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation): ...@@ -890,6 +975,7 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation):
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
activation=activation, activation=activation,
normalization=normalization,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -901,6 +987,7 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation): ...@@ -901,6 +987,7 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation):
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
activation=activation, activation=activation,
normalization=normalization,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -910,7 +997,8 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation): ...@@ -910,7 +997,8 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation):
# Share params # Share params
with torch.no_grad(): with torch.no_grad():
torch_ln_mlp.ln.weight = Parameter(te_ln_mlp.layer_norm_weight.clone()) torch_ln_mlp.ln.weight = Parameter(te_ln_mlp.layer_norm_weight.clone())
torch_ln_mlp.ln.bias = Parameter(te_ln_mlp.layer_norm_bias.clone()) if normalization != "RMSNorm":
torch_ln_mlp.ln.bias = Parameter(te_ln_mlp.layer_norm_bias.clone())
torch_ln_mlp.fc1.weight = Parameter(te_ln_mlp.fc1_weight.clone()) torch_ln_mlp.fc1.weight = Parameter(te_ln_mlp.fc1_weight.clone())
torch_ln_mlp.fc1.bias = Parameter(te_ln_mlp.fc1_bias.clone()) torch_ln_mlp.fc1.bias = Parameter(te_ln_mlp.fc1_bias.clone())
torch_ln_mlp.fc2.weight = Parameter(te_ln_mlp.fc2_weight.clone()) torch_ln_mlp.fc2.weight = Parameter(te_ln_mlp.fc2_weight.clone())
......
...@@ -71,6 +71,8 @@ skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) ...@@ -71,6 +71,8 @@ skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
all_normalizations = ["LayerNorm", "RMSNorm"]
@pytest.fixture() @pytest.fixture()
def seed_default_rng(): def seed_default_rng():
...@@ -676,6 +678,90 @@ def test_export_layernorm( ...@@ -676,6 +678,90 @@ def test_export_layernorm(
validate_result( validate_result(
fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs) fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs)
@pytest.mark.parametrize("scale_factor", [448, 112])
@pytest.mark.parametrize(
"use_fp8, precision, atol", [
[False, torch.float32, 1e-7],
[False, torch.float16, 1e-7],
[False, torch.bfloat16, 1e-7],
[False, "fake-torch.bfloat16", 1e-7],
[True, torch.float32, 1e-7],
[True, torch.float16, 1e-7],
[True, torch.bfloat16, 1e-2],
[True, "fake-torch.bfloat16", 1e-2]
])
def test_export_rmsnorm(
seed_default_rng,
use_fp8: bool,
scale_factor: float,
precision: torch.dtype,
atol: float
):
fake_bf16_io = precision == "fake-torch.bfloat16"
# reset precision to torch.bfloat16 after capturing fake BF16 mode
precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision
# Skip FP8 tests on non-hopper devices
if use_fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
# Set dimensions (these are arbitrary).
inp_shape = [64, 32]
class Test_RMSnorm(nn.Module):
def __init__(self) -> None:
super().__init__()
eps = 1e-6 # An arbitrary small value
dtype = torch.float if fake_bf16_io else precision
self.ln = te.RMSNorm(inp_shape[1], eps, params_dtype=dtype).eval().cuda()
def forward(self, inp):
ret = self.ln(inp)
return ret
class TestFP8_RMSnorm(nn.Module):
def __init__(self) -> None:
super().__init__()
normalized_shape = torch.Size(inp.shape[1:])
self.weight = torch.randn(*normalized_shape, device="cuda",
dtype=torch.float32 if fake_bf16_io else precision)
self.eps = 1e-6 # An arbitrary small value
self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT
self.meta = create_meta(scale_factor)
self.fp8_type = tex.DType.kFloat8E4M3
def forward(self, inp):
ret = texcpp.rmsnorm_fwd_fp8_inf(
inp,
self.weight,
self.eps,
self.meta,
self.fp8_tensor,
self.fp8_type,
False)
ret = cast_from_fp8(
ret,
self.meta,
self.fp8_tensor,
self.fp8_type,
as_te_type(precision))
if fake_bf16_io:
ret = ret.type(torch.float32)
return ret
inp = torch.randn(*inp_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision)
model = TestFP8_RMSnorm() if use_fp8 else Test_RMSnorm()
high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io)
fp8_str = f"_fp8-{scale_factor}" if use_fp8 else ""
fname = f"te.layernorm{fp8_str}{high_prec_str}.onnx"
do_export(model, inp, fname, use_fp8=use_fp8)
te_outputs = te_infer(model, inp, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp, te_outputs)
if fake_bf16_io or precision != torch.bfloat16:
validate_result(
fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs)
@skip_FP8 @skip_FP8
@pytest.mark.parametrize("softmax_fn", [ @pytest.mark.parametrize("softmax_fn", [
...@@ -916,6 +1002,7 @@ def test_export_linear( ...@@ -916,6 +1002,7 @@ def test_export_linear(
(torch.bfloat16, False), (torch.bfloat16, False),
]) ])
@pytest.mark.parametrize("zero_centered_gamma", [False, True]) @pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("normalization", all_normalizations)
def test_export_layernorm_linear( def test_export_layernorm_linear(
seed_default_rng, seed_default_rng,
scale_factor: float, scale_factor: float,
...@@ -924,12 +1011,16 @@ def test_export_layernorm_linear( ...@@ -924,12 +1011,16 @@ def test_export_layernorm_linear(
return_bias: bool, return_bias: bool,
return_layernorm_output: bool, return_layernorm_output: bool,
precision: torch.dtype, precision: torch.dtype,
zero_centered_gamma: bool zero_centered_gamma: bool,
normalization: str,
): ):
# Skip FP8 tests on non-hopper devices # Skip FP8 tests on non-hopper devices
if use_fp8 and not fp8_available: if use_fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
# Set dimensions (these are arbitrary). # Set dimensions (these are arbitrary).
in_features = 64 in_features = 64
out_features = 256 out_features = 256
...@@ -950,6 +1041,7 @@ def test_export_layernorm_linear( ...@@ -950,6 +1041,7 @@ def test_export_layernorm_linear(
return_layernorm_output=return_layernorm_output, return_layernorm_output=return_layernorm_output,
params_dtype=precision, params_dtype=precision,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
normalization=normalization,
).to(device='cuda') ).to(device='cuda')
if use_fp8: if use_fp8:
set_layer_scale(model, scale_factor, num_gemms=1) set_layer_scale(model, scale_factor, num_gemms=1)
...@@ -980,6 +1072,7 @@ def test_export_layernorm_linear( ...@@ -980,6 +1072,7 @@ def test_export_layernorm_linear(
]) ])
@pytest.mark.parametrize("zero_centered_gamma", [False, True]) @pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("activation", supported_activations) @pytest.mark.parametrize("activation", supported_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_export_layernorm_mlp( def test_export_layernorm_mlp(
seed_default_rng, seed_default_rng,
scale_factor: float, scale_factor: float,
...@@ -990,11 +1083,15 @@ def test_export_layernorm_mlp( ...@@ -990,11 +1083,15 @@ def test_export_layernorm_mlp(
precision: torch.dtype, precision: torch.dtype,
zero_centered_gamma: bool, zero_centered_gamma: bool,
activation: str, activation: str,
normalization: str,
): ):
# Skip FP8 tests on non-hopper devices # Skip FP8 tests on non-hopper devices
if use_fp8 and not fp8_available: if use_fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
# Set dimensions (these are arbitrary). # Set dimensions (these are arbitrary).
in_features = 64 in_features = 64
out_features = 256 out_features = 256
...@@ -1016,6 +1113,7 @@ def test_export_layernorm_mlp( ...@@ -1016,6 +1113,7 @@ def test_export_layernorm_mlp(
params_dtype=precision, params_dtype=precision,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
activation=activation, activation=activation,
normalization=normalization,
).to(device='cuda') ).to(device='cuda')
if use_fp8: if use_fp8:
set_layer_scale(model, scale_factor, num_gemms=2) set_layer_scale(model, scale_factor, num_gemms=2)
......
...@@ -95,6 +95,7 @@ batch_sizes = [1, 2] ...@@ -95,6 +95,7 @@ batch_sizes = [1, 2]
all_boolean = [True, False] all_boolean = [True, False]
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
all_normalizations = ["LayerNorm", "RMSNorm"]
def _disable_wgrads(block): def _disable_wgrads(block):
for p in block.parameters(): for p in block.parameters():
...@@ -314,10 +315,16 @@ def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad, skip_d ...@@ -314,10 +315,16 @@ def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad, skip_d
@pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean) @pytest.mark.parametrize("skip_dgrad", all_boolean)
def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad): @pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad,
zero_centered_gamma, skip_dgrad,
normalization):
if fp8_recipe is not None and not fp8_available: if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -330,6 +337,7 @@ def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad, zero_ ...@@ -330,6 +337,7 @@ def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad, zero_
eps=config.eps, eps=config.eps,
init_method=init_method, init_method=init_method,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
normalization=normalization,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -370,10 +378,16 @@ def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad, skip_dgrad): ...@@ -370,10 +378,16 @@ def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad, skip_dgrad):
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean) @pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("activation", all_activations) @pytest.mark.parametrize("activation", all_activations)
def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, activation): @pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad,
zero_centered_gamma, skip_dgrad, activation,
normalization):
if fp8_recipe is not None and not fp8_available: if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -389,6 +403,7 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_cen ...@@ -389,6 +403,7 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_cen
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
activation=activation, activation=activation,
normalization=normalization,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -404,10 +419,16 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_cen ...@@ -404,10 +419,16 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_cen
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("bias", all_boolean) @pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("activation", all_activations) @pytest.mark.parametrize("activation", all_activations)
def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma, bias, activation): @pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad,
zero_centered_gamma, bias, activation,
normalization):
if fp8_recipe is not None and not fp8_available: if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -430,6 +451,7 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamm ...@@ -430,6 +451,7 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamm
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
bias=bias, bias=bias,
activation=activation, activation=activation,
normalization=normalization,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -444,10 +466,15 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamm ...@@ -444,10 +466,15 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamm
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma): @pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
normalization):
if fp8_recipe is not None and not fp8_available: if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -468,6 +495,7 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gam ...@@ -468,6 +495,7 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gam
apply_residual_connection_post_layernorm=True, apply_residual_connection_post_layernorm=True,
output_layernorm=True, output_layernorm=True,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
normalization=normalization,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -482,10 +510,15 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gam ...@@ -482,10 +510,15 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gam
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma): @pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
normalization):
if fp8_recipe is not None and not fp8_available: if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -507,6 +540,7 @@ def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma ...@@ -507,6 +540,7 @@ def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma
output_layernorm=False, output_layernorm=False,
layer_type="decoder", layer_type="decoder",
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
normalization=normalization,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -669,10 +703,15 @@ def test_sanity_gradient_accumulation_fusion(dtype, bs, fp8_recipe, model, skip_ ...@@ -669,10 +703,15 @@ def test_sanity_gradient_accumulation_fusion(dtype, bs, fp8_recipe, model, skip_
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_gpt_cuda_graph(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma): @pytest.mark.parametrize("normalization", all_normalizations)
def test_gpt_cuda_graph(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
normalization):
if fp8_recipe is not None and not fp8_available: if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -694,6 +733,7 @@ def test_gpt_cuda_graph(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_ ...@@ -694,6 +733,7 @@ def test_gpt_cuda_graph(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_
output_layernorm=False, output_layernorm=False,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
fuse_qkv_params=True, fuse_qkv_params=True,
normalization=normalization,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
......
...@@ -7,6 +7,7 @@ from .module import LayerNormLinear ...@@ -7,6 +7,7 @@ from .module import LayerNormLinear
from .module import Linear from .module import Linear
from .module import LayerNormMLP from .module import LayerNormMLP
from .module import LayerNorm from .module import LayerNorm
from .module import RMSNorm
from .attention import DotProductAttention from .attention import DotProductAttention
from .transformer import TransformerLayer from .transformer import TransformerLayer
from .fp8 import fp8_autocast from .fp8 import fp8_autocast
...@@ -21,4 +22,6 @@ from .te_onnx_extensions import ( ...@@ -21,4 +22,6 @@ from .te_onnx_extensions import (
onnx_te_gemm, onnx_te_gemm,
onnx_layernorm_fwd_fp8, onnx_layernorm_fwd_fp8,
onnx_layernorm_fwd, onnx_layernorm_fwd,
onnx_rmsnorm_fwd,
onnx_rmsnorm_fwd_fp8
) )
...@@ -990,6 +990,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -990,6 +990,7 @@ class MultiHeadAttention(torch.nn.Module):
ub_split_rs: bool = False, ub_split_rs: bool = False,
ub_split_ag: bool = False, ub_split_ag: bool = False,
bias: bool = True, bias: bool = True,
normalization: str = "LayerNorm",
) -> None: ) -> None:
super().__init__() super().__init__()
self.layer_number = layer_number self.layer_number = layer_number
...@@ -1044,6 +1045,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -1044,6 +1045,7 @@ class MultiHeadAttention(torch.nn.Module):
ub_bulk_wgrad=ub_bulk_wgrad, ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad, ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag, ub_split_ag=ub_split_ag,
normalization=normalization,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
else: else:
...@@ -1072,6 +1074,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -1072,6 +1074,7 @@ class MultiHeadAttention(torch.nn.Module):
ub_bulk_wgrad=ub_bulk_wgrad, ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad, ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag, ub_split_ag=ub_split_ag,
normalization=normalization,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
else: else:
......
...@@ -10,7 +10,10 @@ import transformer_engine_extensions as tex ...@@ -10,7 +10,10 @@ import transformer_engine_extensions as tex
__all__ = ['layernorm_fwd_fp8', __all__ = ['layernorm_fwd_fp8',
'layernorm_fwd_fp8_inf', 'layernorm_fwd_fp8_inf',
'layernorm_fwd_inf'] 'layernorm_fwd_inf',
'rmsnorm_fwd_fp8',
'rmsnorm_fwd_fp8_inf',
'rmsnorm_fwd_inf']
def layernorm_fwd_fp8( def layernorm_fwd_fp8(
...@@ -99,3 +102,83 @@ def layernorm_fwd_inf( ...@@ -99,3 +102,83 @@ def layernorm_fwd_inf(
eps, eps,
zero_centered_gamma, zero_centered_gamma,
) )
def rmsnorm_fwd_fp8(
inp: torch.Tensor,
weight: torch.Tensor,
eps: float,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
sm_margin: int,
zero_centered_gamma: bool,
rmsnorm_out: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""RMSNorm with FP8 output"""
if rmsnorm_out is not None:
return tex.rmsnorm_fwd_fp8_noalloc(
inp,
weight,
eps,
fp8_meta_tensor.scale[fp8_tensor],
rmsnorm_out,
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
otype,
sm_margin,
zero_centered_gamma
)
return tex.rmsnorm_fwd_fp8(
inp,
weight,
eps,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
otype,
sm_margin,
zero_centered_gamma
)
def rmsnorm_fwd_fp8_inf(
inp: torch.Tensor,
weight: torch.Tensor,
eps: float,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
zero_centered_gamma,
) -> torch.Tensor:
"""RMSNorm with FP8 output.
This version of rmsnorm_fwd_fp8 is specialized for inference, and returns
only the normalized output.
"""
ret = torch.ops.tex_ts.rmsnorm_fwd_fp8_inf_ts(
inp,
weight,
eps,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype,
zero_centered_gamma)
return ret
def rmsnorm_fwd_inf(
inp: torch.Tensor,
weight: torch.Tensor,
eps: float,
zero_centered_gamma: bool,
) -> torch.Tensor:
"""RMSNorm with FP8 output"""
return torch.ops.tex_ts.rmsnorm_fwd_inf_ts(
inp,
weight,
eps,
zero_centered_gamma,
)
...@@ -137,3 +137,11 @@ at::Tensor allocateTorchTensor(int M, ...@@ -137,3 +137,11 @@ at::Tensor allocateTorchTensor(int M,
return at::empty({static_cast<int64_t>(M)}, return at::empty({static_cast<int64_t>(M)},
at::CUDA(GetATenDType(dtype))); at::CUDA(GetATenDType(dtype)));
} }
void *getDataPtr(at::Tensor t) {
if (t.numel() > 0) {
return t.data_ptr();
} else {
return nullptr;
}
}
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <transformer_engine/gemm.h> #include <transformer_engine/gemm.h>
#include <transformer_engine/layer_norm.h> #include <transformer_engine/layer_norm.h>
#include <transformer_engine/rmsnorm.h>
#include <transformer_engine/transpose.h> #include <transformer_engine/transpose.h>
#include <transformer_engine/activation.h> #include <transformer_engine/activation.h>
#include <transformer_engine/logging.h> #include <transformer_engine/logging.h>
...@@ -180,4 +181,6 @@ at::Tensor allocateTorchTensor(int M, ...@@ -180,4 +181,6 @@ at::Tensor allocateTorchTensor(int M,
transformer_engine::DType dtype transformer_engine::DType dtype
); );
void *getDataPtr(at::Tensor t);
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_
...@@ -106,6 +106,10 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -106,6 +106,10 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
c10::optional<at::Tensor> amax_dP, c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV); c10::optional<at::Tensor> amax_dQKV);
at::Tensor fa_prepare_fwd(at::Tensor qkvi);
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);
void te_gemm(at::Tensor A, void te_gemm(at::Tensor A,
at::Tensor A_scale_inverse, at::Tensor A_scale_inverse,
transformer_engine::DType A_type, transformer_engine::DType A_type,
...@@ -318,6 +322,77 @@ at::Tensor layernorm_fwd_inf(const at::Tensor &input, ...@@ -318,6 +322,77 @@ at::Tensor layernorm_fwd_inf(const at::Tensor &input,
const bool zero_centered_gamma const bool zero_centered_gamma
); );
/***************************************************************************************************
* RMSNorm
**************************************************************************************************/
std::vector<at::Tensor> rmsnorm_bwd(const at::Tensor &dz,
const at::Tensor &x,
const at::Tensor &rsigma,
const at::Tensor &gamma,
const int sm_margin,
const bool zero_centered_gamma
);
std::vector<at::Tensor> rmsnorm_fwd_fp8(const at::Tensor &input,
const at::Tensor &weight,
float eps,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
);
std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input,
const at::Tensor &weight,
float eps,
at::Tensor scale,
at::Tensor ln_out,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
);
at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input,
const at::Tensor &weight,
float eps,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
const bool zero_centered_gamma
);
std::vector<at::Tensor> rmsnorm_fwd(const at::Tensor &input,
const at::Tensor &weight,
float eps,
const int sm_margin,
const bool zero_centered_gamma
);
std::vector<at::Tensor> rmsnorm_fwd_noalloc(const at::Tensor &input,
const at::Tensor &weight,
at::Tensor ln_out,
float eps,
const int sm_margin,
const bool zero_centered_gamma
);
at::Tensor rmsnorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight,
float eps,
const bool zero_centered_gamma
);
/***************************************************************************************************
* Cast
**************************************************************************************************/
at::Tensor cast_to_fp8(const at::Tensor &input, at::Tensor cast_to_fp8(const at::Tensor &input,
const at::Tensor &scale, const at::Tensor &scale,
at::Tensor amax, at::Tensor amax,
...@@ -374,3 +449,9 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, ...@@ -374,3 +449,9 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
at::Tensor softmax_results_, at::Tensor softmax_results_,
float scale_factor float scale_factor
); );
size_t get_cublasLt_version();
bool userbuf_comm_available();
void placeholder();
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
at::Tensor gelu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_gelu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor dgelu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto gtype = GetTransformerEngineDType(grad.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype);
nvte_dgelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor relu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = static_cast<size_t>(input.numel()) / N;
auto output =
allocateTorchTensor(M,
N,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_relu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor drelu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto gtype = GetTransformerEngineDType(grad.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype);
nvte_drelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor geglu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N / 2,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_geglu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor dgeglu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto gtype = GetTransformerEngineDType(grad.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype);
nvte_dgeglu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor reglu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N / 2,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_reglu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor dreglu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto gtype = GetTransformerEngineDType(grad.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype);
nvte_dreglu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor swiglu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N / 2,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_swiglu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor dswiglu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto gtype = GetTransformerEngineDType(grad.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype);
nvte_dswiglu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
...@@ -5,9 +5,6 @@ ...@@ -5,9 +5,6 @@
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "extensions.h"
#ifdef NVTE_WITH_USERBUFFERS
#include "comm_gemm_overlap.h"
#endif // NVTE_WITH_USERBUFFERS
constexpr int block_size = 512; constexpr int block_size = 512;
constexpr int ctas_per_sm = 4; constexpr int ctas_per_sm = 4;
...@@ -708,1261 +705,6 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -708,1261 +705,6 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
return {dQ, dKV, dBias}; return {dQ, dKV, dBias};
} }
void te_gemm(at::Tensor A,
at::Tensor A_scale_inverse,
transformer_engine::DType A_type,
bool transa,
at::Tensor B,
at::Tensor B_scale_inverse,
transformer_engine::DType B_type,
bool transb,
at::Tensor D,
at::Tensor D_scale,
transformer_engine::DType D_type,
at::Tensor D_amax,
at::Tensor bias,
transformer_engine::DType bias_type,
at::Tensor pre_gelu_out,
bool grad,
at::Tensor workspace,
size_t workspaceSize,
bool accumulate,
bool use_split_accumulator,
int math_sm_count
) {
using namespace transformer_engine;
auto te_A = makeTransformerEngineTensor(A.data_ptr(),
{static_cast<size_t>(A.size(0)),
static_cast<size_t>(A.size(1))},
A_type, nullptr, nullptr,
A_scale_inverse.data_ptr());
auto te_B = makeTransformerEngineTensor(B.data_ptr(),
{static_cast<size_t>(B.size(0)),
static_cast<size_t>(B.size(1))},
B_type, nullptr, nullptr,
B_scale_inverse.data_ptr());
auto te_D = makeTransformerEngineTensor(D.data_ptr(),
{static_cast<size_t>(D.size(0)),
static_cast<size_t>(D.size(1))},
D_type, D_amax.data_ptr(),
D_scale.data_ptr(), nullptr);
auto te_bias = makeTransformerEngineTensor(bias.data_ptr(), {static_cast<size_t>(bias.size(0))},
bias_type);
const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr
? std::vector<size_t>{static_cast<size_t>(pre_gelu_out.size(0))}
: std::vector<size_t>{static_cast<size_t>(pre_gelu_out.size(0)),
static_cast<size_t>(pre_gelu_out.size(1))};
auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out.data_ptr(),
gelu_shape,
GetTransformerEngineDType(
pre_gelu_out.scalar_type()));
auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(),
{workspaceSize},
DType::kByte);
nvte_cublas_gemm(te_A.data(),
te_B.data(),
te_D.data(),
te_bias.data(),
te_pre_gelu_out.data(),
transa,
transb,
grad,
te_workspace.data(),
accumulate,
use_split_accumulator,
math_sm_count,
at::cuda::getCurrentCUDAStream());
}
void fused_cast_transpose(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
at::Tensor input_cast,
at::Tensor input_transpose,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
auto input_cu = makeTransformerEngineTensor(input);
auto output_cast_cu = makeTransformerEngineTensor(input_cast.data_ptr(), {M, N}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto output_transpose_cu = makeTransformerEngineTensor(input_transpose.data_ptr(), {N, M}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_cast_transpose(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(),
at::cuda::getCurrentCUDAStream());
}
std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(grad_output.size(0));
size_t N = static_cast<size_t>(grad_output.size(1));
DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type());
auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_output_type);
auto grad_output_cast =
allocateTorchTensor(grad_output.size(0),
grad_output.size(1),
DType::kByte);
auto grad_output_transpose =
allocateTorchTensor(grad_output.size(1),
grad_output.size(0),
DType::kByte);
auto input_cu = makeTransformerEngineTensor(grad_output);
auto cast_output_cu = makeTransformerEngineTensor(grad_output_cast.data_ptr(), {M, N},
otype, amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto transposed_output_cu = makeTransformerEngineTensor(grad_output_transpose.data_ptr(),
{N, M}, otype, amax.data_ptr(),
scale.data_ptr(), scale_inv.data_ptr());
auto dbias_cu = makeTransformerEngineTensor(grad_bias);
transformer_engine::TensorWrapper workspace;
nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
// Fill workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
return {grad_bias, grad_output_cast, grad_output_transpose};
}
std::vector<at::Tensor> fused_fp8_transpose_bgrad(at::Tensor grad_output,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
transformer_engine::DType grad_bias_type
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(grad_output.size(0));
size_t N = static_cast<size_t>(grad_output.size(1));
auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_bias_type);
auto grad_output_transpose =
allocateTorchTensor(grad_output.size(1),
grad_output.size(0),
DType::kByte);
auto input_cu = makeTransformerEngineTensor(grad_output.data_ptr(), {M, N},
otype, amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto transposed_output_cu = makeTransformerEngineTensor(grad_output_transpose.data_ptr(),
{N, M}, otype, amax.data_ptr(),
scale.data_ptr(), scale_inv.data_ptr());
auto dbias_cu = makeTransformerEngineTensor(grad_bias);
transformer_engine::TensorWrapper workspace;
nvte_fp8_transpose_dbias(input_cu.data(), transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
// Fill workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
nvte_fp8_transpose_dbias(input_cu.data(), transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
return {grad_bias, grad_output_transpose};
}
std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
at::Tensor gelu_input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(grad_output.size(0));
size_t N = static_cast<size_t>(grad_output.size(1));
DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type());
auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_output_type);
auto dgelu =
allocateTorchTensor(grad_output.size(0),
grad_output.size(1),
DType::kByte);
auto dgelu_transpose =
allocateTorchTensor(grad_output.size(1),
grad_output.size(0),
DType::kByte);
transformer_engine::TensorWrapper workspace;
auto gelu_input_cu = makeTransformerEngineTensor(gelu_input);
auto input_cu = makeTransformerEngineTensor(grad_output);
auto cast_output_cu = makeTransformerEngineTensor(dgelu.data_ptr(), {M, N},
otype, amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto transposed_output_cu = makeTransformerEngineTensor(dgelu_transpose.data_ptr(), {N, M},
otype, amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto dbias_cu = makeTransformerEngineTensor(grad_bias);
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(),
cast_output_cu.data(), transposed_output_cu.data(),
dbias_cu.data(), workspace.data(),
at::cuda::getCurrentCUDAStream());
// Fill workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(),
cast_output_cu.data(), transposed_output_cu.data(),
dbias_cu.data(), workspace.data(),
at::cuda::getCurrentCUDAStream());
return {grad_bias, dgelu, dgelu_transpose};
}
void fused_multi_cast_transpose(std::vector<at::Tensor> input_list,
std::vector<at::Tensor> scale_list,
std::vector<at::Tensor> cast_output_list,
std::vector<at::Tensor> transposed_output_list,
std::vector<at::Tensor> amax_list,
std::vector<at::Tensor> scale_inv_list,
transformer_engine::DType otype
) {
using namespace transformer_engine;
// Extract properties from PyTorch tensors
std::vector<void*> input_dptr_list, scale_dptr_list,
cast_output_dptr_list, transposed_output_dptr_list,
amax_dptr_list, scale_inv_dptr_list;
std::vector<std::vector<size_t>> input_shape_list, scale_shape_list,
cast_output_shape_list, transposed_output_shape_list,
amax_shape_list, scale_inv_shape_list;
std::vector<transformer_engine::DType> input_type_list, scale_type_list,
cast_output_type_list, transposed_output_type_list,
amax_type_list, scale_inv_type_list;
auto extract_tensor_props_skip_dtype = [](at::Tensor& tensor,
std::vector<void*>& dptr_list,
std::vector<std::vector<size_t>>& shape_list) {
dptr_list.push_back(tensor.data_ptr());
shape_list.push_back({});
for (int d = 0; d < tensor.dim(); ++d) {
shape_list.back().push_back(tensor.size(d));
}
};
auto extract_tensor_props = [](at::Tensor& tensor,
std::vector<void*>& dptr_list,
std::vector<std::vector<size_t>>& shape_list,
std::vector<transformer_engine::DType>& type_list) {
dptr_list.push_back(tensor.data_ptr());
shape_list.push_back({});
for (int d = 0; d < tensor.dim(); ++d) {
shape_list.back().push_back(tensor.size(d));
}
type_list.push_back(GetTransformerEngineDType(tensor.scalar_type()));
};
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
extract_tensor_props(input_list[tensor_id],
input_dptr_list,
input_shape_list,
input_type_list);
extract_tensor_props(scale_list[tensor_id],
scale_dptr_list,
scale_shape_list,
scale_type_list);
extract_tensor_props_skip_dtype(cast_output_list[tensor_id],
cast_output_dptr_list,
cast_output_shape_list);
cast_output_type_list.push_back(otype);
extract_tensor_props_skip_dtype(transposed_output_list[tensor_id],
transposed_output_dptr_list,
transposed_output_shape_list);
transposed_output_type_list.push_back(otype);
extract_tensor_props(amax_list[tensor_id],
amax_dptr_list,
amax_shape_list,
amax_type_list);
extract_tensor_props(scale_inv_list[tensor_id],
scale_inv_dptr_list,
scale_inv_shape_list,
scale_inv_type_list);
}
transformer_engine::TensorWrapper workspace;
// Construct TE tensors
std::vector<NVTETensor> nvte_input_list,
nvte_cast_output_list, nvte_transposed_output_list;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
auto make_tensor = [&tensor_wrappers](void* dptr,
const std::vector<size_t>& shape,
transformer_engine::DType dtype,
void* amax_dptr,
void* scale_dptr,
void* scale_inv_dptr)
-> NVTETensor {
tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr,
scale_dptr, scale_inv_dptr));
return tensor_wrappers.back().data();
};
for (size_t i = 0; i < input_dptr_list.size(); ++i) {
nvte_input_list.emplace_back(make_tensor(input_dptr_list[i],
input_shape_list[i],
input_type_list[i],
nullptr,
nullptr,
nullptr));
nvte_cast_output_list.emplace_back(make_tensor(cast_output_dptr_list[i],
cast_output_shape_list[i],
cast_output_type_list[i],
amax_dptr_list[i],
scale_dptr_list[i],
scale_inv_dptr_list[i]));
nvte_transposed_output_list.emplace_back(make_tensor(transposed_output_dptr_list[i],
transposed_output_shape_list[i],
transposed_output_type_list[i],
amax_dptr_list[i],
scale_dptr_list[i],
scale_inv_dptr_list[i]));
}
// Check tensor lists
NVTE_CHECK(nvte_cast_output_list.size() == nvte_input_list.size(),
"Number of input and C output tensors must match");
NVTE_CHECK(nvte_transposed_output_list.size() == nvte_input_list.size(),
"Number of input and T output tensors must match");
// Launch TE kernel
nvte_multi_cast_transpose(nvte_input_list.size(),
nvte_input_list.data(),
nvte_cast_output_list.data(),
nvte_transposed_output_list.data(),
at::cuda::getCurrentCUDAStream());
}
at::Tensor fp8_transpose(at::Tensor input,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
auto output =
allocateTorchTensor(input.size(1),
input.size(0),
DType::kByte);
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype);
nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor gelu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_gelu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor dgelu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto gtype = GetTransformerEngineDType(grad.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype);
nvte_dgelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor relu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = static_cast<size_t>(input.numel()) / N;
auto output =
allocateTorchTensor(M,
N,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_relu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor drelu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto gtype = GetTransformerEngineDType(grad.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype);
nvte_drelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor geglu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N / 2,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_geglu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor dgeglu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto gtype = GetTransformerEngineDType(grad.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype);
nvte_dgeglu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor reglu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N / 2,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_reglu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor dreglu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto gtype = GetTransformerEngineDType(grad.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype);
nvte_dreglu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor swiglu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N / 2,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_swiglu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor dswiglu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;
auto output =
allocateTorchTensor(M,
N,
otype);
auto itype = GetTransformerEngineDType(input.scalar_type());
auto gtype = GetTransformerEngineDType(grad.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype);
nvte_dswiglu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz,
const at::Tensor &x,
const at::Tensor &mu,
const at::Tensor &rsigma,
const at::Tensor &gamma,
const int sm_margin,
const bool zero_centered_gamma
) {
auto dx = at::empty_like(x);
auto dgamma = at::empty_like(gamma);
auto dbeta = at::empty_like(gamma);
transformer_engine::TensorWrapper workspace, barrier, dgamma_part, dbeta_part;
auto dz_cu = makeTransformerEngineTensor(dz);
auto x_cu = makeTransformerEngineTensor(x);
auto mu_cu = makeTransformerEngineTensor(mu);
auto rsigma_cu = makeTransformerEngineTensor(rsigma);
auto gamma_cu = makeTransformerEngineTensor(gamma);
auto dx_cu = makeTransformerEngineTensor(dx);
auto dgamma_cu = makeTransformerEngineTensor(dgamma);
auto dbeta_cu = makeTransformerEngineTensor(dbeta);
// This call populates tensors with the required config.
const auto bwd_fun = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(),
dbeta_part.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
// Alloc space for Tensors.
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true);
auto dgamma_part_data = allocateSpace(dgamma_part.shape(), dgamma_part.dtype());
auto dbeta_part_data = allocateSpace(dbeta_part.shape(), dbeta_part.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
barrier = makeTransformerEngineTensor(barrier_data.data_ptr(),
barrier.shape(),
barrier.dtype());
dgamma_part = makeTransformerEngineTensor(dgamma_part_data.data_ptr(),
dgamma_part.shape(),
dgamma_part.dtype());
dbeta_part = makeTransformerEngineTensor(dbeta_part_data.data_ptr(),
dbeta_part.shape(),
dbeta_part.dtype());
// Actual call to bwd kernel.
bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(),
dbeta_part.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
return { dx, dgamma, dbeta };
}
std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1));
DType itype = GetTransformerEngineDType(input.scalar_type());
auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(otype)));
auto mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto input_cu = makeTransformerEngineTensor(input);
auto gamma_cu = makeTransformerEngineTensor(weight);
auto beta_cu = makeTransformerEngineTensor(bias);
auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto mu_cu = makeTransformerEngineTensor(mu);
auto rsigma_cu = makeTransformerEngineTensor(rsigma);
transformer_engine::TensorWrapper workspace, barrier;
// This call populates workspace and barrier tensors with the required config
const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
// Fill workspace and barrier
auto workspace_data = allocateSpace(workspace.shape(),
workspace.dtype());
auto barrier_data = allocateSpace(barrier.shape(),
barrier.dtype(),
true);
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
barrier = makeTransformerEngineTensor(barrier_data.data_ptr(),
barrier.shape(),
barrier.dtype());
// Actual call to fwd kernel
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
return {ln_out, mu, rsigma};
}
std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
at::Tensor scale,
at::Tensor ln_out,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1));
DType itype = GetTransformerEngineDType(input.scalar_type());
auto mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto input_cu = makeTransformerEngineTensor(input);
auto gamma_cu = makeTransformerEngineTensor(weight);
auto beta_cu = makeTransformerEngineTensor(bias);
auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto mu_cu = makeTransformerEngineTensor(mu);
auto rsigma_cu = makeTransformerEngineTensor(rsigma);
transformer_engine::TensorWrapper workspace, barrier;
// This call populates workspace and barrier tensors with the required config
const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
// Fill workspace and barrier
auto workspace_data = allocateSpace(workspace.shape(),
workspace.dtype());
auto barrier_data = allocateSpace(barrier.shape(),
barrier.dtype(),
true);
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
barrier = makeTransformerEngineTensor(barrier_data.data_ptr(),
barrier.shape(),
barrier.dtype());
// Actual call to fwd kernel
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
return {ln_out, mu, rsigma};
}
at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
const bool zero_centered_gamma
) {
// This is a specialized version of layernorm_fwd_fp8, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = layernorm_fwd_fp8(
input, weight, bias, eps, scale, amax, scale_inv, otype, 0, zero_centered_gamma);
return out[0];
}
std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
const int sm_margin,
const bool zero_centered_gamma
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1));
DType itype = GetTransformerEngineDType(input.scalar_type());
auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(itype)));
auto mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto input_cu = makeTransformerEngineTensor(input);
auto gamma_cu = makeTransformerEngineTensor(weight);
auto beta_cu = makeTransformerEngineTensor(bias);
auto z_cu = makeTransformerEngineTensor(ln_out);
auto mu_cu = makeTransformerEngineTensor(mu);
auto rsigma_cu = makeTransformerEngineTensor(rsigma);
transformer_engine::TensorWrapper workspace, barrier;
// This call populates workspace and barrier tensors with the required config
const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
// Fill workspace and barrier
auto workspace_data = allocateSpace(workspace.shape(),
workspace.dtype());
auto barrier_data = allocateSpace(barrier.shape(),
barrier.dtype(),
true);
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
barrier = makeTransformerEngineTensor(barrier_data.data_ptr(),
barrier.shape(),
barrier.dtype());
// Actual call to fwd kernel
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
return {ln_out, mu, rsigma};
}
std::vector<at::Tensor> layernorm_fwd_noalloc(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
at::Tensor ln_out,
float eps,
const int sm_margin,
const bool zero_centered_gamma
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1));
DType itype = GetTransformerEngineDType(input.scalar_type());
auto mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto input_cu = makeTransformerEngineTensor(input);
auto gamma_cu = makeTransformerEngineTensor(weight);
auto beta_cu = makeTransformerEngineTensor(bias);
auto z_cu = makeTransformerEngineTensor(ln_out);
auto mu_cu = makeTransformerEngineTensor(mu);
auto rsigma_cu = makeTransformerEngineTensor(rsigma);
transformer_engine::TensorWrapper workspace, barrier;
// This call populates workspace and barrier tensors with the required config
const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
// Fill workspace and barrier
auto workspace_data = allocateSpace(workspace.shape(),
workspace.dtype());
auto barrier_data = allocateSpace(barrier.shape(),
barrier.dtype(),
true);
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
barrier = makeTransformerEngineTensor(barrier_data.data_ptr(),
barrier.shape(),
barrier.dtype());
// Actual call to fwd kernel
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
return {ln_out, mu, rsigma};
}
at::Tensor layernorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
const bool zero_centered_gamma
) {
// This is a specialized version of layernorm_fwd, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = layernorm_fwd(input, weight, bias, eps, 0, zero_centered_gamma);
return out[0];
}
at::Tensor cast_to_fp8(const at::Tensor &input,
const at::Tensor &scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
auto input_shape = input.sizes().vec();
std::vector<size_t> shape{input_shape.begin(), input_shape.end()};
auto output = at::empty_like(input, at::CUDA(GetATenDType(otype)));
auto input_cu = makeTransformerEngineTensor(input);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_fp8_quantize(input_cu.data(), output_cu.data(),
at::cuda::getCurrentCUDAStream());
return output;
}
void cast_to_fp8_noalloc(const at::Tensor &input,
const at::Tensor &scale,
at::Tensor output,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1));
auto input_cu = makeTransformerEngineTensor(input);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_fp8_quantize(input_cu.data(), output_cu.data(),
at::cuda::getCurrentCUDAStream());
return;
}
at::Tensor cast_from_fp8(const at::Tensor &input,
const at::Tensor &scale_inv,
transformer_engine::DType itype,
transformer_engine::DType otype
) {
using namespace transformer_engine;
auto input_shape = input.sizes().vec();
std::vector<size_t> shape{input_shape.begin(), input_shape.end()};
auto output = at::empty_like(input, at::CUDA(GetATenDType(otype)));
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), shape, itype,
nullptr, nullptr, scale_inv.data_ptr());
auto output_cu = makeTransformerEngineTensor(output);
nvte_fp8_dequantize(input_cu.data(), output_cu.data(),
at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor scaled_softmax_forward(at::Tensor input,
float scale_factor
) {
using namespace transformer_engine;
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
const int batches = input.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_CHECK(key_seq_len <= 4096);
TORCH_CHECK(query_seq_len > 1);
// Output
auto act_options = input.options().requires_grad(false);
auto softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
auto input_cu = makeTransformerEngineTensor(input);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
nvte_scaled_softmax_forward(input_cu.data(), softmax_results_cu.data(), scale_factor,
at::cuda::getCurrentCUDAStream());
return softmax_results;
}
at::Tensor scaled_softmax_backward(at::Tensor output_grad_,
at::Tensor softmax_results_,
float scale_factor
) {
using namespace transformer_engine;
auto output_grads = output_grad_.contiguous();
auto softmax_results = softmax_results_.contiguous();
AT_ASSERTM(output_grads.dim() == 4, "expected 4D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 4D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
// Produce gradients in place.
nvte_scaled_softmax_backward(
output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(),
scale_factor, at::cuda::getCurrentCUDAStream());
return output_grads;
}
at::Tensor scaled_masked_softmax_forward(at::Tensor input,
at::Tensor mask,
float scale_factor
) {
using namespace transformer_engine;
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
if (!input.is_contiguous())
input = input.contiguous();
if (!mask.is_contiguous())
mask = mask.contiguous();
const int batches = input.size(0);
const int pad_batches = mask.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_CHECK(key_seq_len <= 4096);
TORCH_CHECK(query_seq_len > 1);
TORCH_CHECK(pad_batches == 1 || pad_batches == batches);
TORCH_CHECK(mask.size(1) == 1);
TORCH_CHECK(mask.size(2) == query_seq_len);
TORCH_CHECK(mask.size(3) == key_seq_len);
auto act_options = input.options().requires_grad(false);
auto softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
auto input_cu = makeTransformerEngineTensor(input);
auto mask_cu = makeTransformerEngineTensor(mask);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
nvte_scaled_masked_softmax_forward(
input_cu.data(), mask_cu.data(), softmax_results_cu.data(),
scale_factor, at::cuda::getCurrentCUDAStream());
return softmax_results;
}
at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_,
at::Tensor softmax_results_,
float scale_factor
) {
using namespace transformer_engine;
auto output_grads = output_grad_.contiguous();
auto softmax_results = softmax_results_.contiguous();
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
// Produce gradients in place.
nvte_scaled_softmax_backward(
output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(),
scale_factor, at::cuda::getCurrentCUDAStream());
return output_grads;
}
at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input,
float scale_factor
) {
using namespace transformer_engine;
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
const int attn_batches = input.size(0);
const int seq_len = input.size(1);
TORCH_CHECK(seq_len <= 2048);
// Output
auto act_options = input.options().requires_grad(false);
auto softmax_results =
torch::empty({attn_batches, seq_len, seq_len}, act_options);
auto input_cu = makeTransformerEngineTensor(input);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
nvte_scaled_upper_triang_masked_softmax_forward(input_cu.data(),
softmax_results_cu.data(),
scale_factor,
at::cuda::getCurrentCUDAStream());
return softmax_results;
}
at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
at::Tensor softmax_results_,
float scale_factor
) {
using namespace transformer_engine;
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
TORCH_CHECK(output_grads.size(1) == output_grads.size(2));
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
// Produce gradients in place.
nvte_scaled_upper_triang_masked_softmax_backward(output_grads_cu.data(),
softmax_results_cu.data(),
output_grads_cu.data(),
scale_factor,
at::cuda::getCurrentCUDAStream());
return output_grads;
}
size_t get_cublasLt_version() {
return cublasLtGetVersion();
}
bool userbuf_comm_available() { // TODO(ksivamani) check on python side
#ifdef NVTE_WITH_USERBUFFERS
return true;
#else
return false;
#endif
}
void placeholder() {} // TODO(ksivamani) clean this up
namespace flash_attention { namespace flash_attention {
constexpr int warp_size = 32; constexpr int warp_size = 32;
...@@ -2132,146 +874,3 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { ...@@ -2132,146 +874,3 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) {
return qkv; return qkv;
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Softmax functions
m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD");
m.def("scaled_softmax_backward", &scaled_softmax_backward, "Scaled Softmax BWD");
m.def("scaled_masked_softmax_forward", &scaled_masked_softmax_forward,
"Scaled Masked Softmax FWD");
m.def("scaled_masked_softmax_backward", &scaled_masked_softmax_backward,
"Scaled Masked Softmax BWD");
m.def("scaled_upper_triang_masked_softmax_forward",
&scaled_upper_triang_masked_softmax_forward,
"Scaled Upper-Triangular Masked Softmax FWD");
m.def("scaled_upper_triang_masked_softmax_backward",
&scaled_upper_triang_masked_softmax_backward,
"Scaled Upper-Triangular Masked Softmax BWD");
// Other granular functions
m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8");
m.def("layernorm_fwd_fp8_noalloc", &layernorm_fwd_fp8_noalloc, "LN FWD FP8");
m.def("layernorm_bwd", &layernorm_bwd, "LN BWD");
m.def("layernorm_fwd", &layernorm_fwd, "LN FWD");
m.def("layernorm_fwd_noalloc", &layernorm_fwd_noalloc, "LN FWD");
m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose");
m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad,
"Fused Cast + Transpose + BGRAD");
m.def("fused_fp8_transpose_bgrad", &fused_fp8_transpose_bgrad,
"Fused FP8 Transpose + BGRAD");
m.def("fused_cast_transpose_bgrad_dgelu", &fused_cast_transpose_bgrad_dgelu,
"Fused Cast + Transpose + BGRAD + DGELU");
m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose,
"Fused Multi-tensor Cast + Transpose");
m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8");
m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8");
m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8");
m.def("te_gemm", &te_gemm, "CublasLt GEMM");
m.def("fused_attn_fwd_qkvpacked", &fused_attn_fwd_qkvpacked,
"Fused Attention FP8/BF16/FP16 FWD with packed QKV");
m.def("fused_attn_bwd_qkvpacked", &fused_attn_bwd_qkvpacked,
"Fused Attention FP8/BF16/FP16 BWD with packed QKV");
m.def("fused_attn_fwd_kvpacked", &fused_attn_fwd_kvpacked,
"Fused Attention FP8/BF16/FP16 FWD with packed KV");
m.def("fused_attn_bwd_kvpacked", &fused_attn_bwd_kvpacked,
"Fused Attention FP8/BF16/FP16 BWD with packed KV");
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O");
m.def("gelu", &gelu, "GeLU with FP8 output");
m.def("relu", &relu, "ReLU with FP8 output");
m.def("geglu", &geglu, "GeGLU with FP8 output");
m.def("reglu", &reglu, "ReGLU with FP8 output");
m.def("swiglu", &swiglu, "SwiGLU with FP8 output");
m.def("dgelu", &dgelu, "Backward of GeLU");
m.def("drelu", &drelu, "Backward of ReLU");
m.def("dgeglu", &dgeglu, "Backward of GeGLU");
m.def("dreglu", &dreglu, "Backward of ReGLU");
m.def("dswiglu", &dswiglu, "Backward of SwiGLU");
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention");
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention");
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend");
// Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version");
m.def("userbuf_comm_available", &userbuf_comm_available, "If userbuf backend is available");
// Data structures
py::class_<transformer_engine::FP8TensorMeta>(m, "FP8TensorMeta")
.def(py::init<>())
.def_readwrite("scale", &transformer_engine::FP8TensorMeta::scale)
.def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv)
.def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history);
#ifdef NVTE_WITH_USERBUFFERS
py::enum_<ubuf::UBOverlapAlgo>(m, "UbufOverlapAlgo")
.value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG)
.value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS)
.value("SPLIT_PIPELINED_RS", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS)
.value("SPLIT_PIPELINED_AG", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG);
py::class_<ubuf::UbufCommOverlap>(m, "UbufCommOverlap")
.def(py::init<torch::Tensor&, int, int, int, int, int, bool, int>())
.def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap)
.def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs)
.def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf)
.def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output);
py::class_<ubuf::UbufP2PCommOverlap>(m, "UbufP2PCommOverlap")
.def(py::init<torch::Tensor&, int, int, bool, int>())
.def("split_overlap_ag", &ubuf::UbufP2PCommOverlap::split_overlap_ag)
.def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf)
.def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output);
#else // NVTE_WITH_USERBUFFERS
m.def("UbufOverlapAlgo", &placeholder, "Dummy function for python side annotations");
m.def("UbufCommOverlap", &placeholder, "Dummy function for python side annotations");
m.def("UbufP2PCommOverlap", &placeholder, "Dummy function for python side annotations");
#endif // NVTE_WITH_USERBUFFERS
py::enum_<transformer_engine::DType>(m, "DType", py::module_local())
.value("kByte", transformer_engine::DType::kByte)
.value("kInt32", transformer_engine::DType::kInt32)
.value("kFloat32", transformer_engine::DType::kFloat32)
.value("kFloat16", transformer_engine::DType::kFloat16)
.value("kBFloat16", transformer_engine::DType::kBFloat16)
.value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3)
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2);
py::enum_<transformer_engine::FP8FwdTensors>(m, "FP8FwdTensors")
.value("GEMM1_INPUT", transformer_engine::FP8FwdTensors::GEMM1_INPUT)
.value("GEMM1_WEIGHT", transformer_engine::FP8FwdTensors::GEMM1_WEIGHT)
.value("GEMM1_OUTPUT", transformer_engine::FP8FwdTensors::GEMM1_OUTPUT)
.value("GEMM2_INPUT", transformer_engine::FP8FwdTensors::GEMM2_INPUT)
.value("GEMM2_WEIGHT", transformer_engine::FP8FwdTensors::GEMM2_WEIGHT)
.value("GEMM2_OUTPUT", transformer_engine::FP8FwdTensors::GEMM2_OUTPUT)
.value("GEMM3_INPUT", transformer_engine::FP8FwdTensors::GEMM3_INPUT)
.value("GEMM3_WEIGHT", transformer_engine::FP8FwdTensors::GEMM3_WEIGHT)
.value("GEMM3_OUTPUT", transformer_engine::FP8FwdTensors::GEMM3_OUTPUT);
py::enum_<transformer_engine::FP8BwdTensors>(m, "FP8BwdTensors")
.value("GRAD_OUTPUT1", transformer_engine::FP8BwdTensors::GRAD_OUTPUT1)
.value("GRAD_INPUT1", transformer_engine::FP8BwdTensors::GRAD_INPUT1)
.value("GRAD_OUTPUT2", transformer_engine::FP8BwdTensors::GRAD_OUTPUT2)
.value("GRAD_INPUT2", transformer_engine::FP8BwdTensors::GRAD_INPUT2)
.value("GRAD_OUTPUT3", transformer_engine::FP8BwdTensors::GRAD_OUTPUT3)
.value("GRAD_INPUT3", transformer_engine::FP8BwdTensors::GRAD_INPUT3);
py::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type")
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS)
.value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS)
.value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
py::enum_<NVTE_Mask_Type>(m, "NVTE_Mask_Type")
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK)
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK)
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK);
py::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout")
.value("NVTE_NOT_INTERLEAVED", NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED)
.value("NVTE_QKV_INTERLEAVED", NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
.value("NVTE_KV_INTERLEAVED", NVTE_QKV_Layout::NVTE_KV_INTERLEAVED);
py::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend")
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)
.value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)
.value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8)
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend);
}
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
at::Tensor cast_to_fp8(const at::Tensor &input,
const at::Tensor &scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
auto input_shape = input.sizes().vec();
std::vector<size_t> shape{input_shape.begin(), input_shape.end()};
auto output = at::empty_like(input, at::CUDA(GetATenDType(otype)));
auto input_cu = makeTransformerEngineTensor(input);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_fp8_quantize(input_cu.data(), output_cu.data(),
at::cuda::getCurrentCUDAStream());
return output;
}
void cast_to_fp8_noalloc(const at::Tensor &input,
const at::Tensor &scale,
at::Tensor output,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1));
auto input_cu = makeTransformerEngineTensor(input);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_fp8_quantize(input_cu.data(), output_cu.data(),
at::cuda::getCurrentCUDAStream());
return;
}
at::Tensor cast_from_fp8(const at::Tensor &input,
const at::Tensor &scale_inv,
transformer_engine::DType itype,
transformer_engine::DType otype
) {
using namespace transformer_engine;
auto input_shape = input.sizes().vec();
std::vector<size_t> shape{input_shape.begin(), input_shape.end()};
auto output = at::empty_like(input, at::CUDA(GetATenDType(otype)));
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), shape, itype,
nullptr, nullptr, scale_inv.data_ptr());
auto output_cu = makeTransformerEngineTensor(output);
nvte_fp8_dequantize(input_cu.data(), output_cu.data(),
at::cuda::getCurrentCUDAStream());
return output;
}
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
void te_gemm(at::Tensor A,
at::Tensor A_scale_inverse,
transformer_engine::DType A_type,
bool transa,
at::Tensor B,
at::Tensor B_scale_inverse,
transformer_engine::DType B_type,
bool transb,
at::Tensor D,
at::Tensor D_scale,
transformer_engine::DType D_type,
at::Tensor D_amax,
at::Tensor bias,
transformer_engine::DType bias_type,
at::Tensor pre_gelu_out,
bool grad,
at::Tensor workspace,
size_t workspaceSize,
bool accumulate,
bool use_split_accumulator,
int math_sm_count
) {
using namespace transformer_engine;
auto te_A = makeTransformerEngineTensor(A.data_ptr(),
{static_cast<size_t>(A.size(0)),
static_cast<size_t>(A.size(1))},
A_type, nullptr, nullptr,
A_scale_inverse.data_ptr());
auto te_B = makeTransformerEngineTensor(B.data_ptr(),
{static_cast<size_t>(B.size(0)),
static_cast<size_t>(B.size(1))},
B_type, nullptr, nullptr,
B_scale_inverse.data_ptr());
auto te_D = makeTransformerEngineTensor(D.data_ptr(),
{static_cast<size_t>(D.size(0)),
static_cast<size_t>(D.size(1))},
D_type, D_amax.data_ptr(),
D_scale.data_ptr(), nullptr);
auto te_bias = makeTransformerEngineTensor(bias.data_ptr(), {static_cast<size_t>(bias.size(0))},
bias_type);
const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr
? std::vector<size_t>{static_cast<size_t>(pre_gelu_out.size(0))}
: std::vector<size_t>{static_cast<size_t>(pre_gelu_out.size(0)),
static_cast<size_t>(pre_gelu_out.size(1))};
auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out.data_ptr(),
gelu_shape,
GetTransformerEngineDType(
pre_gelu_out.scalar_type()));
auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(),
{workspaceSize},
DType::kByte);
nvte_cublas_gemm(te_A.data(),
te_B.data(),
te_D.data(),
te_bias.data(),
te_pre_gelu_out.data(),
transa,
transb,
grad,
te_workspace.data(),
accumulate,
use_split_accumulator,
math_sm_count,
at::cuda::getCurrentCUDAStream());
}
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#ifdef NVTE_WITH_USERBUFFERS
#include "comm_gemm_overlap.h"
#endif // NVTE_WITH_USERBUFFERS
size_t get_cublasLt_version() {
return cublasLtGetVersion();
}
bool userbuf_comm_available() { // TODO(ksivamani) check on python side
#ifdef NVTE_WITH_USERBUFFERS
return true;
#else
return false;
#endif
}
void placeholder() {} // TODO(ksivamani) clean this up
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz,
const at::Tensor &x,
const at::Tensor &mu,
const at::Tensor &rsigma,
const at::Tensor &gamma,
const int sm_margin,
const bool zero_centered_gamma
) {
auto dx = at::empty_like(x);
auto dgamma = at::empty_like(gamma);
auto dbeta = at::empty_like(gamma);
transformer_engine::TensorWrapper workspace, barrier, dgamma_part, dbeta_part;
auto dz_cu = makeTransformerEngineTensor(dz);
auto x_cu = makeTransformerEngineTensor(x);
auto mu_cu = makeTransformerEngineTensor(mu);
auto rsigma_cu = makeTransformerEngineTensor(rsigma);
auto gamma_cu = makeTransformerEngineTensor(gamma);
auto dx_cu = makeTransformerEngineTensor(dx);
auto dgamma_cu = makeTransformerEngineTensor(dgamma);
auto dbeta_cu = makeTransformerEngineTensor(dbeta);
// This call populates tensors with the required config.
const auto bwd_fun = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(),
dbeta_part.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
// Alloc space for Tensors.
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true);
auto dgamma_part_data = allocateSpace(dgamma_part.shape(), dgamma_part.dtype());
auto dbeta_part_data = allocateSpace(dbeta_part.shape(), dbeta_part.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
barrier = makeTransformerEngineTensor(barrier_data.data_ptr(),
barrier.shape(),
barrier.dtype());
dgamma_part = makeTransformerEngineTensor(dgamma_part_data.data_ptr(),
dgamma_part.shape(),
dgamma_part.dtype());
dbeta_part = makeTransformerEngineTensor(dbeta_part_data.data_ptr(),
dbeta_part.shape(),
dbeta_part.dtype());
// Actual call to bwd kernel.
bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(),
dbeta_part.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
return { dx, dgamma, dbeta };
}
std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
) {
using namespace transformer_engine;
auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(otype)));
return layernorm_fwd_fp8_noalloc(input, weight, bias, eps,
scale, ln_out, amax, scale_inv,
otype, sm_margin, zero_centered_gamma);
}
std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
at::Tensor scale,
at::Tensor ln_out,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1));
DType itype = GetTransformerEngineDType(input.scalar_type());
auto mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto input_cu = makeTransformerEngineTensor(input);
auto gamma_cu = makeTransformerEngineTensor(weight);
auto beta_cu = makeTransformerEngineTensor(bias);
auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype,
getDataPtr(amax), getDataPtr(scale),
getDataPtr(scale_inv));
auto mu_cu = makeTransformerEngineTensor(mu);
auto rsigma_cu = makeTransformerEngineTensor(rsigma);
transformer_engine::TensorWrapper workspace, barrier;
// This call populates workspace and barrier tensors with the required config
const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
// Fill workspace and barrier
auto workspace_data = allocateSpace(workspace.shape(),
workspace.dtype());
auto barrier_data = allocateSpace(barrier.shape(),
barrier.dtype(),
true);
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
barrier = makeTransformerEngineTensor(barrier_data.data_ptr(),
barrier.shape(),
barrier.dtype());
// Actual call to fwd kernel
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
return {ln_out, mu, rsigma};
}
at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
const bool zero_centered_gamma
) {
// This is a specialized version of layernorm_fwd_fp8, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = layernorm_fwd_fp8(
input, weight, bias, eps, scale, amax, scale_inv, otype, 0, zero_centered_gamma);
return out[0];
}
std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
const int sm_margin,
const bool zero_centered_gamma
) {
using namespace transformer_engine;
DType itype = GetTransformerEngineDType(input.scalar_type());
auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(itype)));
return layernorm_fwd_noalloc(input, weight, bias, ln_out, eps,
sm_margin, zero_centered_gamma);
}
std::vector<at::Tensor> layernorm_fwd_noalloc(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
at::Tensor ln_out,
float eps,
const int sm_margin,
const bool zero_centered_gamma
) {
using namespace transformer_engine;
DType itype = GetTransformerEngineDType(input.scalar_type());
return layernorm_fwd_fp8_noalloc(input, weight, bias, eps, at::Tensor(),
ln_out, at::Tensor(), at::Tensor(),
itype, sm_margin, zero_centered_gamma);
}
at::Tensor layernorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
const bool zero_centered_gamma
) {
// This is a specialized version of layernorm_fwd, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = layernorm_fwd(input, weight, bias, eps, 0, zero_centered_gamma);
return out[0];
}
std::vector<at::Tensor> rmsnorm_bwd(const at::Tensor &dz,
const at::Tensor &x,
const at::Tensor &rsigma,
const at::Tensor &gamma,
const int sm_margin,
const bool zero_centered_gamma
) {
NVTE_CHECK(zero_centered_gamma == false,
"Zero-centered gamma is not supported yet for RMSNorm.");
auto dx = at::empty_like(x);
auto dgamma = at::empty_like(gamma);
transformer_engine::TensorWrapper workspace, barrier, dgamma_part;
auto dz_cu = makeTransformerEngineTensor(dz);
auto x_cu = makeTransformerEngineTensor(x);
auto rsigma_cu = makeTransformerEngineTensor(rsigma);
auto gamma_cu = makeTransformerEngineTensor(gamma);
auto dx_cu = makeTransformerEngineTensor(dx);
auto dgamma_cu = makeTransformerEngineTensor(dgamma);
// This call populates tensors with the required config.
const auto bwd_fun = nvte_rmsnorm_bwd;
bwd_fun(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dgamma_part.data(),
at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
// Alloc space for Tensors.
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true);
auto dgamma_part_data = allocateSpace(dgamma_part.shape(), dgamma_part.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
barrier = makeTransformerEngineTensor(barrier_data.data_ptr(),
barrier.shape(),
barrier.dtype());
dgamma_part = makeTransformerEngineTensor(dgamma_part_data.data_ptr(),
dgamma_part.shape(),
dgamma_part.dtype());
// Actual call to bwd kernel.
bwd_fun(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dgamma_part.data(),
at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
return { dx, dgamma };
}
std::vector<at::Tensor> rmsnorm_fwd_fp8(const at::Tensor &input,
const at::Tensor &weight,
float eps,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
) {
using namespace transformer_engine;
auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(otype)));
return rmsnorm_fwd_fp8_noalloc(input, weight, eps,
scale, ln_out, amax, scale_inv,
otype, sm_margin, zero_centered_gamma);
}
std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input,
const at::Tensor &weight,
float eps,
at::Tensor scale,
at::Tensor ln_out,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
) {
using namespace transformer_engine;
NVTE_CHECK(zero_centered_gamma == false,
"Zero-centered gamma is not supported yet for RMSNorm.");
size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1));
DType itype = GetTransformerEngineDType(input.scalar_type());
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto input_cu = makeTransformerEngineTensor(input);
auto gamma_cu = makeTransformerEngineTensor(weight);
auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype,
getDataPtr(amax), getDataPtr(scale),
getDataPtr(scale_inv));
auto rsigma_cu = makeTransformerEngineTensor(rsigma);
transformer_engine::TensorWrapper workspace, barrier;
// This call populates workspace and barrier tensors with the required config
const auto func = nvte_rmsnorm_fwd;
func(input_cu.data(), gamma_cu.data(), eps, z_cu.data(),
rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
// Fill workspace and barrier
auto workspace_data = allocateSpace(workspace.shape(),
workspace.dtype());
auto barrier_data = allocateSpace(barrier.shape(),
barrier.dtype(),
true);
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
barrier = makeTransformerEngineTensor(barrier_data.data_ptr(),
barrier.shape(),
barrier.dtype());
// Actual call to fwd kernel
func(input_cu.data(), gamma_cu.data(), eps, z_cu.data(),
rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
return {ln_out, rsigma};
}
at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input,
const at::Tensor &weight,
float eps,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
const bool zero_centered_gamma
) {
// This is a specialized version of rmsnorm_fwd_fp8, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = rmsnorm_fwd_fp8(
input, weight, eps, scale, amax, scale_inv, otype, 0, zero_centered_gamma);
return out[0];
}
std::vector<at::Tensor> rmsnorm_fwd(const at::Tensor &input,
const at::Tensor &weight,
float eps,
const int sm_margin,
const bool zero_centered_gamma
) {
using namespace transformer_engine;
DType itype = GetTransformerEngineDType(input.scalar_type());
auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(itype)));
return rmsnorm_fwd_noalloc(input, weight, ln_out, eps,
sm_margin, zero_centered_gamma);
}
std::vector<at::Tensor> rmsnorm_fwd_noalloc(const at::Tensor &input,
const at::Tensor &weight,
at::Tensor ln_out,
float eps,
const int sm_margin,
const bool zero_centered_gamma
) {
using namespace transformer_engine;
DType itype = GetTransformerEngineDType(input.scalar_type());
return rmsnorm_fwd_fp8_noalloc(input, weight, eps, at::Tensor(),
ln_out, at::Tensor(), at::Tensor(),
itype, sm_margin, zero_centered_gamma);
}
at::Tensor rmsnorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight,
float eps,
const bool zero_centered_gamma
) {
// This is a specialized version of rmsnorm_fwd, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = rmsnorm_fwd(input, weight, eps, 0, zero_centered_gamma);
return out[0];
}
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../extensions.h"
#ifdef NVTE_WITH_USERBUFFERS
#include "comm_gemm_overlap.h"
#endif // NVTE_WITH_USERBUFFERS
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Softmax functions
m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD");
m.def("scaled_softmax_backward", &scaled_softmax_backward, "Scaled Softmax BWD");
m.def("scaled_masked_softmax_forward", &scaled_masked_softmax_forward,
"Scaled Masked Softmax FWD");
m.def("scaled_masked_softmax_backward", &scaled_masked_softmax_backward,
"Scaled Masked Softmax BWD");
m.def("scaled_upper_triang_masked_softmax_forward",
&scaled_upper_triang_masked_softmax_forward,
"Scaled Upper-Triangular Masked Softmax FWD");
m.def("scaled_upper_triang_masked_softmax_backward",
&scaled_upper_triang_masked_softmax_backward,
"Scaled Upper-Triangular Masked Softmax BWD");
// Other granular functions
m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8");
m.def("layernorm_fwd_fp8_noalloc", &layernorm_fwd_fp8_noalloc, "LN FWD FP8");
m.def("layernorm_bwd", &layernorm_bwd, "LN BWD");
m.def("layernorm_fwd", &layernorm_fwd, "LN FWD");
m.def("layernorm_fwd_noalloc", &layernorm_fwd_noalloc, "LN FWD");
m.def("rmsnorm_fwd_fp8", &rmsnorm_fwd_fp8, "LN FWD FP8");
m.def("rmsnorm_fwd_fp8_noalloc", &rmsnorm_fwd_fp8_noalloc, "LN FWD FP8");
m.def("rmsnorm_bwd", &rmsnorm_bwd, "LN BWD");
m.def("rmsnorm_fwd", &rmsnorm_fwd, "LN FWD");
m.def("rmsnorm_fwd_noalloc", &rmsnorm_fwd_noalloc, "LN FWD");
m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose");
m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad,
"Fused Cast + Transpose + BGRAD");
m.def("fused_fp8_transpose_bgrad", &fused_fp8_transpose_bgrad,
"Fused FP8 Transpose + BGRAD");
m.def("fused_cast_transpose_bgrad_dgelu", &fused_cast_transpose_bgrad_dgelu,
"Fused Cast + Transpose + BGRAD + DGELU");
m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose,
"Fused Multi-tensor Cast + Transpose");
m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8");
m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8");
m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8");
m.def("te_gemm", &te_gemm, "CublasLt GEMM");
m.def("fused_attn_fwd_qkvpacked", &fused_attn_fwd_qkvpacked,
"Fused Attention FP8/BF16/FP16 FWD with packed QKV");
m.def("fused_attn_bwd_qkvpacked", &fused_attn_bwd_qkvpacked,
"Fused Attention FP8/BF16/FP16 BWD with packed QKV");
m.def("fused_attn_fwd_kvpacked", &fused_attn_fwd_kvpacked,
"Fused Attention FP8/BF16/FP16 FWD with packed KV");
m.def("fused_attn_bwd_kvpacked", &fused_attn_bwd_kvpacked,
"Fused Attention FP8/BF16/FP16 BWD with packed KV");
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O");
m.def("gelu", &gelu, "GeLU with FP8 output");
m.def("relu", &relu, "ReLU with FP8 output");
m.def("geglu", &geglu, "GeGLU with FP8 output");
m.def("reglu", &reglu, "ReGLU with FP8 output");
m.def("swiglu", &swiglu, "SwiGLU with FP8 output");
m.def("dgelu", &dgelu, "Backward of GeLU");
m.def("drelu", &drelu, "Backward of ReLU");
m.def("dgeglu", &dgeglu, "Backward of GeGLU");
m.def("dreglu", &dreglu, "Backward of ReGLU");
m.def("dswiglu", &dswiglu, "Backward of SwiGLU");
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention");
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention");
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend");
// Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version");
m.def("userbuf_comm_available", &userbuf_comm_available, "If userbuf backend is available");
// Data structures
py::class_<transformer_engine::FP8TensorMeta>(m, "FP8TensorMeta")
.def(py::init<>())
.def_readwrite("scale", &transformer_engine::FP8TensorMeta::scale)
.def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv)
.def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history);
#ifdef NVTE_WITH_USERBUFFERS
py::enum_<ubuf::UBOverlapAlgo>(m, "UbufOverlapAlgo")
.value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG)
.value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS)
.value("SPLIT_PIPELINED_RS", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS)
.value("SPLIT_PIPELINED_AG", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG);
py::class_<ubuf::UbufCommOverlap>(m, "UbufCommOverlap")
.def(py::init<torch::Tensor&, int, int, int, int, int, bool, int>())
.def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap)
.def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs)
.def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf)
.def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output);
py::class_<ubuf::UbufP2PCommOverlap>(m, "UbufP2PCommOverlap")
.def(py::init<torch::Tensor&, int, int, bool, int>())
.def("split_overlap_ag", &ubuf::UbufP2PCommOverlap::split_overlap_ag)
.def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf)
.def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output);
#else // NVTE_WITH_USERBUFFERS
m.def("UbufOverlapAlgo", &placeholder, "Dummy function for python side annotations");
m.def("UbufCommOverlap", &placeholder, "Dummy function for python side annotations");
m.def("UbufP2PCommOverlap", &placeholder, "Dummy function for python side annotations");
#endif // NVTE_WITH_USERBUFFERS
py::enum_<transformer_engine::DType>(m, "DType", py::module_local())
.value("kByte", transformer_engine::DType::kByte)
.value("kInt32", transformer_engine::DType::kInt32)
.value("kFloat32", transformer_engine::DType::kFloat32)
.value("kFloat16", transformer_engine::DType::kFloat16)
.value("kBFloat16", transformer_engine::DType::kBFloat16)
.value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3)
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2);
py::enum_<transformer_engine::FP8FwdTensors>(m, "FP8FwdTensors")
.value("GEMM1_INPUT", transformer_engine::FP8FwdTensors::GEMM1_INPUT)
.value("GEMM1_WEIGHT", transformer_engine::FP8FwdTensors::GEMM1_WEIGHT)
.value("GEMM1_OUTPUT", transformer_engine::FP8FwdTensors::GEMM1_OUTPUT)
.value("GEMM2_INPUT", transformer_engine::FP8FwdTensors::GEMM2_INPUT)
.value("GEMM2_WEIGHT", transformer_engine::FP8FwdTensors::GEMM2_WEIGHT)
.value("GEMM2_OUTPUT", transformer_engine::FP8FwdTensors::GEMM2_OUTPUT)
.value("GEMM3_INPUT", transformer_engine::FP8FwdTensors::GEMM3_INPUT)
.value("GEMM3_WEIGHT", transformer_engine::FP8FwdTensors::GEMM3_WEIGHT)
.value("GEMM3_OUTPUT", transformer_engine::FP8FwdTensors::GEMM3_OUTPUT);
py::enum_<transformer_engine::FP8BwdTensors>(m, "FP8BwdTensors")
.value("GRAD_OUTPUT1", transformer_engine::FP8BwdTensors::GRAD_OUTPUT1)
.value("GRAD_INPUT1", transformer_engine::FP8BwdTensors::GRAD_INPUT1)
.value("GRAD_OUTPUT2", transformer_engine::FP8BwdTensors::GRAD_OUTPUT2)
.value("GRAD_INPUT2", transformer_engine::FP8BwdTensors::GRAD_INPUT2)
.value("GRAD_OUTPUT3", transformer_engine::FP8BwdTensors::GRAD_OUTPUT3)
.value("GRAD_INPUT3", transformer_engine::FP8BwdTensors::GRAD_INPUT3);
py::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type")
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS)
.value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS)
.value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
py::enum_<NVTE_Mask_Type>(m, "NVTE_Mask_Type")
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK)
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK)
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK);
py::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout")
.value("NVTE_NOT_INTERLEAVED", NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED)
.value("NVTE_QKV_INTERLEAVED", NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
.value("NVTE_KV_INTERLEAVED", NVTE_QKV_Layout::NVTE_KV_INTERLEAVED);
py::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend")
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)
.value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)
.value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8)
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment