"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "31fc29abe6dedec2477f37a38945f8d9e57b1f14"
Unverified Commit a3df1d73 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Prototype for operation-based API (#707)



* Add basic infrastructure for Sequential module
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add linear op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add FP8 support in linear op

Runs, but need to validate. Runtime errors with non-FP8 params and FP8 compute, or FP8 params and non-FP8 compute.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add reshape op and unit test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add bias op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add unfused linear op

Test does not pass with FP8.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug unfused linear op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add test for linear+bias op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add separate abstract classes for unfused and fused ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Consolidate unfused ops in submodule
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add linear-bias fused op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use fused cast-transpose in linear ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable GEMM+bias fusion with FP32 activations

Not supported by cuBLAS.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add parallel unit test for unfused linear op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Refactor parallel tests to reduce job launches
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add all-reduce, all-gather, and reduce-scatter ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove unused file
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug multi-GPU FP8 test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add support for FP8 scale updates

Still need to implement amax reductions.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add license boilerplate
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fuse GEMM+bias in row TP

Add documentation for unfused ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Rename pipeline to fuser

Expand documentation
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Tweak documentation
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Preserve cached FP8 transpose between ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add option for fused wgrad accumulation
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Directly output FP8 from linear if needed
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix cuDNN front-end commit
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use updated FP8 tensor API for transpose caching
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use updated API for FP8 scale updates
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add tests for non-default FP8 recipes
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Rename UnfusedOperation to BasicOperation
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add unit test to check amax reduction with fusable op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Operator autograd state no longer needs to be initialized
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Initial functional implementation of linear op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug fused linear+bias op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove autograd context from functional linear impl
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use functional linear impl in fused linear+bias op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Rename subdirectory from "fuser" to "ops"

Avoid confusion with kernel fusers and graph compilers.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update with Float8Tensor changes in #820
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove unnecessary CPU overheads
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Correctly pass FP8 metadata from next op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter errors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add convenience functions to manipulate Sequential class
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update name of PyTorch extensions module
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Clear saved tensor data in linear op after bprop
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix Pylint error
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update name of PyTorch extensions module
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix test name in QA script
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update name of PyTorch extensions module
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Run distributed tests even when only 1 GPU is available
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Only run distributed tests with 2 GPUs if there are >=2 GPUs
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestions from @sudhakarsingh27 and @ksivaman

Fix spelling of "fusible". Avoid "input" name in internal APIs.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update transformer_engine/pytorch/ops/__init__.py
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 05977f44
...@@ -22,3 +22,5 @@ pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py ...@@ -22,3 +22,5 @@ pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py
pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py
pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops_distributed.py
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from __future__ import annotations
import math
import pytest
import torch
import transformer_engine
import transformer_engine.pytorch as te
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor
from transformer_engine.pytorch.ops.fused_forward import (
ForwardLinearBiasActivation,
)
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
# Supported data types
_dtypes: list[torch.dtype] = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher
_dtypes.append(torch.bfloat16)
# Supported devices
_devices: list[torch.device] = [torch.device("cpu"), torch.device("cuda")]
def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
"""
# Transformer Engine dtypes
if isinstance(dtype, tex.DType):
if dtype == tex.DType.kFloat8E4M3:
return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625
if dtype == tex.DType.kFloat8E5M2:
return dict(rtol=0.25, atol=0.125) # epsilon = 0.152
dtype = {
tex.DType.kByte: torch.uint8,
tex.DType.kInt32: torch.int32,
tex.DType.kFloat32: torch.float32,
tex.DType.kFloat16: torch.half,
tex.DType.kBFloat16: torch.bfloat16,
}[dtype]
# PyTorch dtypes
if dtype == torch.float16:
return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5)
if dtype == torch.float32:
return dict(rtol=1.3e-6, atol=1e-5)
if dtype == torch.float64:
return dict(rtol=1e-7, atol=1e-7)
raise ValueError(f"Unsupported dtype ({dtype})")
@torch.no_grad()
def make_reference_and_test_tensors(
shape: int | Iterable[int],
ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu",
test_dtype: torch.dtype = torch.float32,
test_device: torch.device = "cuda",
test_is_fp8: bool = False,
requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Construct tensors with the same values
The reference tensor is intended for use in plain PyTorch
operations in high precision. The test tensor is intended for use
in Transformer Engine operations.
"""
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
if test_is_fp8:
test = Float8Tensor.to_float8(ref)
test._transpose = test._data.reshape(-1, test.size(-1)).transpose(0, 1)
test._transpose = test._transpose.contiguous()
test._transpose_invalid = False
else:
test = ref.to(device=test_device, dtype=test_dtype)
if test.data_ptr() == ref.data_ptr():
test = test.clone()
ref.copy_(test)
ref.requires_grad_(requires_grad)
test.requires_grad_(requires_grad)
return ref, test
class TestSequential:
"""Tests for sequential container"""
def test_modules(self) -> None:
"""Check that list of modules can be manipulated as expected"""
# Construct sequential container
modules = [
te_ops.Identity(),
te_ops.Identity(),
torch.nn.Identity(),
te_ops.Identity(),
]
model = te_ops.Sequential(*modules)
# Length
assert len(model) == len(modules)
# Iterator
for module1, module2 in zip(model, modules):
assert module1 is module2
# Index by int
for i, module in enumerate(modules):
assert model[i] is module
assert model[i - len(modules)] is module
# Index by slice
model_subset = model[1:-1]
modules_subset = modules[1:-1]
assert isinstance(model_subset, te_ops.Sequential)
for module1, module2 in zip(model_subset, modules_subset):
assert module1 is module2
# Set element
new_module = torch.nn.Identity()
idx = 1
modules[idx] = new_module
model[idx] = new_module
for module1, module2 in zip(model, modules):
assert module1 is module2
# Delete element
idx = 1
del modules[idx]
del model[idx]
for module1, module2 in zip(model, modules):
assert module1 is module2
# Append
new_module = torch.nn.Identity()
modules.append(new_module)
model.append(new_module)
for module1, module2 in zip(model, modules):
assert module1 is module2
# Extend
new_modules = [te_ops.Identity(), te_ops.Identity()]
modules.extend(new_modules)
model.extend(new_modules)
for module1, module2 in zip(model, modules):
assert module1 is module2
# Insert
new_module = te_ops.Identity()
idx = 2
modules.insert(idx, new_module)
model.insert(idx, new_module)
for module1, module2 in zip(model, modules):
assert module1 is module2
# Pop
idx = 2
assert model.pop(idx) is modules.pop(idx)
for module1, module2 in zip(model, modules):
assert module1 is module2
# Out-of-place add
new_modules = [torch.nn.Identity(), te_ops.Identity()]
added_modules = modules + new_modules
added_model = model + te_ops.Sequential(*new_modules)
for module1, module2 in zip(model, modules):
assert module1 is module2
for module1, module2 in zip(added_model, added_modules):
assert module1 is module2
# In-place add
new_modules = [te_ops.Identity(), torch.nn.Identity()]
modules += new_modules
model += te_ops.Sequential(*new_modules)
for module1, module2 in zip(model, modules):
assert module1 is module2
def test_module_groups(self) -> None:
"""Check that modules are grouped together correctly"""
model = te_ops.Sequential(
te_ops.Identity(),
te_ops.Identity(),
torch.nn.Identity(),
torch.nn.Identity(),
te_ops.Identity(),
torch.nn.Identity(),
te_ops.Identity(),
te_ops.Identity(),
te_ops.Identity(),
)
model(torch.zeros(1))
assert len(model._module_groups) == 6
class TestFuser:
"""Tests for operation fusion infrastructure"""
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_scale_update(
self,
size: int = 16,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
):
"""Test FP8 scaling factors with delayed scaling recipe"""
# FP8 recipe
margin = 2
fp8_format = transformer_engine.common.recipe.Format.HYBRID
recipe = transformer_engine.common.recipe.DelayedScaling(
margin=margin,
interval=1,
fp8_format=fp8_format,
amax_history_len=8,
amax_compute_algo="max",
)
# Construct model
with te.fp8_model_init():
model = te_ops.basic.BasicLinear(
size,
size,
device=device,
dtype=dtype,
)
# Training steps
w_vals = [2, 5, 3, 11]
x_vals = [7, 3, 5]
dy_vals = [1, 2, 1]
with torch.no_grad():
model.weight.fill_(w_vals[0])
for step in range(3):
# Data tensors
x = torch.full(
(size, size),
x_vals[step],
dtype=dtype,
device=device,
requires_grad=True,
)
dy = torch.full(
(size, size),
dy_vals[step],
dtype=dtype,
device=device,
)
# Training step
with te.fp8_autocast(fp8_recipe=recipe):
y = model(x)
y.backward(dy)
with torch.no_grad():
model.weight.fill_(w_vals[step + 1])
# Check that output tensors match expected
tols = dict(rtol=0, atol=0)
y_val_ref = w_vals[step] * x_vals[step] * size
dx_val_ref = w_vals[step] * dy_vals[step] * size
torch.testing.assert_close(
y,
torch.full_like(y, y_val_ref),
**dtype_tols(tex.DType.kFloat8E4M3),
)
torch.testing.assert_close(
x.grad,
torch.full_like(x.grad, dx_val_ref),
**dtype_tols(tex.DType.kFloat8E5M2),
)
# Check that scaling factors match expected
w_amax_ref = max(w_vals[: step + 2])
x_amax_ref = max(x_vals[: step + 1])
dy_amax_ref = max(dy_vals[: step + 1])
w_scale_ref = (fp8_format.value.max_fwd / w_amax_ref) / (2**margin)
x_scale_ref = (fp8_format.value.max_fwd / x_amax_ref) / (2**margin)
dy_scale_ref = (fp8_format.value.max_bwd / dy_amax_ref) / (2**margin)
forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)
backward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False)
w_scale = model.get_fp8_meta("param")[forward_key].scale
x_scale = model.get_fp8_meta("input")[forward_key].scale
dy_scale = model.get_fp8_meta("grad_output")[backward_key].scale
torch.testing.assert_close(w_scale, torch.full_like(w_scale, w_scale_ref))
torch.testing.assert_close(x_scale, torch.full_like(x_scale, x_scale_ref))
torch.testing.assert_close(dy_scale, torch.full_like(dy_scale, dy_scale_ref))
class TestBasicOps:
"""Tests for individual operations"""
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize("in_shape", ((1,),))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu"))
@pytest.mark.parametrize("fp8", (False, True))
def test_identity(
self,
*,
in_shape: Iterable[int],
dtype: torch.dtype,
device: torch.device,
fp8: bool,
) -> None:
# Skip invalid configurations
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = x_ref
dx_ref = dy_ref
# Implementation with fusible operation
op = te_ops.Identity()
y_test = op(x_test)
y_test.backward(dy_test)
# Check results
tols = dict(rtol=0, atol=0) # Identity is exact
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, dx_ref, **tols)
# Make sure we are not trivially passing the test
with pytest.raises(AssertionError):
torch.testing.assert_close(y_test, -y_ref, **tols)
with pytest.raises(AssertionError):
torch.testing.assert_close(dx_test, -dx_ref, **tols)
@pytest.mark.parametrize(
"shapes",
(
((1, 2, 3, 4), (2, 12)),
((5, 4, 3, 2), (-1, 6)),
((30,), (2, 3, -1)),
((6, 7), (3, -1, 7)),
),
)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu"))
@pytest.mark.parametrize(
"memory_format",
(torch.contiguous_format, torch.channels_last),
)
@pytest.mark.parametrize("fp8", (False, True))
def test_reshape(
self,
*,
shapes: tuple[Iterable[int], Iterable[int]],
dtype: torch.dtype,
device: torch.device,
memory_format: torch.memory_format,
fp8: bool,
) -> None:
in_shape, out_shape = shapes
# Skip invalid configurations
if memory_format == torch.channels_last and len(in_shape) != 4:
pytest.skip("torch.channels_last only supports 4D tensors")
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
)
x_test = x_test.contiguous(memory_format=memory_format)
x_test = x_test.detach().requires_grad_()
dy_ref, dy_test = make_reference_and_test_tensors(
x_ref.reshape(out_shape).size(),
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = x_ref.reshape(out_shape)
y_ref.backward(dy_ref)
# Implementation with fusible operation
op = te_ops.Reshape(out_shape)
y_test = op(x_test)
y_test.backward(dy_test)
# Check results
tols = dict(rtol=0, atol=0) # Reshape is exact
y_test = y_test.to(
dtype=torch.float64,
device="cpu",
memory_format=torch.contiguous_format,
)
dx_test = x_test.grad.to(
dtype=torch.float64,
device="cpu",
memory_format=torch.contiguous_format,
)
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("size", (1, 7, 32))
@pytest.mark.parametrize("in_shape", ((-1,), (1, 3, -1), (2, 3, 4, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize("fp8", (False, True))
def test_bias(
self,
*,
size: int,
in_shape: Iterable[int],
dtype: torch.dtype,
device: torch.device,
fp8: bool,
) -> None:
# Make input and bias shapes consistent
in_shape = list(in_shape)[:-1] + [size]
# Skip invalid configurations
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
)
b_ref, b_test = make_reference_and_test_tensors(
size,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = x_ref + b_ref.reshape([1] * (len(in_shape) - 1) + [size])
y_ref.backward(dy_ref)
# Implementation with fusible operation
op = te_ops.Bias(size, device=device, dtype=dtype)
with torch.no_grad():
op.bias.copy_(b_test)
del b_test
y_test = op(x_test)
y_test.backward(dy_test)
# Check results
tols = dtype_tols(dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(db_test, b_ref.grad, **tols)
@pytest.mark.parametrize("weight_shape", ((48, 16), (3, 5)))
@pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (2, 2, 4, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("fp8_compute", (False, True))
@pytest.mark.parametrize("fp8_input", (False, True))
@pytest.mark.parametrize("fp8_weight", (False, True))
@pytest.mark.parametrize("fp8_grad_output", (False, True))
@pytest.mark.parametrize("accumulate_into_main_grad", (False, True))
def test_basic_linear(
self,
*,
weight_shape: tuple[int, int],
in_shape: Iterable[int],
dtype: torch.dtype,
device: torch.device = "cuda",
fp8_compute: bool,
fp8_input: bool,
fp8_weight: bool,
fp8_grad_output: bool,
accumulate_into_main_grad: bool,
) -> None:
"""GEMM"""
# Make input and weight shapes consistent
out_features, in_features = weight_shape
in_shape = list(in_shape)[:-1] + [in_features]
out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations
if fp8_compute or fp8_input or fp8_weight or fp8_grad_output:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
if fp8_compute:
if (
math.prod(in_shape[:-1]) % 16 != 0
or in_features % 16 != 0
or out_features % 16 != 0
):
pytest.skip("FP8 GEMMs require dims that are divisible by 16")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=(fp8_compute or fp8_input),
)
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
test_dtype=dtype,
test_device=device,
test_is_fp8=(fp8_compute or fp8_weight),
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=(fp8_compute or fp8_grad_output),
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
y_ref.backward(dy_ref)
# Implementation with fusible operation
with te.fp8_model_init(enabled=fp8_weight):
op = te_ops.BasicLinear(
in_features,
out_features,
device=device,
dtype=dtype,
accumulate_into_main_grad=accumulate_into_main_grad,
)
with torch.no_grad():
op.weight.copy_(w_test)
del w_test
op.weight.main_grad = torch.full_like(op.weight, 0.5, dtype=torch.float32)
with te.fp8_autocast(enabled=fp8_compute):
y_test = op(x_test)
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if fp8_compute:
tols = dtype_tols(
op.weight._fp8_dtype if is_float8_tensor(op.weight) else tex.DType.kFloat8E4M3
)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
if accumulate_into_main_grad:
if op.weight.grad is not None:
torch.testing.assert_close(
op.weight.grad,
torch.zeros_like(op.weight.grad),
rtol=0,
atol=0,
)
dw_test = op.weight.main_grad.to(dtype=torch.float64, device="cpu") - 0.5
else:
dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(
op.weight.main_grad,
torch.full_like(op.weight.main_grad, 0.5),
rtol=0,
atol=0,
)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
@pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("fp8_compute", (False, True))
@pytest.mark.parametrize("fp8_weight", (False, True))
def test_linear(
self,
*,
bias: bool,
weight_shape: tuple[int, int] = (16, 16),
in_shape: Iterable[int] = (16, -1),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
fp8_compute: bool,
fp8_input: bool = False,
fp8_weight: bool,
) -> None:
"""GEMM + bias"""
# Make input and weight shapes consistent
out_features, in_features = weight_shape
in_shape = list(in_shape)[:-1] + [in_features]
out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations
if fp8_input or fp8_weight or fp8_compute:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
if fp8_compute:
if (
math.prod(in_shape[:-1]) % 16 != 0
or in_features % 16 != 0
or out_features % 16 != 0
):
pytest.skip("FP8 GEMMs require dims that are divisible by 16")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=(fp8_compute or fp8_input),
)
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
test_dtype=dtype,
test_device=device,
test_is_fp8=(fp8_compute or fp8_weight),
)
b_ref, b_test = None, None
if bias:
b_ref, b_test = make_reference_and_test_tensors(
out_features,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref, bias=b_ref)
y_ref.backward(dy_ref)
# Implementation with fusible operation
with te.fp8_model_init(enabled=fp8_weight):
op = te_ops.Linear(
in_features,
out_features,
bias=bias,
device=device,
dtype=dtype,
)
with torch.no_grad():
op.weight.copy_(w_test)
if bias:
op.bias.copy_(b_test)
del w_test
del b_test
with te.fp8_autocast(enabled=fp8_compute):
y_test = op(x_test)
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if fp8_compute:
tols = dtype_tols(
op.weight._fp8_dtype if is_float8_tensor(op.weight) else tex.DType.kFloat8E4M3
)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
if bias:
db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(db_test, b_ref.grad, **tols)
class TestFusedOps:
"""Tests for fused operations"""
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize("weight_shape", ((32, 48), (3, 5)))
@pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (4, 2, 10, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("fp8_compute", (False, True))
@pytest.mark.parametrize("fp8_input", (False, True))
@pytest.mark.parametrize("fp8_weight", (False, True))
def test_linear_bias_activation(
self,
*,
bias: bool = True,
weight_shape: tuple[int, int],
in_shape: Iterable[int],
dtype: torch.dtype,
device: torch.device = "cuda",
fp8_compute: bool,
fp8_input: bool,
fp8_weight: bool,
) -> None:
"""GEMM + bias + activation"""
# Make input and weight shapes consistent
out_features, in_features = weight_shape
in_shape = list(in_shape)[:-1] + [in_features]
out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations
if fp8_input or fp8_weight or fp8_compute:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
if fp8_compute:
if (
math.prod(in_shape[:-1]) % 16 != 0
or in_features % 16 != 0
or out_features % 16 != 0
):
pytest.skip("FP8 GEMMs require dims that are divisible by 16")
if dtype not in (torch.float16, torch.bfloat16):
pytest.skip(
"FP8 fused linear-bias-activation is only supported with FP16 or BF16 output"
)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=(fp8_compute or fp8_input),
)
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
test_dtype=dtype,
test_device=device,
test_is_fp8=(fp8_compute or fp8_weight),
)
b_ref, b_test = None, None
if bias:
b_ref, b_test = make_reference_and_test_tensors(
out_features,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref, bias=b_ref)
y_ref.backward(dy_ref)
# Implementation with fusible operations
with te.fp8_model_init(enabled=fp8_weight):
model = te_ops.Sequential(
te_ops.Linear(
in_features,
out_features,
bias=bias,
device=device,
dtype=dtype,
),
)
with torch.no_grad():
model[0].weight.copy_(w_test)
if bias:
model[0].bias.copy_(b_test)
del w_test
del b_test
with te.fp8_autocast(enabled=fp8_compute):
y_test = model(x_test)
y_test.backward(dy_test)
# Check that forward operations have been fused
forward_ops = model._module_groups[0]._forward_ops
assert len(forward_ops) == 1
assert isinstance(forward_ops[0][0], ForwardLinearBiasActivation)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if fp8_compute:
tols = dtype_tols(
model[0].weight._fp8_dtype
if is_float8_tensor(model[0].weight)
else tex.DType.kFloat8E4M3
)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
if bias:
db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(db_test, b_ref.grad, **tols)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_linear(
self,
*,
in_shape: Iterable[int] = (16, 16),
dtype: torch.dtype = torch.bfloat16,
device: torch.device = "cuda",
) -> None:
"""Adjacent linear ops with FP8 enabled"""
# Make input and weight shapes consistent
in_shape = tuple(in_shape)
weight_shape = (in_shape[-1], in_shape[-1])
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=True,
)
w0_ref, w0_test = make_reference_and_test_tensors(
weight_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=True,
)
w1_ref, w1_test = make_reference_and_test_tensors(
weight_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=True,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w0_ref)
y_ref = torch.nn.functional.linear(y_ref, w1_ref)
y_ref.backward(dy_ref)
# Implementation with fusible operations
with te.fp8_model_init(enabled=True):
model = te_ops.Sequential(
te_ops.BasicLinear(
in_shape[-1],
in_shape[-1],
device=device,
dtype=dtype,
),
te_ops.BasicLinear(
in_shape[-1],
in_shape[-1],
device=device,
dtype=dtype,
),
)
with torch.no_grad():
model[0].weight.copy_(w0_test)
model[1].weight.copy_(w1_test)
del w0_test, w1_test
with te.fp8_autocast(enabled=True):
y_test = model(x_test)
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(model[0].weight._fp8_dtype)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw0_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
dw1_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw0_test, w0_ref.grad, **tols)
torch.testing.assert_close(dw1_test, w1_ref.grad, **tols)
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from __future__ import annotations
import argparse
import functools
import itertools
import os
import pathlib
import subprocess
import sys
import pytest
import torch
import transformer_engine
import transformer_engine.pytorch as te
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
@functools.cache
def world_group() -> torch.distributed.ProcessGroup:
"""Get NCCL process group, initializing if needed"""
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(rank)
group = torch.distributed.init_process_group(
"nccl",
init_method="file:///tmp/rdzv",
world_size=world_size,
rank=rank,
)
return group
def reset_rng(seed: int = 1234) -> None:
"""Reset random number generators"""
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@torch.no_grad()
def make_reference_and_test_tensors(
shape: int | Iterable[int],
ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu",
test_dtype: torch.dtype = torch.float32,
test_device: torch.device = "cuda",
test_is_fp8: bool = False,
requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Construct tensors with the same values
The reference tensor is intended for use in plain PyTorch
operations in high precision. The test tensor is intended for use
in Transformer Engine operations.
"""
# Random data
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
# Make copy of tensor
if test_is_fp8:
test = Float8Tensor.to_float8(ref)
else:
test = ref.to(device=test_device, dtype=test_dtype)
if test.data_ptr() == ref.data_ptr():
test = test.clone()
# Make sure reference and test tensors represent exact same values
ref.copy_(test)
# Return reference and test tensors
ref.requires_grad_(requires_grad)
test.requires_grad_(requires_grad)
return ref, test
def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
"""
# Transformer Engine dtypes
if isinstance(dtype, tex.DType):
if dtype == tex.DType.kFloat8E4M3:
return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625
if dtype == tex.DType.kFloat8E5M2:
return dict(rtol=0.25, atol=0.125) # epsilon = 0.152
dtype = {
tex.DType.kByte: torch.uint8,
tex.DType.kInt32: torch.int32,
tex.DType.kFloat32: torch.float32,
tex.DType.kFloat16: torch.half,
tex.DType.kBFloat16: torch.bfloat16,
}[dtype]
# PyTorch dtypes
if dtype == torch.float16:
return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5)
if dtype == torch.float32:
return dict(rtol=1.3e-6, atol=1e-5)
if dtype == torch.float64:
return dict(rtol=1e-7, atol=1e-7)
raise ValueError(f"Unsupported dtype ({dtype})")
def _test_all_reduce(
*,
local_size: int = 17,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
fp8: bool = False,
) -> None:
# Distributed process group
process_group = world_group()
rank = torch.distributed.get_rank(process_group)
world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions
in_shape = [world_size, local_size]
out_shape = [local_size]
# Random data
reset_rng()
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
)
# Plain PyTorch implementation
y_ref = x_ref.sum(0)
y_ref.backward(dy_ref)
# Convert to distributed tensors
with torch.no_grad():
dx_ref = x_ref.grad[rank]
x_ref = x_ref[rank]
x_test = x_test[rank].clone()
x_test.requires_grad_()
# Implementation with fusible operation
op = te_ops.AllReduce(process_group=process_group)
y_test = op(x_test)
y_test.backward(dy_test)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **dtype_tols(dtype))
torch.testing.assert_close(dx_test, dx_ref, rtol=0, atol=0)
def _test_all_gather(
*,
local_size: int = 13,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
fp8: bool = False,
) -> None:
# Distributed process group
process_group = world_group()
rank = torch.distributed.get_rank(process_group)
world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions
in_shape = [world_size, local_size]
out_shape = [world_size, world_size * local_size]
# Random data
reset_rng()
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
)
# Plain PyTorch implementation
y_ref = x_ref.tile((world_size, 1)).reshape(out_shape)
y_ref.backward(dy_ref)
# Convert to distributed tensors
with torch.no_grad():
dx_ref = x_ref.grad[rank]
x_ref = x_ref[rank]
x_test = x_test[rank].clone()
y_ref = y_ref[rank]
dy_ref = dy_ref[rank]
dy_test = dy_test[rank].clone()
x_test.requires_grad_()
# Implementation with fusible operation
op = te_ops.AllGather(process_group=process_group)
y_test = op(x_test)
y_test.backward(dy_test)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, rtol=0, atol=0)
torch.testing.assert_close(dx_test, dx_ref, **dtype_tols(dtype))
def _test_reduce_scatter(
*,
local_size: int = 11,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
fp8: bool = False,
) -> None:
# Distributed process group
process_group = world_group()
rank = torch.distributed.get_rank(process_group)
world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions
in_shape = [world_size, world_size * local_size]
out_shape = [world_size, local_size]
# Random data
reset_rng()
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
)
# Plain PyTorch implementation
y_ref = x_ref.sum(0).reshape(out_shape)
y_ref.backward(dy_ref)
# Convert to distributed tensors
with torch.no_grad():
dx_ref = x_ref.grad[rank]
x_ref = x_ref[rank]
x_test = x_test[rank].clone()
y_ref = y_ref[rank]
dy_ref = dy_ref[rank]
dy_test = dy_test[rank].clone()
x_test.requires_grad_()
# Implementation with fusible operation
op = te_ops.ReduceScatter(process_group=process_group)
y_test = op(x_test)
y_test.backward(dy_test)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **dtype_tols(dtype))
torch.testing.assert_close(dx_test, dx_ref, rtol=0, atol=0)
def _test_basic_linear(
*,
local_weight_shape: tuple[int, int] = (16, 16),
batch_size: int = 16,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
fp8_compute: bool = False,
fp8_input: bool = False,
fp8_weight: bool = False,
fp8_grad_output: bool = False,
tensor_parallel_mode: str = "column",
sequence_parallel: bool = False,
) -> None:
# Distributed process group
process_group = world_group()
rank = torch.distributed.get_rank(process_group)
world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions
local_out_features, local_in_features = local_weight_shape
out_features, in_features = local_out_features, local_in_features
if tensor_parallel_mode == "column":
out_features *= world_size
elif tensor_parallel_mode == "row":
in_features *= world_size
in_shape = [batch_size, in_features]
out_shape = [batch_size, out_features]
# Random data
reset_rng()
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=(fp8_compute or fp8_input),
)
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
test_dtype=dtype,
test_device=device,
test_is_fp8=(fp8_compute or fp8_weight),
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=(fp8_compute or fp8_grad_output),
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
y_ref.backward(dy_ref)
# Convert to distributed tensors
with torch.no_grad():
dw_ref = w_ref.grad
dx_ref = x_ref.grad
if tensor_parallel_mode == "column":
local_out_features = out_features // world_size
local_slice = slice(
rank * local_out_features,
(rank + 1) * local_out_features,
)
w_ref = w_ref[local_slice, :]
dw_ref = dw_ref[local_slice, :]
w_test = w_test[local_slice, :]
y_ref = y_ref[..., local_slice]
dy_ref = dy_ref[..., local_slice]
dy_test = dy_test[..., local_slice].clone()
elif tensor_parallel_mode == "row":
local_in_features = in_features // world_size
local_slice = slice(
rank * local_in_features,
(rank + 1) * local_in_features,
)
w_ref = w_ref[:, local_slice]
dw_ref = dw_ref[:, local_slice]
w_test = w_test[:, local_slice]
x_ref = x_ref[..., local_slice]
dx_ref = dx_ref[..., local_slice]
x_test = x_test[..., local_slice].clone()
if sequence_parallel:
local_batch_size = batch_size // world_size
local_slice = slice(
rank * local_batch_size,
(rank + 1) * local_batch_size,
)
if tensor_parallel_mode == "column":
x_ref = x_ref[local_slice, ...]
dx_ref = dx_ref[local_slice, ...]
x_test = x_test[local_slice, ...].clone()
elif tensor_parallel_mode == "row":
y_ref = y_ref[local_slice, ...]
dy_ref = dy_ref[local_slice, ...]
dy_test = dy_test[local_slice, ...].clone()
x_test.requires_grad_()
# Implementation with fusible operation
with te.fp8_model_init(enabled=fp8_weight):
op = te_ops.BasicLinear(
in_features,
out_features,
device=device,
dtype=dtype,
tensor_parallel_mode=tensor_parallel_mode,
tensor_parallel_group=process_group,
sequence_parallel=sequence_parallel,
)
with torch.no_grad():
op.weight.copy_(w_test)
del w_test
with te.fp8_autocast(enabled=fp8_compute):
y_test = op(x_test)
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if fp8_compute:
tols = dtype_tols(
op.weight._fp8_dtype if is_float8_tensor(op.weight) else tex.DType.kFloat8E4M3
)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, dx_ref, **tols)
torch.testing.assert_close(dw_test, dw_ref, **tols)
def _test_linear(
*,
bias: bool = True,
local_weight_shape: tuple[int, int] = (16, 16),
batch_size: int = 16,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
fp8_compute: bool = False,
fp8_input: bool = False,
fp8_weight: bool = False,
fp8_grad_output: bool = False,
tensor_parallel_mode: str = "column",
sequence_parallel: bool = False,
) -> None:
# Distributed process group
process_group = world_group()
rank = torch.distributed.get_rank(process_group)
world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions
local_out_features, local_in_features = local_weight_shape
out_features, in_features = local_out_features, local_in_features
if tensor_parallel_mode == "column":
out_features *= world_size
elif tensor_parallel_mode == "row":
in_features *= world_size
in_shape = [batch_size, in_features]
out_shape = [batch_size, out_features]
# Random data
reset_rng()
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=(fp8_compute or fp8_input),
)
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
test_dtype=dtype,
test_device=device,
test_is_fp8=(fp8_compute or fp8_weight),
)
b_ref, b_test = None, None
if bias:
if tensor_parallel_mode == "row":
bias_shape = [world_size, out_features]
else:
bias_shape = [out_features]
b_ref, b_test = make_reference_and_test_tensors(
bias_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=(fp8_compute or fp8_grad_output),
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
if bias:
if tensor_parallel_mode == "row":
y_ref += b_ref.sum(dim=0)
else:
y_ref += b_ref
y_ref.backward(dy_ref)
# Convert to distributed tensors
with torch.no_grad():
dw_ref = w_ref.grad
db_ref = b_ref.grad if bias else None
dx_ref = x_ref.grad
if tensor_parallel_mode == "column":
local_out_features = out_features // world_size
local_slice = slice(
rank * local_out_features,
(rank + 1) * local_out_features,
)
w_ref = w_ref[local_slice, :]
dw_ref = dw_ref[local_slice, :]
w_test = w_test[local_slice, :]
if bias:
b_ref = b_ref[local_slice]
db_ref = db_ref[local_slice]
b_test = b_test[local_slice]
y_ref = y_ref[..., local_slice]
dy_ref = dy_ref[..., local_slice]
dy_test = dy_test[..., local_slice].clone()
elif tensor_parallel_mode == "row":
local_in_features = in_features // world_size
local_slice = slice(
rank * local_in_features,
(rank + 1) * local_in_features,
)
w_ref = w_ref[:, local_slice]
dw_ref = dw_ref[:, local_slice]
w_test = w_test[:, local_slice]
if bias:
b_ref = b_ref[rank, :]
db_ref = db_ref[rank, :]
b_test = b_test[rank, :]
x_ref = x_ref[..., local_slice]
dx_ref = dx_ref[..., local_slice]
x_test = x_test[..., local_slice].clone()
if sequence_parallel:
local_batch_size = batch_size // world_size
local_slice = slice(
rank * local_batch_size,
(rank + 1) * local_batch_size,
)
if tensor_parallel_mode == "column":
x_ref = x_ref[local_slice, ...]
dx_ref = dx_ref[local_slice, ...]
x_test = x_test[local_slice, ...].clone()
elif tensor_parallel_mode == "row":
y_ref = y_ref[local_slice, ...]
dy_ref = dy_ref[local_slice, ...]
dy_test = dy_test[local_slice, ...].clone()
x_test.requires_grad_()
# Implementation with fusible operation
with te.fp8_model_init(enabled=fp8_weight):
model = te_ops.Sequential(
te_ops.Linear(
in_features,
out_features,
bias=bias,
device=device,
dtype=dtype,
tensor_parallel_mode=tensor_parallel_mode,
tensor_parallel_group=process_group,
sequence_parallel=sequence_parallel,
),
)
with torch.no_grad():
model[0].weight.copy_(w_test)
if bias:
model[0].bias.copy_(b_test)
del w_test
del b_test
with te.fp8_autocast(enabled=fp8_compute):
y_test = model(x_test)
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if fp8_compute:
tols = dtype_tols(
model[0].weight._fp8_dtype
if is_float8_tensor(model[0].weight)
else tex.DType.kFloat8E4M3
)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, dx_ref, **tols)
torch.testing.assert_close(dw_test, dw_ref, **tols)
if bias:
db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(db_test, db_ref, **tols)
def _test_fp8_scale_update(
*,
amax_history_len: int = 31,
amax_compute_algo: str = "max",
margin: float = 2,
local_weight_shape: tuple[int, int] = (16, 16),
batch_size: int = 16,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
tensor_parallel_mode: str = "column",
) -> None:
# Distributed process group
process_group = world_group()
rank = torch.distributed.get_rank(process_group)
world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions
local_out_features, local_in_features = local_weight_shape
out_features, in_features = local_out_features, local_in_features
if tensor_parallel_mode == "column":
out_features *= world_size
elif tensor_parallel_mode == "row":
in_features *= world_size
in_shape = [batch_size, in_features]
out_shape = [batch_size, out_features]
# Random data
reset_rng()
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
def ref_amax_and_scale(
ref: torch.Tensor,
stage: str,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Expected absmax and FP8 scale"""
amax = ref.abs().amax()
max_val = {
"forward": 448.0,
"backward": 57344.0,
}[stage]
scale = (max_val / amax) / (2**margin)
amax = amax.to(dtype=torch.float32, device="cpu")
scale = scale.to(dtype=torch.float32, device="cpu")
return amax, scale
# Compute expected amaxes and FP8 scales
x_amax_ref, x_scale_ref = ref_amax_and_scale(x_ref, "forward")
w_amax_ref, w_scale_ref = ref_amax_and_scale(w_ref, "forward")
dy_amax_ref, dy_scale_ref = ref_amax_and_scale(dy_ref, "backward")
# Convert to distributed tensors
with torch.no_grad():
if tensor_parallel_mode == "column":
local_out_features = out_features // world_size
local_slice = slice(
rank * local_out_features,
(rank + 1) * local_out_features,
)
w_ref = w_ref[local_slice, :]
w_test = w_test[local_slice, :]
dy_ref = dy_ref[..., local_slice]
dy_test = dy_test[..., local_slice].clone()
elif tensor_parallel_mode == "row":
local_in_features = in_features // world_size
local_slice = slice(
rank * local_in_features,
(rank + 1) * local_in_features,
)
w_ref = w_ref[:, local_slice]
w_test = w_test[:, local_slice]
x_ref = x_ref[..., local_slice]
x_test = x_test[..., local_slice].clone()
x_test.requires_grad_()
# Initialize fusible operation
op = te_ops.BasicLinear(
in_features,
out_features,
device=device,
dtype=dtype,
tensor_parallel_mode=tensor_parallel_mode,
tensor_parallel_group=process_group,
)
with torch.no_grad():
op.weight.copy_(w_test)
del w_test
# Forward and backward pass
fp8_format = transformer_engine.common.recipe.Format.HYBRID
recipe = transformer_engine.common.recipe.DelayedScaling(
margin=margin,
interval=1,
fp8_format=fp8_format,
amax_history_len=amax_history_len,
amax_compute_algo=amax_compute_algo,
)
with te.fp8_autocast(fp8_recipe=recipe):
y_test = op(x_test)
y_test.backward(dy_test)
# Check results
forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)
backward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False)
x_fp8_meta = op.get_fp8_meta("input")[forward_key]
w_fp8_meta = op.get_fp8_meta("param")[forward_key]
dy_fp8_meta = op.get_fp8_meta("grad_output")[backward_key]
x_amax_test = x_fp8_meta.amax_history[-1, 0].to(dtype=torch.float32, device="cpu")
w_amax_test = w_fp8_meta.amax_history[-1, 0].to(dtype=torch.float32, device="cpu")
dy_amax_test = dy_fp8_meta.amax_history[-1, 0].to(dtype=torch.float32, device="cpu")
x_scale_test = x_fp8_meta.scale[0].to(dtype=torch.float32, device="cpu")
w_scale_test = w_fp8_meta.scale[0].to(dtype=torch.float32, device="cpu")
dy_scale_test = dy_fp8_meta.scale[0].to(dtype=torch.float32, device="cpu")
torch.testing.assert_close(x_amax_test, x_amax_ref)
torch.testing.assert_close(w_amax_test, w_amax_ref)
torch.testing.assert_close(dy_amax_test, dy_amax_ref)
torch.testing.assert_close(x_scale_test, x_scale_ref)
torch.testing.assert_close(w_scale_test, w_scale_ref)
torch.testing.assert_close(dy_scale_test, dy_scale_ref)
def run_parallel_tests() -> None:
"""Run parallel tests"""
# Distributed process group
process_group = world_group()
rank = torch.distributed.get_rank(process_group)
world_size = torch.distributed.get_world_size(process_group)
# Collective communication ops
if rank == 0:
print(f"Running _test_all_reduce")
_test_all_reduce()
if rank == 0:
print(f"Running _test_all_gather")
_test_all_gather()
if rank == 0:
print(f"Running _test_reduce_scatter")
_test_reduce_scatter()
# Basic linear op
for config in itertools.product(
(False, True) if fp8_available else (False,),
("column", "row"),
(False, True),
):
if rank == 0:
print(f"Running _test_basic_linear with {config=}")
fp8, tensor_parallel_mode, sequence_parallel = config
_test_basic_linear(
fp8_compute=fp8,
fp8_input=fp8,
fp8_weight=fp8,
fp8_grad_output=fp8,
tensor_parallel_mode=tensor_parallel_mode,
sequence_parallel=sequence_parallel,
)
# Linear op
for config in itertools.product(
(False, True) if fp8_available else (False,),
("column", "row"),
):
if rank == 0:
print(f"Running _test_linear with {config=}")
fp8, tensor_parallel_mode = config
dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32
_test_linear(
bias=True, # bias=False is tested in _test_basic_linear
dtype=dtype,
fp8_compute=fp8,
fp8_input=fp8,
fp8_weight=fp8,
fp8_grad_output=fp8,
tensor_parallel_mode=tensor_parallel_mode,
)
# FP8 scale update
if fp8_available:
if rank == 0:
print(f"Running _test_fp8_scale_update")
_test_fp8_scale_update()
# Parallel job sizes
_world_sizes = [torch.cuda.device_count()]
if 1 not in _world_sizes:
_world_sizes.append(1)
if torch.cuda.device_count() >= 2 and 2 not in _world_sizes:
_world_sizes.append(2)
@pytest.mark.parametrize("world_size", _world_sizes)
def test_distributed_fuser_ops(world_size: int) -> None:
"""Launch parallel job that runs parallel tests"""
python_exe = pathlib.Path(sys.executable).resolve()
current_file = pathlib.Path(__file__).resolve()
command = [
python_exe,
"-m",
"torch.distributed.run",
f"--nproc_per_node={world_size}",
current_file,
"--parallel",
]
result = subprocess.run(
command,
check=True,
)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--parallel", action="store_true", help="Run parallel tests")
args = parser.parse_args()
if args.parallel:
run_parallel_tests()
if __name__ == "__main__":
main()
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
from typing import Optional from typing import Iterable, Optional
import pytest import pytest
import torch import torch
...@@ -15,6 +15,8 @@ from transformer_engine.pytorch.fp8 import ( ...@@ -15,6 +15,8 @@ from transformer_engine.pytorch.fp8 import (
_amax_and_scale_update, _amax_and_scale_update,
get_default_fp8_recipe, get_default_fp8_recipe,
) )
import transformer_engine.pytorch.ops as te_ops
import transformer_engine_torch as tex
# Check if FP8 is supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
...@@ -33,7 +35,7 @@ class TestFP8Recipe: ...@@ -33,7 +35,7 @@ class TestFP8Recipe:
@pytest.mark.parametrize("amax_history_len", [31, 1024]) @pytest.mark.parametrize("amax_history_len", [31, 1024])
@pytest.mark.parametrize("amax_compute_algo", ["max", "most_recent"]) @pytest.mark.parametrize("amax_compute_algo", ["max", "most_recent"])
@pytest.mark.parametrize("is_first_microbatch", [None, True, False]) @pytest.mark.parametrize("is_first_microbatch", [None, True, False])
def test_amax_and_scale_update( def test_fp8_scale_update_with_linear_module(
self, self,
amax_history_len: int, amax_history_len: int,
amax_compute_algo: str, amax_compute_algo: str,
...@@ -49,7 +51,7 @@ class TestFP8Recipe: ...@@ -49,7 +51,7 @@ class TestFP8Recipe:
amax_history_len=amax_history_len, amax_history_len=amax_history_len,
amax_compute_algo=amax_compute_algo, amax_compute_algo=amax_compute_algo,
) )
with te.fp8_autocast(enabled=True, fp8_recipe=recipe): with te.fp8_autocast(fp8_recipe=recipe):
module = te.Linear(16, 16) module = te.Linear(16, 16)
y = module( y = module(
torch.randn([16, 16], device="cuda"), torch.randn([16, 16], device="cuda"),
...@@ -162,6 +164,130 @@ class TestFP8Recipe: ...@@ -162,6 +164,130 @@ class TestFP8Recipe:
ref_scale_inv_backward[0], ref_scale_inv_backward[0],
) )
@pytest.mark.parametrize("amax_history_len", [31, 1024])
@pytest.mark.parametrize("amax_compute_algo", ["max", "most_recent"])
def test_fp8_scale_update_with_linear_fuser_op(
self,
amax_history_len: int,
amax_compute_algo: str,
margin: float = 2,
num_steps: int = 4,
in_shape: tuple[int] = (16, 16),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
):
# Construct linear op
op = te_ops.BasicLinear(in_shape[-1], in_shape[-1])
# Get FP8 meta tensors
forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)
backward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False)
x_fp8_meta = op.get_fp8_meta("input")[forward_key]
w_fp8_meta = op.get_fp8_meta("param")[forward_key]
dy_fp8_meta = op.get_fp8_meta("grad_output")[backward_key]
# Perform training steps
x_history = []
w_history = []
dy_history = []
for step in range(num_steps):
# Fill tensors with known values
x_history.append(step + 0.25)
w_history.append(step + 0.5)
dy_history.append(step + 0.75)
x = torch.full(
in_shape,
x_history[-1],
dtype=dtype,
device=device,
requires_grad=True,
)
dy = torch.full(
in_shape,
dy_history[-1],
dtype=dtype,
device=device,
)
with torch.no_grad():
op.weight.fill_(w_history[-1])
# Forward and backward pass
fp8_format = transformer_engine.common.recipe.Format.HYBRID
recipe = transformer_engine.common.recipe.DelayedScaling(
margin=margin,
interval=1,
fp8_format=fp8_format,
amax_history_len=amax_history_len,
amax_compute_algo=amax_compute_algo,
)
with te.fp8_autocast(fp8_recipe=recipe):
y = op(x)
y.backward(dy)
def check_amax_history(
fp8_meta: dict,
ref_amax_history: Iterable[float],
) -> None:
"""Check that amax history matches expected values"""
if len(ref_amax_history) > amax_history_len:
ref_amax_history = ref_amax_history[-amax_history_len:]
ref_amax_history = torch.tensor(
ref_amax_history,
dtype=torch.float32,
device=device,
)
test_amax_history = fp8_meta.amax_history[:, 0]
tols = dict(rtol=0, atol=0)
torch.testing.assert_close(
test_amax_history[-(step + 1) :],
ref_amax_history[: (step + 1)],
**tols,
)
def check_scale(
fp8_meta: dict,
ref_amax_history: Iterable[float],
stage: str,
):
"""Check that scale and scale reciprocal match expected values"""
# Compute amax
if len(ref_amax_history) > amax_history_len:
ref_amax_history = ref_amax_history[-(amax_history_len + 1) :]
if amax_compute_algo == "max":
ref_amax = max(ref_amax_history)
elif amax_compute_algo == "most_recent":
ref_amax = ref_amax_history[-1]
else:
raise RuntimeError(f"{amax_compute_algo=} is not supported")
# Compute scale
max_val = {
"forward": 448.0,
"backward": 57344.0,
}[stage]
ref_scale = (max_val / ref_amax) / (2**margin)
# Check values in FP8 meta tensors
torch.testing.assert_close(
fp8_meta.scale.item(),
ref_scale,
)
torch.testing.assert_close(
fp8_meta.scale_inv.item(),
1 / ref_scale,
)
# Check that results match expected values
check_amax_history(x_fp8_meta, x_history)
check_amax_history(w_fp8_meta, w_history)
check_amax_history(dy_fp8_meta, dy_history)
check_scale(x_fp8_meta, x_history, "forward")
check_scale(w_fp8_meta, w_history, "forward")
check_scale(dy_fp8_meta, dy_history, "backward")
@pytest.mark.parametrize("amax_case", ["zero", "tiny", "normal", "inf", "nan"]) @pytest.mark.parametrize("amax_case", ["zero", "tiny", "normal", "inf", "nan"])
@pytest.mark.parametrize("fused_update", [True, False], ids=["fused", "non-fused"]) @pytest.mark.parametrize("fused_update", [True, False], ids=["fused", "non-fused"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -191,7 +317,7 @@ class TestFP8Recipe: ...@@ -191,7 +317,7 @@ class TestFP8Recipe:
# Setup fp8_meta dictionary # Setup fp8_meta dictionary
def setup_fp8_meta(): def setup_fp8_meta():
with te.fp8_autocast(enabled=True, fp8_recipe=recipe): with te.fp8_autocast(fp8_recipe=recipe):
module = te.Linear(16, 16) module = te.Linear(16, 16)
y = module(torch.zeros([16, 16], device="cuda")) y = module(torch.zeros([16, 16], device="cuda"))
y.backward(torch.zeros_like(y)) y.backward(torch.zeros_like(y))
......
...@@ -3,9 +3,11 @@ ...@@ -3,9 +3,11 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Methods needed for distributed training (DP/TP).""" """Methods needed for distributed training (DP/TP)."""
import warnings from __future__ import annotations
from contextlib import contextmanager, AbstractContextManager, ContextDecorator from contextlib import contextmanager, AbstractContextManager, ContextDecorator
from typing import Any, Dict, Union, Optional, Callable, Tuple, List from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings
import torch import torch
from torch.cuda import _lazy_call, _lazy_init from torch.cuda import _lazy_call, _lazy_init
...@@ -829,23 +831,48 @@ def reduce_scatter_along_first_dim( ...@@ -829,23 +831,48 @@ def reduce_scatter_along_first_dim(
def gather_along_first_dim( def gather_along_first_dim(
input_: torch.Tensor, tp_group: dist_group_type, async_op: bool = False input_: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: process_group: dist_group_type,
"""Gather tensors and concatinate along the first dimension.""" async_op: bool = False,
) -> tuple[torch.Tensor, Any]:
"""All-gather tensors and concatenate along first dimension."""
world_size = get_distributed_world_size(tp_group) # Return immediately if no communication is required
# Bypass the function if we are using only 1 GPU. world_size = get_distributed_world_size(process_group)
if world_size == 1: if world_size == 1:
return input_, None return input_, None
dim_size = list(input_.size()) # Allocate output tensor
dim_size[0] = dim_size[0] * world_size output_shape = list(input_.size())
output_shape[0] *= world_size
if isinstance(input_, Float8Tensor):
output = Float8Tensor.make_like(
input_,
data=torch.empty(
output_shape,
dtype=torch.uint8,
device=input_.device,
),
)
src = input_._data.contiguous()
dst = output._data
else:
output = torch.empty(
output_shape,
dtype=input_.dtype,
device=input_.device,
memory_format=torch.contiguous_format,
)
src = input_.contiguous()
dst = output
output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device()) # Launch all-gather
handle = torch.distributed.all_gather_into_tensor( handle = torch.distributed.all_gather_into_tensor(
output, input_.contiguous(), group=tp_group, async_op=async_op dst,
src,
group=process_group,
async_op=async_op,
) )
return output, handle return output, handle
......
...@@ -563,6 +563,23 @@ class Float8Tensor(torch.Tensor): ...@@ -563,6 +563,23 @@ class Float8Tensor(torch.Tensor):
return _IdentityFunc.apply(self) return _IdentityFunc.apply(self)
return super().expand_as(other) return super().expand_as(other)
def contiguous(
self,
*,
memory_format: torch.memory_format = torch.contiguous_format,
) -> Float8Tensor:
"""Returns tensor with data in provided memory format
Returns `self` if data is already in correct memory format.
"""
if self._data.is_contiguous(memory_format=memory_format):
return self
return _IdentityFunc.apply(
self,
{"data": self._data.detach().contiguous(memory_format=memory_format)},
)
def transpose_2d( def transpose_2d(
self, self,
*, *,
...@@ -885,6 +902,22 @@ class Float8Tensor(torch.Tensor): ...@@ -885,6 +902,22 @@ class Float8Tensor(torch.Tensor):
fp8_attrs=args[0]._fp8_attrs, fp8_attrs=args[0]._fp8_attrs,
) )
# View op
if func == aten.view.default:
tensor = args[0]
data = tensor._data
data_view = data.__torch_dispatch__(
func,
types,
[data] + list(args[1:]),
kwargs,
)
return Float8Tensor.make_like(
tensor,
data=data_view,
fp8_attrs=tensor._fp8_attrs,
)
def maybe_unwrap(t): def maybe_unwrap(t):
if isinstance(t, Float8Tensor): if isinstance(t, Float8Tensor):
return t.from_float8() return t.from_float8()
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operations.
This operation-based API is experimental and subject to change.
"""
from transformer_engine.pytorch.ops.basic import (
AllGather,
AllReduce,
BasicLinear,
Bias,
Identity,
ReduceScatter,
Reshape,
)
from transformer_engine.pytorch.ops.linear import Linear
from transformer_engine.pytorch.ops.op import FusibleOperation
from transformer_engine.pytorch.ops.sequential import Sequential
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Helper functions used in fusible operations."""
from __future__ import annotations
from typing import Any, Iterable, Optional
import torch
from transformer_engine.pytorch.float8_tensor import Float8Tensor
def canonicalize_device(device: Optional[torch.device | str]) -> torch.device:
"""Canonicalize PyTorch device
If `None`, then returns the default CUDA device.
"""
if device is None:
# Use default CUDA device
device = torch.get_default_device()
if device.type != "cuda":
device = torch.device("cuda", torch.cuda.current_device())
elif not isinstance(device, torch.device):
device = torch.device(device)
if device.type == "cuda" and device.index is None:
device = torch.device("cuda", torch.cuda.current_device())
return device
def canonicalize_dtype(dtype: Optional[torch.dtype]) -> torch.dtype:
"""Canonicalize PyTorch datatype
If `None`, then returns the default PyTorch datatype.
"""
if dtype is None:
# Use default dtype
dtype = torch.get_default_dtype()
return dtype
def devices_match(device1: torch.device, device2: torch.device) -> bool:
"""Whether two devices are the same"""
device1 = torch.device(device1)
device2 = torch.device(device2)
if device1.type != device2.type:
return False
if device1.type == "cuda":
index1 = device1.index
index2 = device2.index
if index1 is None:
index1 = torch.cuda.current_device()
if index2 is None:
index2 = torch.cuda.current_device()
return index1 == index2
return device1 == device2
def is_float8_tensor(tensor: Any) -> bool:
"""Check if object is a `Float8Tensor`"""
return isinstance(tensor, Float8Tensor)
def convert_tensor(
tensor: torch.Tensor | Float8Tensor,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
memory_format: torch.memory_format = torch.preserve_format,
) -> torch.Tensor | Float8Tensor:
"""Convert tensor attributes, keeping same data if possible"""
# Default kwargs
if device is None:
device = tensor.device
device = canonicalize_device(device)
if dtype is None:
dtype = tensor.dtype
dtype = canonicalize_dtype(dtype)
# Make sure output is detached from autograd graph
tensor = tensor.detach()
# Return immediately if tensor already has desired attributes
if devices_match(device, tensor.device) and dtype == tensor.dtype:
if memory_format == torch.preserve_format or tensor.is_contiguous(
memory_format=memory_format
):
return tensor
# Convert FP8 tensor
if is_float8_tensor(tensor):
data = tensor._data.to(device=device, memory_format=memory_format)
return Float8Tensor.make_like(
tensor,
data=data,
fp8_attrs=tensor._fp8_attrs,
dtype=dtype,
)
# Convert standard PyTorch tensor
return tensor.to(device=device, dtype=dtype, memory_format=memory_format)
def reshape(
tensor: torch.Tensor | Float8Tensor,
shape: Iterable[int],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor | Float8Tensor:
"""Reshape tensor, keeping same data if possible
If the input is a Float8Tensor, this function attempts to preserve
the cached transpose if available and valid. If a cached transpose
is present, it is interpreted as the transpose of a 2D matrix
where the width matches the innermost tensor dimension.
"""
# Make sure tensor is in expected format
tensor = convert_tensor(
tensor,
device=device,
dtype=dtype,
memory_format=torch.contiguous_format,
)
# Return immediately if tensor already has desired shape
shape = list(shape)
if len(shape) == tensor.dim():
if sum(1 for d in shape if d == -1) > 1:
raise ValueError(
"Attempted to reshape tensor with "
f"shape={tuple(tensor.size())} into shape={tuple(shape)}"
)
if all(d1 == d2 for d1, d2 in zip(shape, tensor.size()) if d1 != -1):
return tensor
# Reshape FP8 tensor
# Note: Preserve cached transpose if possible
if is_float8_tensor(tensor):
out = Float8Tensor.make_like(
tensor,
data=tensor._data.view(shape),
fp8_attrs=tensor._fp8_attrs,
)
return out
# Reshape standard PyTorch tensor
return tensor.view(shape)
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Single tensor operations supported by the operation fuser."""
from .all_gather import AllGather
from .all_reduce import AllReduce
from .basic_linear import BasicLinear
from .bias import Bias
from .identity import Identity
from .reduce_scatter import ReduceScatter
from .reshape import Reshape
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for all-gather."""
from __future__ import annotations
from typing import Optional
import torch
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from .._common import convert_tensor, is_float8_tensor
class AllGather(BasicOperation):
"""All-gather tensor along outer dimension
Equivalent to gathering tensors from all processes and
concatenating along the first dimension.
Parameters
----------
process_group: torch.distributed.ProcessGroup, default = world group
Process group for communication
"""
def __init__(
self,
process_group: Optional[torch.distributed.ProcessGroup] = None,
) -> None:
super().__init__()
self.process_group: Optional[torch.distributed.ProcessGroup] = process_group
self.process_group_size: int = torch.distributed.get_world_size(process_group)
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
) -> torch.Tensor:
# Trivial case
if self.process_group_size == 1:
return input_
# Tensor dimensions
input_dims = input_.size()
if not input_dims:
raise RuntimeError(
"Attempted to all-gather a tensor "
f"with shape={list(input_dims)} "
f"over {self.process_group_size} processes"
)
output_dims = list(input_dims)
output_dims[0] *= self.process_group_size
# Perform all-gather
x = convert_tensor(input_, memory_format=torch.contiguous_format)
y = None
if is_float8_tensor(x):
y = Float8Tensor.make_like(
x,
data=torch.empty(
output_dims,
dtype=torch.uint8,
device=x.device,
),
)
torch.distributed.all_gather_into_tensor(
y._data,
x._data,
group=self.process_group,
)
else:
y = torch.empty(output_dims, dtype=x.dtype, device=x.device)
torch.distributed.all_gather_into_tensor(
y,
x,
group=self.process_group,
)
return y
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
# Trivial case
if self.process_group_size == 1:
return grad_output, ()
# Tensor dimensions
output_dims = grad_output.size()
if not output_dims or output_dims[0] % self.process_group_size != 0:
raise RuntimeError(
"Attempted to reduce-scatter a tensor "
f"with shape={list(output_dims)} "
f"over {self.process_group_size} processes"
)
input_dims = list(output_dims)
input_dims[0] //= self.process_group_size
# Check output gradient tensor
dy = grad_output
if is_float8_tensor(dy):
dy = dy.from_float8()
dy = dy.contiguous()
# Perform reduce-scatter
dx = torch.empty(input_dims, dtype=dy.dtype, device=dy.device)
torch.distributed.reduce_scatter_tensor(
dx,
dy,
group=self.process_group,
)
return dx, ()
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for all-reduce."""
from __future__ import annotations
from typing import Optional
import torch
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from .._common import is_float8_tensor
class AllReduce(BasicOperation):
"""All-reduce tensor
Equivalent to summing tensors from all processes. It is assumed
that the output is used in operations that are redundantly
computed on all processes, and hence that gradients are identical
between processes.
Parameters
----------
process_group: torch.distributed.ProcessGroup, default = world group
Process group for communication
"""
def __init__(
self,
process_group: Optional[torch.distributed.ProcessGroup] = None,
reduce_in_backward: bool = True,
) -> None:
super().__init__()
self.process_group: Optional[torch.distributed.ProcessGroup] = process_group
self._reduce_in_backward: bool = reduce_in_backward
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
) -> torch.Tensor:
# Trivial case
if torch.distributed.get_world_size(self.process_group) == 1:
return input_
# Perform all-reduce
x = input_
if is_float8_tensor(x):
x = x.from_float8()
x = x.contiguous()
torch.distributed.all_reduce(x, group=self.process_group)
return x
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
return grad_output, ()
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for linear layer without bias."""
from __future__ import annotations
from collections.abc import Callable, Iterable
import contextlib
import math
from typing import Any, Optional
import torch
from transformer_engine.pytorch.cpp_extensions import fp8_gemm, gemm
from transformer_engine.pytorch.distributed import (
CudaRNGStatesTracker,
gather_along_first_dim,
reduce_scatter_along_first_dim,
)
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import (
FP8GlobalStateManager,
get_fp8_te_dtype,
)
from transformer_engine.pytorch.module.base import get_workspace
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from .._common import (
canonicalize_device,
canonicalize_dtype,
convert_tensor,
is_float8_tensor,
reshape,
)
from ...utils import clear_tensor_data
def _wait_async(handle: Optional[Any]) -> None:
"""Wait for asynchronous communication to finish, if needed"""
if handle is not None:
handle.wait()
class BasicLinear(BasicOperation):
"""Apply linear transformation: :math:`y = x A^T`
This is a drop-in replacement for `torch.nn.Linear` with
`bias=False`.
Parameters
----------
in_features: int
Inner dimension of input tensor
out_features: int
Inner dimension of output tensor
device: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
Tensor datatype
tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
Mode for tensor parallelism
tensor_parallel_group: torch.distributed.ProcessGroup, default = world group
Process group for tensor parallelism
sequence_parallel: bool, default = `False`
Whether to apply sequence parallelism together with tensor
parallelism, i.e. distributing input or output tensors along
outer dimension (sequence or batch dim) when not distributing
along inner dimension (embedding dim)
rng_state_tracker_function: callable
Function that returns `CudaRNGStatesTracker`, which is used
for model-parallel weight initialization
accumulate_into_main_grad: bool, default = `False`
Whether to directly accumulate weight gradients into the
weight's `main_grad` attribute instead of relying on PyTorch
autograd. The weight's `main_grad` must be set externally and
there is no guarantee that `grad` will be set or be
meaningful.
"""
def __init__(
self,
in_features: int,
out_features: int,
*,
device: Optional[torch.device | str] = None,
dtype: Optional[torch.dtype] = None,
tensor_parallel_mode: Optional[str] = None,
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
sequence_parallel: bool = False,
rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None,
accumulate_into_main_grad: bool = False,
) -> None:
super().__init__()
# Weight tensor dimensions
self.in_features: int = in_features
self.out_features: int = out_features
# Weight tensor device
defer_param_init = False
device = canonicalize_device(device)
if device.type == "meta":
defer_param_init = True
device = canonicalize_device(None)
if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})")
self.device: torch.device = device
# Weight tensor datatype
dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
self.dtype: torch.dtype = canonicalize_dtype(dtype)
# Tensor parallel configuration
self.tensor_parallel_mode: Optional[str]
self.tensor_parallel_group: Optional[torch.distributed.ProcessGroup]
self.tensor_parallel_size: int
self.sequence_parallel: bool
self.local_in_features: int
self.local_out_features: int
(
self.tensor_parallel_mode,
self.tensor_parallel_group,
self.tensor_parallel_size,
self.sequence_parallel,
self.local_in_features,
self.local_out_features,
) = self._canonicalize_tensor_parallelism(
mode=tensor_parallel_mode,
process_group=tensor_parallel_group,
sequence_parallel=sequence_parallel,
in_features=in_features,
out_features=out_features,
)
# Whether weight tensor is natively in FP8
self._with_fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
if self._with_fp8_parameters:
self._fp8_metas = self._make_fp8_metas()
# Initialize parameters if needed
weight = torch.empty(
self.local_out_features,
self.local_in_features,
device="meta",
dtype=dtype,
)
weight = torch.nn.Parameter(weight)
self.weight: torch.nn.Parameter
self.register_parameter("weight", weight)
self._rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]]
self._rng_state_tracker_function = rng_state_tracker_function
if not defer_param_init:
self.reset_parameters()
# Whether to accumulate weight gradient into main_grad
self._accumulate_into_main_grad = accumulate_into_main_grad
@classmethod
def _canonicalize_tensor_parallelism(
cls,
*,
mode: Optional[str],
process_group: Optional[torch.distributed.ProcessGroup],
sequence_parallel: bool,
in_features: int,
out_features: int,
) -> tuple[
Optional[str],
Optional[torch.distributed.ProcessGroup],
int,
bool,
int,
int,
]:
"""Check configuration for tensor parallelism
Parameters
----------
mode: {`None`, "column", "row"}
Mode for tensor parallelism
process_group: torch.distributed.ProcessGroup
Process group for tensor parallelism
sequence_parallel: bool
Whether to apply sequence parallelism together with tensor
parallelism, i.e. distributing input or output tensors
along outer dimension (sequence or batch dim) when not
distributing along inner dimension (embedding dim)
in_features: int
Inner dimension of global input tensor
out_features: int
Inner dimension of global output tensor
Returns
-------
mode: {`None`, "column", "row"}
Mode for tensor parallelism
process_group: torch.distributed.ProcessGroup
Process group for tensor parallelism
group_size: int
Size of tensor-parallel process group
sequence_parallel: bool
Whether to apply sequence parallelism
local_in_features: int
Inner dimension of local input tensor
local_out_features: int
Inner dimension of local output tensor
"""
# Tensor-parallel group size
if mode is None:
group_size = 1
else:
group_size = torch.distributed.get_world_size(process_group)
# Disable tensor parallelism if not needed
if group_size == 1:
mode = None
process_group = None
sequence_parallel = False
# Determine local tensor dims
local_in_features = in_features
local_out_features = out_features
if mode is None:
pass
elif mode == "column":
# Distribute output tensor
if out_features % group_size != 0:
raise ValueError(
"Invalid configuration for tensor parallelism "
f"({mode=}, {out_features=}, {group_size=})"
)
local_out_features //= group_size
elif mode == "row":
# Distribute input tensor
if in_features % group_size != 0:
raise ValueError(
"Invalid configuration for tensor parallelism "
f"({mode=}, {in_features=}, {group_size=})"
)
local_in_features //= group_size
else:
raise ValueError(
"Supported modes for tensor parallelism are "
f'`None`, "row", and "column" (got {mode=})'
)
return (
mode,
process_group,
group_size,
sequence_parallel,
local_in_features,
local_out_features,
)
def num_fp8_scales(self, mode: str) -> int:
if mode in ("input", "param", "grad_output"):
return 1
return 0
def reset_parameters(self) -> None:
"""Initialize parameter buffers and values"""
# Make sure parameter is initialized
weight = self.weight
if weight.device.type != "cuda" or is_float8_tensor(weight):
weight = torch.empty_like(weight, device=self.device)
weight = weight.to(device=self.device, dtype=self.dtype)
# Initialize values
init_context = contextlib.nullcontext
if self._rng_state_tracker_function is not None:
init_context = self._rng_state_tracker_function().fork
with init_context():
torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
# Cast to FP8 if needed
if self._with_fp8_parameters:
weight = Float8Tensor.to_float8(
weight,
fp8_meta=self.get_fp8_meta("param"),
fp8_meta_index=0,
)
# Save updated parameter
if not isinstance(weight, torch.nn.Parameter):
weight = torch.nn.Parameter(weight)
self.weight = weight
def pre_forward(self) -> None:
super().pre_forward()
if self.weight.device.type == "meta":
self.reset_parameters()
@staticmethod
def _functional_forward(
input: torch.Tensor, # pylint: disable=redefined-builtin
weight: torch.Tensor,
*,
bias: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
tensor_parallel_mode: Optional[str] = None,
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
sequence_parallel: bool = False,
with_fp8_compute: bool = False,
input_fp8_meta: Optional[dict[str, Any]] = None,
weight_fp8_meta: Optional[dict[str, Any]] = None,
output_fp8_meta: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Functional API for forward pass
Parameters
----------
input: torch.Tensor
Input tensor
weight: torch.Tensor
Weight tensor
bias: torch.Tensor, optional
Bias tensor
device: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
Tensor datatype
tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
Mode for tensor parallelism
tensor_parallel_group: torch.distributed.ProcessGroup, default = world group
Process group for tensor parallelism
sequence_parallel: bool, default = `False`
Whether to apply sequence parallelism together with tensor
parallelism, i.e. distributing input or output tensors
along outer dimension (sequence or batch dim) when not
distributing along inner dimension (embedding dim)
with_fp8_compute: bool, default = `False`
Whether to perform compute in FP8
input_fp8_meta: dict, optional
FP8 metadata for casting input tensor to FP8. Required for
FP8 compute if input is not already in FP8.
weight_fp8_meta: dict, optional
FP8 metadata for casting weight tensor to FP8. Required for
FP8 compute if weight is not already in FP8.
output_fp8_meta: dict, optional
FP8 metadata for casting output tensor to FP8
Returns
-------
torch.Tensor
Output tensor
torch.Tensor
Input tensor used in GEMM, possibly cast and reshaped from
provided input tensor
torch.Tensor
Weight tensor used in GEMM, possibly cast and reshaped from
provided weight tensor
"""
# Check device
if device is None:
device = weight.device
device = canonicalize_device(device)
if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})")
# Check datatype
if dtype is None:
dtype = weight.dtype
dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
# Check tensor dims
input_dims = tuple(input.size())
weight_dims = tuple(weight.size())
if len(weight_dims) != 2:
raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})")
if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]:
raise ValueError(
f"Input tensor (shape={input_dims}) "
f"and weight tensor (shape={weight_dims}) "
"are not compatible"
)
# Check if FP8 is enabled
if with_fp8_compute:
if input_fp8_meta is None and not is_float8_tensor(input):
raise ValueError("No FP8 metadata was provided for casting input to FP8")
if weight_fp8_meta is None and not is_float8_tensor(weight):
raise ValueError("No FP8 metadata was provided for casting weight to FP8")
else:
input_fp8_meta = None
weight_fp8_meta = None
output_fp8_meta = None
with_fp8_output = (
with_fp8_compute and tensor_parallel_mode != "row" and output_fp8_meta is not None
)
# Check input tensor
x_local = reshape(
input,
(-1, input_dims[-1]),
device=device,
dtype=dtype,
)
if with_fp8_compute and not is_float8_tensor(x_local):
fp8_dtype = get_fp8_te_dtype(
input_fp8_meta["recipe"],
fprop_tensor=True,
)
x_fp8 = Float8Tensor(
data=torch.empty_like(x_local, dtype=torch.uint8),
fp8_meta=input_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device),
dtype=dtype,
)
with_cast_transpose = weight.requires_grad
if tensor_parallel_mode == "column" and sequence_parallel:
with_cast_transpose = False
if with_cast_transpose:
x_fp8.cast_transpose_(x_local)
else:
x_fp8.copy_(x_local)
x_local = x_fp8
elif not with_fp8_compute and is_float8_tensor(x_local):
x_local = x_local.from_float8()
x = x_local
x_async = None
if tensor_parallel_mode == "column" and sequence_parallel:
x, x_async = gather_along_first_dim(
x_local,
tensor_parallel_group,
async_op=True,
)
# Check weight tensor
w = convert_tensor(
weight,
device=device,
dtype=dtype,
memory_format=torch.contiguous_format,
)
if with_fp8_compute and not is_float8_tensor(w):
fp8_dtype = get_fp8_te_dtype(
weight_fp8_meta["recipe"],
fprop_tensor=True,
)
w = Float8Tensor.to_float8(
w,
fp8_meta=weight_fp8_meta,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
)
elif not with_fp8_compute and is_float8_tensor(w):
w = w.from_float8()
# Check bias tensor
b = None
if bias is not None:
b = convert_tensor(
bias,
device=device,
dtype=dtype,
memory_format=torch.contiguous_format,
)
# Construct output tensor
y = None
if with_fp8_output:
fp8_dtype = get_fp8_te_dtype(
output_fp8_meta["recipe"],
fprop_tensor=True,
)
data = torch.empty(
(x.size(0), weight_dims[0]),
dtype=torch.uint8,
device=device,
)
y = Float8Tensor(
data=data,
fp8_meta=output_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
dtype=dtype,
)
else:
y = torch.empty(
(x.size(0), weight_dims[0]),
dtype=dtype,
device=device,
)
# Perform GEMM
_wait_async(x_async)
x_async = None
if with_fp8_compute:
kwargs = dict(
out=y,
bias=b,
use_bias=(b is not None),
)
if with_fp8_output:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=y._fp8_meta_forward,
)
kwargs.update(
dict(
out=y._data,
out_index=y._fp8_meta_index,
fp8_meta_tensor=y._fp8_meta[fp8_meta_key],
D_dtype=y._fp8_dtype,
)
)
fp8_gemm(
w._data,
w._scale_inv,
0,
w._fp8_dtype,
x._data,
x._scale_inv,
0,
x._fp8_dtype,
y.dtype,
get_workspace(),
**kwargs,
)
else:
gemm(
w,
x,
y.dtype,
get_workspace(),
out=y,
bias=b,
use_bias=(b is not None),
)
# Reduce tensor-parallel output if needed
if tensor_parallel_mode == "row":
if sequence_parallel:
y, _ = reduce_scatter_along_first_dim(y, tensor_parallel_group)
else:
torch.distributed.all_reduce(y, group=tensor_parallel_group)
# Reshape output tensor
output_dims = list(input_dims)
output_dims[0] = -1
output_dims[-1] = weight_dims[0]
output = reshape(y, output_dims)
return output, x_local, w
@staticmethod
def _functional_backward(
grad_output: torch.Tensor,
input: Optional[torch.Tensor], # pylint: disable=redefined-builtin
weight: Optional[torch.Tensor],
input_dims: Iterable[int],
weight_dims: Iterable[int],
*,
input_requires_grad: bool = True,
weight_requires_grad: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
tensor_parallel_mode: Optional[str] = None,
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
sequence_parallel: bool = False,
with_fp8_compute: bool = False,
input_fp8_meta: Optional[dict[str, Any]] = None,
weight_fp8_meta: Optional[dict[str, Any]] = None,
grad_output_fp8_meta: Optional[dict[str, Any]] = None,
grad_input_fp8_meta: Optional[dict[str, Any]] = None,
accumulate_into_grad_weight: bool = False,
grad_weight: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Functional API for backward pass
Parameters
----------
grad_output: torch.Tensor
Loss gradient w.r.t. output tensor
input: torch.Tensor, optional
Input tensor. Required to compute loss gradient w.r.t.
weight.
weight: torch.Tensor, optional
Weight tensor. Required to compute loss gradient w.r.t.
input.
input_dims: iterable of int
Input tensor dimensions
weight_dims: iterable of int
Weight tensor dimensions
input_requires_grad: bool
Whether to compute loss gradient w.r.t. input tensor
weight_requires_grad: bool
Whether to compute loss gradient w.r.t. weight tensor
device: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
Tensor datatype
tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
Mode for tensor parallelism
tensor_parallel_group: torch.distributed.ProcessGroup, default = world group
Process group for tensor parallelism
sequence_parallel: bool, default = `False`
Whether to apply sequence parallelism together with tensor
parallelism, i.e. distributing input or output tensors
along outer dimension (sequence or batch dim) when not
distributing along inner dimension (embedding dim)
with_fp8_compute: bool, default = `False`
Whether to perform compute in FP8
input_fp8_meta: dict, optional
FP8 metadata for casting input tensor to FP8. Required for
FP8 compute if input is not already in FP8.
weight_fp8_meta: dict, optional
FP8 metadata for casting weight tensor to FP8. Required for
FP8 compute if weight is not already in FP8.
grad_output_fp8_meta: dict, optional
FP8 metadata for casting loss gradient w.r.t. output
tensor to FP8. Required if output grad is not already in
FP8.
grad_output_fp8_meta: dict, optional
FP8 metadata for casting loss gradient w.r.t. input
tensor to FP8
accumulate_into_grad_weight: bool, default = `False`
Accumulate into weight grad instead of overwriting
grad_weight: torch.Tensor, optional
Loss gradient w.r.t. weight tensor
Returns
-------
torch.Tensor
Loss gradient w.r.t. input tensor
torch.Tensor
Loss gradient w.r.t. weight tensor
"""
# Check device
if device is None:
device = weight.device
device = canonicalize_device(device)
if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})")
# Check datatype
if dtype is None:
dtype = weight.dtype
dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
# Check tensor dims
output_dims = tuple(grad_output.size())
input_dims = tuple(input_dims)
weight_dims = tuple(weight_dims)
if len(weight_dims) != 2:
raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})")
if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]:
raise ValueError(
f"Input tensor (shape={input_dims}) "
f"and weight tensor (shape={weight_dims}) "
"are not compatible"
)
if weight_dims[0] != output_dims[-1]:
raise ValueError(
f"Grad output tensor (shape={output_dims}) "
f"and weight tensor (shape={weight_dims}) "
"are not compatible"
)
# Check if FP8 is enabled
if with_fp8_compute:
if grad_output_fp8_meta is None and not is_float8_tensor(grad_output):
raise ValueError("No FP8 metadata was provided for casting output gradient to FP8")
else:
input_fp8_meta = None
weight_fp8_meta = None
grad_output_fp8_meta = None
grad_input_fp8_meta = None
with_fp8_grad_input = (
with_fp8_compute
and input_requires_grad
and tensor_parallel_mode != "column"
and grad_input_fp8_meta is not None
)
# Check grad output tensor
dy_async = None
dy = reshape(
grad_output,
(-1, output_dims[-1]),
device=device,
dtype=dtype,
)
if with_fp8_compute and not is_float8_tensor(dy):
fp8_dtype = get_fp8_te_dtype(
grad_output_fp8_meta["recipe"],
fprop_tensor=False,
)
dy_fp8 = Float8Tensor(
data=torch.empty_like(dy, dtype=torch.uint8),
fp8_meta=grad_output_fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device),
dtype=dtype,
)
with_cast_transpose = weight_requires_grad
if tensor_parallel_mode == "row" and sequence_parallel:
with_cast_transpose = False
if with_cast_transpose:
dy_fp8.cast_transpose_(dy)
else:
dy_fp8.copy_(dy)
dy = dy_fp8
elif not with_fp8_compute and is_float8_tensor(dy):
dy = dy.from_float8()
if tensor_parallel_mode == "row" and sequence_parallel:
dy, dy_async = gather_along_first_dim(
dy,
tensor_parallel_group,
async_op=True,
)
# Check input tensor
x = None
x_async = None
if weight_requires_grad:
if input is None:
raise ValueError("Input tensor is required to compute weight grad")
x_local = reshape(
input,
(-1, input_dims[-1]),
device=device,
dtype=dtype,
)
if with_fp8_compute and not is_float8_tensor(x_local):
fp8_dtype = get_fp8_te_dtype(
input_fp8_meta["recipe"],
fprop_tensor=True,
)
x_fp8 = Float8Tensor(
data=torch.empty_like(x_local, dtype=torch.uint8),
fp8_meta=input_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device),
dtype=dtype,
)
x_fp8.cast_transpose_(x_local)
x_local = x_fp8
elif not with_fp8_compute and is_float8_tensor(x_local):
x_local = x_local.from_float8()
x = x_local
if tensor_parallel_mode == "column" and sequence_parallel:
x, x_async = gather_along_first_dim(
x_local,
tensor_parallel_group,
async_op=True,
)
# Compute grad input
dx = None
dx_async = None
if input_requires_grad:
# Check weight tensor
if weight is None:
raise ValueError("Weight tensor is required to compute input grad")
w = convert_tensor(
weight,
device=device,
dtype=dtype,
memory_format=torch.contiguous_format,
)
if with_fp8_compute and not is_float8_tensor(w):
fp8_dtype = get_fp8_te_dtype(
weight_fp8_meta["recipe"],
fprop_tensor=True,
)
w_fp8 = Float8Tensor(
data=torch.empty_like(w, dtype=torch.uint8),
fp8_meta=weight_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device),
dtype=dtype,
)
w_fp8.cast_transpose_(w)
w = w_fp8
elif not with_fp8_compute and is_float8_tensor(w):
w = w.from_float8()
# Construct grad input tensor
if with_fp8_grad_input:
fp8_dtype = get_fp8_te_dtype(
grad_input_fp8_meta["recipe"],
fprop_tensor=False,
)
data = torch.empty(
(dy.size(0), weight_dims[1]),
dtype=torch.uint8,
device=device,
)
dx = Float8Tensor(
data=data,
fp8_meta=grad_input_fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
dtype=dtype,
)
else:
dx = torch.empty(
(dy.size(0), weight_dims[1]),
dtype=dtype,
device=device,
)
# Perform dgrad GEMM
_wait_async(dy_async)
dy_async = None
if with_fp8_compute:
kwargs = dict(out=dx)
if with_fp8_grad_input:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=dx._fp8_meta_forward,
)
kwargs.update(
dict(
out=dx._data,
out_index=dx._fp8_meta_index,
fp8_meta_tensor=dx._fp8_meta[fp8_meta_key],
D_dtype=dx._fp8_dtype,
)
)
fp8_gemm(
w.transpose_2d(),
w._scale_inv,
0,
w._fp8_dtype,
dy._data,
dy._scale_inv,
0,
dy._fp8_dtype,
dx.dtype,
get_workspace(),
**kwargs,
)
else:
gemm(
w,
dy,
dx.dtype,
get_workspace(),
layout="NN",
out=dx,
)
# Reduce tensor-parallel grad input if needed
if tensor_parallel_mode == "column":
if sequence_parallel:
dx, dx_async = reduce_scatter_along_first_dim(
dx,
tensor_parallel_group,
async_op=True,
)
else:
dx_async = torch.distributed.all_reduce(
dx,
group=tensor_parallel_group,
async_op=True,
)
# Perform wgrad GEMM
if not weight_requires_grad:
grad_weight = None
else:
if grad_weight is None:
if accumulate_into_grad_weight:
raise ValueError(
"Attempted to accumulate into grad weight buffer"
"without providing grad weight"
)
grad_weight = torch.empty(
weight_dims,
dtype=dtype,
device=device,
memory_format=torch.contiguous_format,
)
_wait_async(dy_async)
_wait_async(x_async)
dy_async = None
x_async = None
if with_fp8_compute:
fp8_gemm(
x.transpose_2d(),
x._scale_inv,
0,
x._fp8_dtype,
dy.transpose_2d(),
dy._scale_inv,
0,
dy._fp8_dtype,
grad_weight.dtype,
get_workspace(),
accumulate=accumulate_into_grad_weight,
out=grad_weight,
)
else:
gemm(
x,
dy,
x.dtype,
get_workspace(),
accumulate=accumulate_into_grad_weight,
layout="NT",
out=grad_weight,
)
# Clean up and return grads
_wait_async(dy_async)
_wait_async(x_async)
_wait_async(dx_async)
grad_input = None
if dx is not None:
grad_input = reshape(dx, input_dims)
return grad_input, grad_weight
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
) -> torch.Tensor:
# FP8 metadata
with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled()
input_fp8_meta = None
weight_fp8_meta = None
output_fp8_meta = None
grad_output_fp8_meta = None
grad_input_fp8_meta = None
if with_fp8_compute:
input_fp8_meta = self.get_fp8_meta("input")
weight_fp8_meta = self.get_fp8_meta("param")
if next_op is not None and next_op.num_fp8_scales("input") > 0:
output_fp8_meta = next_op.get_fp8_meta("input")
grad_output_fp8_meta = self.get_fp8_meta("grad_output")
if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0:
grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output")
# Linear forward
output, x_local, _ = BasicLinear._functional_forward(
input=input_,
weight=self.weight,
device=self.device,
dtype=self.dtype,
tensor_parallel_mode=self.tensor_parallel_mode,
tensor_parallel_group=self.tensor_parallel_group,
sequence_parallel=self.sequence_parallel,
with_fp8_compute=with_fp8_compute,
input_fp8_meta=input_fp8_meta,
weight_fp8_meta=weight_fp8_meta,
output_fp8_meta=output_fp8_meta,
)
# Save state for backward pass
ctx.save_for_backward(x_local)
ctx.with_fp8_compute = with_fp8_compute
ctx.weight_fp8_meta = weight_fp8_meta
ctx.grad_output_fp8_meta = grad_output_fp8_meta
ctx.grad_input_fp8_meta = grad_input_fp8_meta
ctx.input_dims = input_.size()
ctx.input_requires_grad = input_.requires_grad
ctx.weight_requires_grad = self.weight.requires_grad
ctx.has_prev_op = prev_op is not None
return output
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, Iterable[Optional[torch.Tensor]]]:
# Saved tensors from forward pass
(x_local,) = ctx.saved_tensors
# wgrad fusion
accumulate_into_main_grad = self._accumulate_into_main_grad
grad_weight = None
if ctx.weight_requires_grad and accumulate_into_main_grad:
if not hasattr(self.weight, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
"accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute"
)
grad_weight = self.weight.main_grad.detach()
else:
accumulate_into_main_grad = False
# Linear backward pass
grad_input, grad_weight = BasicLinear._functional_backward(
grad_output=grad_output,
input=x_local,
weight=self.weight,
input_dims=ctx.input_dims,
weight_dims=self.weight.size(),
input_requires_grad=ctx.input_requires_grad,
weight_requires_grad=ctx.weight_requires_grad,
device=self.device,
dtype=self.dtype,
tensor_parallel_mode=self.tensor_parallel_mode,
tensor_parallel_group=self.tensor_parallel_group,
sequence_parallel=self.sequence_parallel,
with_fp8_compute=ctx.with_fp8_compute,
weight_fp8_meta=ctx.weight_fp8_meta,
grad_output_fp8_meta=ctx.grad_output_fp8_meta,
grad_input_fp8_meta=ctx.grad_input_fp8_meta,
accumulate_into_grad_weight=accumulate_into_main_grad,
grad_weight=grad_weight,
)
# Clear input tensor if possible
if ctx.has_prev_op:
clear_tensor_data(x_local)
if accumulate_into_main_grad:
grad_weight = None
return grad_input, [grad_weight]
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for bias."""
from __future__ import annotations
from typing import Optional
import torch
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from .._common import (
canonicalize_device,
canonicalize_dtype,
)
class Bias(BasicOperation):
"""Apply additive bias
This is equivalent to the additive bias in `torch.nn.Linear`.
Parameters
----------
size: int
Inner dimension of input tensor
device: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
Tensor datatype
tensor_parallel: bool, default = `False`
Whether to distribute input tensor and bias tensors along
inner dimension
tensor_parallel_group: torch.distributed.ProcessGroup, default = world group
Process group for tensor parallelism
"""
def __init__(
self,
size: int,
*,
device: Optional[torch.device | str] = None,
dtype: Optional[torch.dtype] = None,
tensor_parallel: bool = False,
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
) -> None:
super().__init__()
# Bias size
self._size = size
# Bias tensor device
defer_param_init = False
device = canonicalize_device(device)
if device.type == "meta":
defer_param_init = True
device = canonicalize_device(None)
self.device: torch.device = device
# Bias tensor datatype
self.dtype: torch.dtype = canonicalize_dtype(dtype)
# Tensor parallel configuration
tensor_parallel_size = 1
local_size = size
if tensor_parallel:
tensor_parallel_size = torch.distributed.get_world_size(tensor_parallel_group)
tensor_parallel = tensor_parallel_size > 1
if size % tensor_parallel_size != 0:
raise ValueError(
"Invalid configuration for tensor parallelism "
f"({size=}, {tensor_parallel_size=})"
)
local_size //= tensor_parallel_size
else:
tensor_parallel_group = None
self.tensor_parallel: bool = tensor_parallel
self.tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = tensor_parallel_group
self.tensor_parallel_size: int = tensor_parallel_size
self.local_size: int = local_size
# Initialize parameters if needed
bias = torch.empty(
local_size,
device="meta",
dtype=dtype,
)
bias = torch.nn.Parameter(bias)
self.bias: torch.nn.Parameter
self.register_parameter("bias", bias)
if not defer_param_init:
self.reset_parameters()
def reset_parameters(self) -> None:
"""Initialize parameter buffers and values"""
# Make sure parameter is initialized
bias = self.bias
if bias.device.type != "cuda":
bias = torch.empty_like(bias, device=self.device)
bias = bias.to(device=self.device, dtype=self.dtype)
# Initialize values
bias.zero_()
# Save updated parameter
if not isinstance(bias, torch.nn.Parameter):
bias = torch.nn.Parameter(bias)
self.bias = bias
def pre_forward(self) -> None:
super().pre_forward()
if self.bias.device.type == "meta":
self.reset_parameters()
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
) -> torch.Tensor:
x = input_
b = self.bias.reshape([1] * (x.dim() - 1) + [self.local_size])
return x + b
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
dy = grad_output
if dy.dim() > 1:
db = dy.sum(tuple(range(dy.dim() - 1)))
else:
db = dy
return dy, (db,)
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for identity."""
from __future__ import annotations
from typing import Optional
import torch
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
class Identity(BasicOperation):
"""Return input tensor"""
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
) -> torch.Tensor:
return input_
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
return grad_output, ()
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for reduce-scatter."""
from __future__ import annotations
from typing import Optional
import torch
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from .._common import convert_tensor, is_float8_tensor
class ReduceScatter(BasicOperation):
"""Reduce-scatter tensor along outer dimension
Equivalent to summing tensors from all processes and splitting
along the first dimension.
Parameters
----------
process_group: torch.distributed.ProcessGroup, default = world group
Process group for communication
"""
def __init__(
self,
process_group: Optional[torch.distributed.ProcessGroup] = None,
) -> None:
super().__init__()
self.process_group: Optional[torch.distributed.ProcessGroup] = process_group
self.process_group_size: int = torch.distributed.get_world_size(process_group)
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
) -> torch.Tensor:
# Trivial case
if self.process_group_size == 1:
return input_
# Tensor dimensions
input_dims = input_.size()
if not input_dims or input_dims[0] % self.process_group_size != 0:
raise RuntimeError(
"Attempted to reduce-scatter a tensor "
f"with shape={list(input_dims)} "
f"over {self.process_group_size} processes"
)
output_dims = list(input_dims)
output_dims[0] //= self.process_group_size
# Check input tensor
x = input_
if is_float8_tensor(x):
x = x.from_float8()
x = x.contiguous()
# Perform reduce-scatter
y = torch.empty(output_dims, dtype=x.dtype, device=x.device)
torch.distributed.reduce_scatter_tensor(y, x, group=self.process_group)
return y
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
# Trivial case
if self.process_group_size == 1:
return grad_output, ()
# Tensor dimensions
output_dims = grad_output.size()
if not output_dims:
raise RuntimeError(
"Attempted to all-gather a tensor "
f"with shape={list(output_dims)} "
f"over {self.process_group_size} processes"
)
input_dims = list(output_dims)
input_dims[0] *= self.process_group_size
# Perform all-gather
dy = convert_tensor(grad_output, memory_format=torch.contiguous_format)
dx = None
if is_float8_tensor(dy):
dx = Float8Tensor.make_like(
dy,
data=torch.empty(
input_dims,
dtype=torch.uint8,
device=dy.device,
),
)
torch.distributed.all_gather_into_tensor(
dx._data,
dy._data,
group=self.process_group,
)
else:
dx = torch.empty(input_dims, dtype=dy.dtype, device=dy.device)
torch.distributed.all_gather_into_tensor(
dx,
dy,
group=self.process_group,
)
return dx, ()
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for reshape."""
from __future__ import annotations
from collections.abc import Iterable
from typing import Optional
import torch
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from .._common import reshape
class Reshape(BasicOperation):
"""Reshape tensor
See `torch.reshape`.
Parameters
----------
shape: iterable of int
Output tensor dimensions. If one dimension is -1, it is
inferred based on input tensor dimensions.
"""
def __init__(self, shape: Iterable[int]) -> None:
super().__init__()
self._shape = tuple(shape)
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
) -> torch.Tensor:
ctx.input_shape = input_.size()
return reshape(input_, self._shape)
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
return reshape(grad_output, ctx.input_shape), ()
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Compound tensor operation supported by the operation fuser."""
from .linear_bias_activation import (
ForwardLinearBiasActivation,
fuse_forward_linear_bias_activation,
)
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused operation for GEMM, bias, activation in the forward pass."""
from __future__ import annotations
from typing import Any, Optional
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.ops.basic import BasicLinear, Bias
from transformer_engine.pytorch.ops.op import (
BasicOperation,
FusedOperation,
FusibleOperation,
OperationContext,
)
class ForwardLinearBiasActivation(FusedOperation):
"""Fused GEMM, bias, activation in the forward pass
Bias and activation are both optional. Row tensor parallelism is
not supported since that requires communication immediately after
the GEMM.
"""
def __init__(
self,
*,
linear: BasicLinear,
bias: Optional[Bias],
activation: None,
) -> None:
# Basic operations that comprise this fused operation
op_idxs = dict(
linear=0,
bias=None,
activation=None,
)
ops = [linear]
if bias is not None:
op_idxs["bias"] = len(ops)
ops.append(bias)
if activation is not None:
op_idxs["activation"] = len(ops)
ops.append(activation)
# Initialize base class
super().__init__(ops)
# Index of each basic operations
self._op_idxs: dict[str, Optional[int]] = op_idxs
def fuser_forward(
self,
basic_op_ctxs: list[OperationContext],
input_: torch.Tensor,
basic_op_prev_ops: list[Optional[BasicOperation]],
basic_op_next_ops: list[Optional[BasicOperation]],
basic_op_kwargs: list[dict[str, Any]],
) -> torch.Tensor:
# Get basic operations
idx = self._op_idxs["linear"]
linear_op = self.basic_ops[idx]
linear_op_ctx = basic_op_ctxs[idx]
if self._op_idxs["bias"] is None:
bias_op = None
bias = None
else:
idx = self._op_idxs["bias"]
bias_op = self.basic_ops[idx]
bias = bias_op.bias
if basic_op_kwargs[idx]:
raise ValueError("Bias operation forward does not expect keyword arguments")
if self._op_idxs["activation"] is None:
activation_op = None # pylint: disable=unused-variable
else:
raise NotImplementedError("Activations are not yet supported")
# FP8 metadata
with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled()
input_fp8_meta = None
weight_fp8_meta = None
output_fp8_meta = None
grad_output_fp8_meta = None
grad_input_fp8_meta = None
if with_fp8_compute:
input_fp8_meta = linear_op.get_fp8_meta("input")
weight_fp8_meta = linear_op.get_fp8_meta("param")
next_op = basic_op_next_ops[-1]
if next_op is not None and next_op.num_fp8_scales("input") > 0:
output_fp8_meta = next_op.get_fp8_meta("input")
grad_output_fp8_meta = linear_op.get_fp8_meta("grad_output")
prev_op = basic_op_prev_ops[0]
if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0:
grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output")
# Linear forward
output, x_local, _ = BasicLinear._functional_forward(
input=input_,
weight=linear_op.weight,
bias=bias,
device=linear_op.device,
dtype=linear_op.dtype,
tensor_parallel_mode=linear_op.tensor_parallel_mode,
tensor_parallel_group=linear_op.tensor_parallel_group,
sequence_parallel=linear_op.sequence_parallel,
with_fp8_compute=with_fp8_compute,
input_fp8_meta=input_fp8_meta,
weight_fp8_meta=weight_fp8_meta,
output_fp8_meta=output_fp8_meta,
)
# Save state for backward pass
linear_op_ctx.save_for_backward(x_local)
linear_op_ctx.with_fp8_compute = with_fp8_compute
linear_op_ctx.weight_fp8_meta = weight_fp8_meta
linear_op_ctx.grad_output_fp8_meta = grad_output_fp8_meta
linear_op_ctx.grad_input_fp8_meta = grad_input_fp8_meta
linear_op_ctx.input_dims = input_.size()
linear_op_ctx.input_requires_grad = input_.requires_grad
linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad
linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None
return output
def fuse_forward_linear_bias_activation(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fuse GEMM, bias, activation in the forward pass
Parameters
----------
ops: list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
Updated forward pass operations
"""
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 2:
out.extend(window)
# Check if first op is linear
window, ops = ops[:1], ops[1:]
op1, _ = window[0]
if not isinstance(op1, BasicLinear):
continue
if op1.tensor_parallel_mode == "row":
# Row tensor-parallelism requires communication after the
# GEMM
continue
if op1.dtype not in (torch.float16, torch.bfloat16):
# cuBLAS only supports fused GEMM+bias+activation with
# FP16 and BF16 output
continue
# Check if second op is bias
op2, _ = ops[0]
if not isinstance(op2, Bias):
continue
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = ForwardLinearBiasActivation(
linear=window[0][0],
bias=window[1][0],
activation=None,
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Manager class for a pipeline of fusible operations."""
from __future__ import annotations
from typing import Any, Optional
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.graph import is_graph_capturing
from transformer_engine.pytorch.ops.op import (
BasicOperation,
FusibleOperation,
OperationContext,
)
from transformer_engine.pytorch.ops.fused_forward import (
fuse_forward_linear_bias_activation,
)
class _OperationFuserAutogradFunction(torch.autograd.Function):
"""Autograd function for a pipeline of operations
Autograd must be done at the pipeline level since we may apply
different fusions in the forward and backward passes.
"""
# pylint: disable=unused-argument
@staticmethod
def forward(
func_ctx: torch.autograd.function.FunctionCtx,
input_: torch.Tensor,
forward_ops: list[tuple[FusibleOperation, list[int]]],
backward_ops: list[tuple[FusibleOperation, list[int]]],
basic_ops: list[BasicOperation],
basic_op_kwargs: list[dict[str, Any]],
*params: torch.nn.Parameter,
) -> torch.Tensor:
"""Forward pass
Parameters
----------
func_ctx: torch.autograd.function.FunctionCtx
Context for PyTorch autograd function
input_: torch.Tensor
Input to first operation in pipeline
forward_ops: list of tuple
Forward pass operations and the indices of the
corresponding basic operations. The order should match
basic_ops.
backward_ops: list of tuple
Backward pass operations and the indices of the
corresponding basic operations. The order should be the
reverse of basic_ops.
basic_ops: list of BasicOperation
Basic operations
basic_op_kwargs: list of dict
Keyword arguments to BasicOperation
*params: torch.nn.Parameter
Parameters in operation pipeline
"""
# Operation autograd contexts
basic_op_ctxs = [OperationContext() for _ in range(len(basic_ops))]
# Apply forward ops
x = input_
requires_grad = x.requires_grad
for op, basic_op_idxs in forward_ops:
# Forward op
prev_ops = [basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs]
next_ops = [
basic_ops[idx + 1] if (idx < len(basic_ops) - 1) else None for idx in basic_op_idxs
]
x = op.fuser_forward(
[basic_op_ctxs[idx] for idx in basic_op_idxs],
x,
prev_ops,
next_ops,
[basic_op_kwargs[idx] for idx in basic_op_idxs],
)
# Check if backward op is required
if not requires_grad:
requires_grad = any(param.requires_grad for param in op.parameters())
for idx in basic_op_idxs:
basic_op_ctxs[idx]._requires_grad = requires_grad
x.requires_grad_(requires_grad=requires_grad)
# Flatten list of saved tensors
to_save = []
for ctx in basic_op_ctxs:
range_start = len(to_save)
if ctx.to_save is not None:
to_save.extend(ctx.to_save)
range_end = len(to_save)
ctx.to_save = None
ctx._saved_tensors_range = (range_start, range_end)
func_ctx.save_for_backward(*to_save)
# Other context for backward pass
func_ctx.backward_ops = backward_ops
func_ctx.basic_ops = basic_ops
func_ctx.basic_op_ctxs = basic_op_ctxs
func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
return x
@staticmethod
@torch.autograd.function.once_differentiable
def backward(
func_ctx: Any,
grad_output: torch.Tensor,
) -> tuple[Optional[torch.Tensor], ...]:
"""Backward pass"""
# Operations and autograd state
backward_ops = func_ctx.backward_ops
basic_ops = func_ctx.basic_ops
basic_op_ctxs = func_ctx.basic_op_ctxs
# Unflatten list of saved tensors
saved_tensors = func_ctx.saved_tensors
for ctx in basic_op_ctxs:
ctx.saved_tensors = saved_tensors[slice(*ctx._saved_tensors_range)]
ctx._saved_tensors_range = None
del saved_tensors
# Apply backward ops
dx = grad_output
grad_params = [None for _ in range(len(basic_ops))]
for op, basic_op_idxs in backward_ops:
# Stop if no more gradients are required
if all(not basic_op_ctxs[idx]._requires_grad for idx in basic_op_idxs):
dx = None
break
# Backward op
dx, fused_op_dparams = op.fuser_backward(
[basic_op_ctxs[idx] for idx in basic_op_idxs],
dx,
)
for idx, basic_op_dparams in zip(basic_op_idxs, fused_op_dparams):
grad_params[idx] = basic_op_dparams
basic_op_ctxs[idx].saved_tensors = None
# Flatten list of parameter gradients
grad_params_flat = []
for idx, dparams in enumerate(grad_params):
params = list(basic_ops[idx].parameters())
if dparams is None:
dparams = [None for _ in range(len(params))]
else:
dparams = list(dparams)
if len(dparams) != len(params):
raise RuntimeError(
f"Expected op {idx} to generate {len(params)} param grads, "
f"but got {len(dparams)}"
)
grad_params_flat.extend(dparams)
# Update FP8 scaling factors
if func_ctx.is_first_module and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
return (
dx, # input_
None, # forward_ops
None, # backward_ops
None, # basic_ops
None, # basic_op_kwargs
*grad_params_flat, # params
)
class OperationFuser:
"""Manages forward and backward passes for a pipeline of operations
Parameters
----------
ops: list of FusibleOperation
Pipeline of operations
fuse_ops: bool, default = `True`
Whether to attempt fusing operations
"""
def __init__(
self,
ops: list[FusibleOperation],
fuse_ops: bool = True,
) -> None:
# Get list of basic operations
basic_ops = []
for op in ops:
if op.is_fused_op:
basic_ops.extend(op.basic_ops)
else:
basic_ops.append(op)
self._num_basic_ops: int = len(basic_ops)
self._basic_ops: list[BasicOperation] = basic_ops
# Ops for forward and backward pass
self._forward_ops: list[tuple[FusibleOperation, list[int]]]
self._backward_ops: list[tuple[FusibleOperation, list[int]]]
self._forward_ops = [(op, (idx,)) for idx, op in enumerate(self._basic_ops)]
self._backward_ops = list(reversed(self._forward_ops))
# Fuse ops if needed
if fuse_ops:
self.fuse_ops()
@classmethod
def _fuse_forward_ops(
cls,
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Attempt to fuse operations in forward pass"""
ops = fuse_forward_linear_bias_activation(ops)
return ops
@classmethod
def _fuse_backward_ops(
cls,
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Attempt to fuse operations in backward pass"""
return ops
def fuse_ops(self) -> None:
"""Attempt to fuse operations"""
self._forward_ops = self._fuse_forward_ops(self._forward_ops)
self._backward_ops = self._fuse_backward_ops(self._backward_ops)
def __call__(
self,
input: torch.Tensor, # pylint: disable=redefined-builtin
basic_op_kwargs: Optional[list[dict[str, Any]]] = None,
) -> torch.Tensor:
# Initialization before forward pass
for op in self._basic_ops:
op.pre_forward()
# Canonicalize op kwargs
if basic_op_kwargs is None:
basic_op_kwargs = [{} for _ in range(len(self._basic_ops))]
# Flatten list of parameters
params = []
for op in self._basic_ops:
params.extend(op.parameters())
# Fuser forward pass
return _OperationFuserAutogradFunction.apply(
input,
self._forward_ops,
self._backward_ops,
self._basic_ops,
basic_op_kwargs,
*params,
)
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for linear layer."""
from __future__ import annotations
from collections.abc import Callable
from typing import Optional
import torch
from transformer_engine.pytorch.ops.basic import (
AllReduce,
BasicLinear,
Bias,
ReduceScatter,
)
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
from transformer_engine.pytorch.ops.op import FusedOperation
class Linear(FusedOperation):
"""Apply linear transformation: :math:`y = x A^T + b`
This is a drop-in replacement for `torch.nn.Linear`.
Parameters
----------
in_features: int
Inner dimension of input tensor
out_features: int
Inner dimension of output tensor
bias: bool, default = `True`
Apply additive bias
device: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
Tensor datatype
tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
Mode for tensor parallelism
tensor_parallel_group: torch.distributed.ProcessGroup, default = world group
Process group for tensor parallelism
sequence_parallel: bool, default = `False`
Whether to apply sequence parallelism together with tensor
parallelism, i.e. distributing input or output tensors along
outer dimension (sequence or batch dim) when not distributing
along inner dimension (embedding dim)
rng_state_tracker_function: callable
Function that returns CudaRNGStatesTracker, which is used for
model-parallel weight initialization
accumulate_into_main_grad: bool, default = `False`
Whether to directly accumulate weight gradients into the
weight's `main_grad` attribute instead of relying on PyTorch
autograd. The weight's `main_grad` must be set externally and
there is no guarantee that `grad` will be set or be
meaningful.
"""
def __init__(
self,
in_features: int,
out_features: int,
*,
bias: bool = True,
device: Optional[torch.device | str] = None,
dtype: Optional[torch.dtype] = None,
tensor_parallel_mode: Optional[str] = None,
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
sequence_parallel: bool = False,
rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None,
accumulate_into_main_grad: bool = False,
) -> None:
# Tensor parallel configuration
(
tensor_parallel_mode,
tensor_parallel_group,
tensor_parallel_size,
sequence_parallel,
local_in_features,
local_out_features,
) = BasicLinear._canonicalize_tensor_parallelism(
mode=tensor_parallel_mode,
process_group=tensor_parallel_group,
sequence_parallel=sequence_parallel,
in_features=in_features,
out_features=out_features,
)
# Construct basic ops
ops = []
linear_kwargs = dict(
in_features=in_features,
out_features=out_features,
device=device,
dtype=dtype,
tensor_parallel_mode=tensor_parallel_mode,
tensor_parallel_group=tensor_parallel_group,
sequence_parallel=sequence_parallel,
rng_state_tracker_function=rng_state_tracker_function,
accumulate_into_main_grad=accumulate_into_main_grad,
)
bias_kwargs = dict(
size=out_features,
device=device,
dtype=dtype,
tensor_parallel=(tensor_parallel_mode is not None),
tensor_parallel_group=tensor_parallel_group,
)
if tensor_parallel_mode == "row":
# Row TP: GEMM + bias + reduction
linear_kwargs["in_features"] = local_in_features
linear_kwargs["out_features"] = local_out_features
linear_kwargs["tensor_parallel_mode"] = None
linear_kwargs["tensor_parallel_group"] = None
linear_kwargs["sequence_parallel"] = False
bias_kwargs["size"] *= tensor_parallel_size
ops.append(BasicLinear(**linear_kwargs))
if bias:
ops.append(Bias(**bias_kwargs))
if sequence_parallel:
ops.append(ReduceScatter(tensor_parallel_group))
else:
ops.append(AllReduce(tensor_parallel_group))
else:
# Column TP or no TP: (gather + GEMM) + bias
ops.append(BasicLinear(**linear_kwargs))
if bias:
ops.append(Bias(**bias_kwargs))
# Initialize base class
super().__init__(ops)
# Register parameters
self.register_parameter("weight", self.basic_ops[0].weight)
self.register_parameter("bias", self.basic_ops[1].bias if bias else None)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment