Unverified Commit b1820c44 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Experimental FP8 tensor class (#452)



* Experimental FP8 tensor
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add fp8 tensor to ci test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* review comments and tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Minor changes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Default to FP8 usage
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Naming changes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* minor fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix transpose caching
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Debug transpose caching

Handle case where transpose cache is updated externally.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Rename FP8GlobalStateManager.with_fp8_parameters
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* remove Float8Tensor from import API
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Avoid caching FP8 transposes if not required
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix import error in FP8 tensor tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix tranpose caching and checkpointing bug
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Improve caching and fix distopt case
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update transformer_engine/pytorch/float8_tensor.py
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Remove recursive logic
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix cache reset bug
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Store FP8 attributes in dict

Easier for multiple tensors to share, e.g. detached tensors.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Make sure scale_inv is 1D tensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Make sure scale_inv is 1D tensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fixes and detach recipe
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Set default fp8 data type
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
parent 67051eff
......@@ -35,6 +35,8 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.fp8_autocast
.. autoapifunction:: transformer_engine.pytorch.fp8_model_init
.. autoapifunction:: transformer_engine.pytorch.checkpoint
.. autoapifunction:: transformer_engine.pytorch.onnx_export
......@@ -12,3 +12,4 @@ PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pyt
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_attn.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from collections.abc import Iterable
from typing import Any, Dict, List, Tuple, Union
import pytest
import torch
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine_extensions as tex
# PyTorch tensor dtypes
_dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16]
# TE FP8 dtypes
_fp8_dtypes: List[tex.DType] = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]
# Numerical tolerances with FP8 types
_tols: Dict[tex.DType, Dict[str, float]] = {
tex.DType.kFloat8E4M3: dict(rtol=0.125, atol=0.0675), # epsilon = 0.0625
tex.DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125), # epsilon = 0.125
}
def _to_list(x: Union[Iterable, Any]) -> List:
"""Convert to list if iterable, otherwise put in singleton list"""
if isinstance(x, Iterable):
return list(x)
else:
return [x]
# Types that can be interpreted as tensor dims
DimsType = Union[Iterable[int], int]
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFloat8Tensor:
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def test_constructor(
self,
dims: DimsType = 1,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale_inv: float = 0.375,
dtype: torch.dtype = torch.float32,
) -> None:
"""Call constructor and perform sanity checks"""
dims = _to_list(dims)
tensor = Float8Tensor(
data=torch.zeros(dims, device="cuda", dtype=torch.uint8),
fp8_dtype=fp8_dtype,
fp8_scale_inv=torch.full([1], scale_inv),
dtype=dtype,
)
assert list(tensor.size()) == dims, "Incorrect dims"
assert tensor.dtype == dtype, "Incorrect nominal dtype"
assert tensor.is_cuda, "Incorrect device"
def _test_quantize_dequantize(
self,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 3.5,
dtype: torch.dtype = torch.float32,
dims: DimsType = 23,
) -> None:
"""Check numerical error when casting to FP8 and back"""
# Initialize random data
x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cpu") - 1
# Cast to FP8 and back
x_fp8 = Float8Tensor.to_float8(
x_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x_fp8 = x_fp8.from_float8().cpu()
# Check results
torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype])
# Make sure we are not trivially passing the test
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8, -x_ref, **_tols[fp8_dtype])
@pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
@pytest.mark.parametrize("dtype", _dtypes)
def test_quantize_dequantize_dtypes(
self,
fp8_dtype: tex.DType,
dtype: torch.dtype,
) -> None:
self._test_quantize_dequantize(fp8_dtype=fp8_dtype, dtype=dtype)
@pytest.mark.parametrize("scale", [0.375, 1, 3.5])
def test_quantize_dequantize_scales(self, scale: float) -> None:
self._test_quantize_dequantize(scale=scale)
@pytest.mark.parametrize("dims", [[], 1, 311, [7,11], [7,5,3], [2,3,5,3]])
def test_quantize_dequantize_dims(self, dims: DimsType) -> None:
self._test_quantize_dequantize(dims=dims)
def test_fp8_meta(
self,
dtype: torch.dtype = torch.float32,
dims: DimsType = 23,
) -> None:
"""Construct Float8Tensor using FP8 metadata and perform basic checks"""
# Get FP8 metadata from linear module
fp8_dtype = tex.DType.kFloat8E4M3
recipe = transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
module = te.Linear(32, 32)
_ = module(torch.zeros([8, 32], device="cuda"))
fp8_meta = module.fp8_meta
fp8_meta_index = tex.FP8FwdTensors.GEMM1_WEIGHT
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)
# Initialize random data
dims = _to_list(dims)
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
# Make Float8Tensor
x_fp8 = Float8Tensor.to_float8(
x_ref,
fp8_meta=fp8_meta,
fp8_meta_index=fp8_meta_index,
)
x_ref = x_fp8.from_float8()
assert list(x_fp8.size()) == dims, "Incorrect dims"
assert x_fp8.dtype == dtype, "Incorrect nominal dtype"
assert x_fp8.is_cuda, "Incorrect device"
assert x_fp8._fp8_dtype == fp8_dtype, "Incorrect FP8 dtype"
# Change FP8 metadata scale
fp8_meta[fp8_meta_key].scale[fp8_meta_index] = 2
fp8_meta[fp8_meta_key].scale_inv.fill_(123)
# Check results
torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype])
with pytest.raises(AssertionError):
# Make sure we are not trivially passing the test
torch.testing.assert_close(x_fp8, -x_ref, **_tols[fp8_dtype])
# Check if scaling factor is updated after in-place ops
x_fp8 += 0
fp8_meta[fp8_meta_key].scale[fp8_meta_index] = 4
fp8_meta[fp8_meta_key].scale_inv.fill_(321)
assert x_fp8._scale_inv.item() == 0.5, "Incorrect FP8 scale_inv"
torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype])
y = x_fp8.detach()
y += 0
assert x_fp8._scale_inv.item() == 0.25, "Incorrect FP8 scale_inv"
torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype])
def test_basic_ops(
self,
dims: DimsType = 23,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 3.5,
dtype: torch.dtype = torch.float32,
) -> None:
"""Test basic out-of-place ops"""
# Initialize random data
dims = _to_list(dims)
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
x_fp8 = Float8Tensor.to_float8(
x_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
y_fp8 = Float8Tensor.to_float8(
y_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x_ref = x_fp8.from_float8()
y_ref = y_fp8.from_float8()
# Exact operations
torch.testing.assert_close(-x_fp8, -x_ref, rtol=0, atol=0)
torch.testing.assert_close(x_fp8.abs(), x_ref.abs(), rtol=0, atol=0)
# Operations with numerical error
tols = _tols[fp8_dtype]
torch.testing.assert_close(x_fp8 + y_fp8, x_ref + y_ref, **tols)
torch.testing.assert_close(x_fp8 - y_fp8, x_ref - y_ref, **tols)
torch.testing.assert_close(x_fp8 * y_fp8, x_ref * y_ref, **tols)
torch.testing.assert_close(x_fp8 + y_ref, x_ref + y_ref, **tols)
torch.testing.assert_close(x_ref + y_fp8, x_ref + y_ref, **tols)
torch.testing.assert_close(torch.sin(x_fp8), torch.sin(x_ref), **tols)
# Make sure we are not trivially passing tests
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8 + y_fp8, x_ref - y_fp8, **tols)
def test_inplace_ops(
self,
dims: DimsType = 23,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 3.5,
dtype: torch.dtype = torch.float32,
) -> None:
"""Test in-place ops"""
# Initialize random data
dims = _to_list(dims)
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
x_fp8 = Float8Tensor.to_float8(
x_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
y_fp8 = Float8Tensor.to_float8(
y_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x_ref = x_fp8.from_float8()
y_ref = y_fp8.from_float8()
# In-place operations
tols = _tols[fp8_dtype]
x_fp8 += y_ref
x_ref += y_ref
torch.testing.assert_close(x_fp8, x_ref, **tols)
x_ref = x_fp8.from_float8()
x_fp8 -= y_fp8
x_ref -= y_fp8
torch.testing.assert_close(x_fp8, x_ref, **tols)
x_ref = x_fp8.from_float8()
x_fp8 *= 2
x_ref *= 2
torch.testing.assert_close(x_fp8, x_ref, **tols)
x_ref = x_fp8.from_float8()
# Make sure we are not trivially passing tests
x_ref += 123
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8, x_ref, **tols)
@pytest.mark.parametrize("dims", [[33, 41], [5, 7, 11]])
@pytest.mark.parametrize("transpose_dims", [(0, 1), (-2, -1), (0, 0)])
def test_transpose(
self,
dims: DimsType,
transpose_dims: Tuple[int, int],
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 1,
dtype: torch.dtype = torch.float32,
) -> None:
"""Test transpose"""
# Initialize random data
dims = _to_list(dims)
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
x_fp8 = Float8Tensor.to_float8(
x_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x_ref = x_fp8.from_float8()
# Perform transpose
y_fp8 = x_fp8.transpose(*transpose_dims)
y_ref = x_ref.transpose(*transpose_dims)
# Check results
tols = dict(rtol=0, atol=0)
torch.testing.assert_close(y_fp8, y_ref, **tols)
# Make sure we are not trivially passing the test
if transpose_dims[0] != transpose_dims[1]:
with pytest.raises(AssertionError):
torch.testing.assert_close(
y_fp8,
x_ref,
**tols,
)
# Check transpose caching
if x_fp8.dim() == 2 and transpose_dims[0] != transpose_dims[1]:
x_fp8 += 0.5
x_ref = x_fp8.from_float8()
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache=True),
x_ref.transpose(*transpose_dims),
**tols,
)
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache=True),
x_ref.transpose(*transpose_dims),
**tols,
)
x_fp8 += 0.5
x_ref = x_fp8.from_float8()
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache=True),
x_ref.transpose(*transpose_dims),
**tols,
)
......@@ -12,7 +12,7 @@ import torch
import torch.nn as nn
from torch.nn import Parameter
from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager
from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager, fp8_model_init
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
......@@ -339,7 +339,7 @@ class TorchGPT(nn.Module):
return x
def _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False):
def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False):
reset_rng_states()
FP8GlobalStateManager.reset()
......@@ -354,24 +354,26 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False):
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
params_dtype=dtype,
with fp8_model_init(enabled=fp8 and fp8_model_params):
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
params_dtype=dtype,
fuse_qkv_params=True,
)
.cuda()
)
.cuda()
)
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
......@@ -400,18 +402,19 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False):
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("fp8", all_boolean)
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8):
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_params):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False)
outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=True)
outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False)
outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True)
assert_all_equal(outputs, outputs_recompute)
def _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False):
def _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False):
reset_rng_states()
FP8GlobalStateManager.reset()
......@@ -426,7 +429,8 @@ def _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False):
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
block = (
with fp8_model_init(enabled=fp8 and fp8_model_params):
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
......@@ -441,9 +445,10 @@ def _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False):
output_layernorm=False,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
params_dtype=dtype,
fuse_qkv_params=True,
)
.cuda()
)
)
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
......@@ -483,14 +488,15 @@ def _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False):
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("fp8", all_boolean)
def test_gpt_full_activation_recompute(dtype, bs, model, fp8):
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
outputs = _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False)
outputs_recompute = _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=True)
outputs = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False)
outputs_recompute = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True)
assert_all_equal(outputs, outputs_recompute)
......@@ -871,6 +877,7 @@ def test_linear_accuracy(dtype, bs, model):
else:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
......@@ -911,6 +918,7 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps):
else:
assert_allclose(te_outputs[0], torch_outputs[0], 2e-2)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
......@@ -1110,3 +1118,72 @@ def test_gpt_cuda_graph(dtype, bs, model):
assert_allclose(out, graphed_out, 1e-3)
assert_allclose(params, graphed_params, 1e-3)
assert_allclose(grads, graphed_grads, 1e-3)
def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params):
reset_rng_states()
FP8GlobalStateManager.reset()
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
with fp8_model_init(enabled=fp8_model_params):
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
params_dtype=dtype,
fuse_qkv_params=True,
)
.cuda()
)
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
with fp8_autocast(enabled=True):
te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
outputs = [te_out, te_inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_gpt_fp8_parameters(dtype, bs, model):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
outputs = _test_gpt_fp8_parameters(bs, dtype, config, False)
outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True)
assert_all_equal(outputs, outputs_fp8_params)
......@@ -147,7 +147,7 @@ def set_layer_scale(module: torch.nn.Module, scale: float, num_gemms: int):
"""Initialize the FP8 quantization scales in module"""
NB_SCALES_PER_GEMM = 3 # One scale per: input, weights, and output GEMM tensors.
nb_total_scales = num_gemms * NB_SCALES_PER_GEMM
module.fp8_init(num_gemms)
module.init_fp8_metadata(num_gemms)
module.fp8_meta["scaling_fwd"].scale = torch.ones(
nb_total_scales, dtype=torch.float32, device="cuda") / scale
module.fp8_meta["scaling_fwd"].scale_inv = torch.ones(
......
......@@ -16,7 +16,7 @@ import pytest
import torch
import transformer_engine.pytorch as te
import transformer_engine_extensions as tex
from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8, cast_from_fp8
from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8
from transformer_engine.pytorch.module.base import get_workspace
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
......@@ -93,7 +93,7 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd
model_in = Test_TE_Export(precision, True)
with te.fp8_autocast(enabled=True):
model_in.fp8_init()
model_in.init_fp8_metadata()
# scaling fwd
model_in.fp8_meta["scaling_fwd"].scale = torch.ones(3, dtype=torch.float32, device="cuda") * scale_fwd
model_in.fp8_meta["scaling_fwd"].scale_inv = torch.ones(3, dtype=torch.float32, device="cuda") / scale_fwd
......
......@@ -13,6 +13,7 @@ from .attention import InferenceParams
from .attention import MultiheadAttention
from .transformer import TransformerLayer
from .fp8 import fp8_autocast
from .fp8 import fp8_model_init
from .export import onnx_export
from .distributed import checkpoint
from .distributed import CudaRNGStatesTracker
......
......@@ -83,14 +83,16 @@ def initialize_affine_weight_gpu(
weight: torch.Tensor,
init_method: Callable,
get_rng_state_tracker: Callable,
partition_dim: int,
partition_dim: int = 0,
stride: int = 1,
set_tp_attributes: bool = True,
) -> None:
"""Initialize affine weight for model parallel on GPU."""
set_tensor_model_parallel_attributes(
tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
)
if set_tp_attributes:
set_tensor_model_parallel_attributes(
tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
)
if get_rng_state_tracker is None:
init_method(weight)
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tensor class with FP8 data"""
from __future__ import annotations
from typing import Any, Dict, Optional
import torch
from torch.utils._pytree import tree_map
import transformer_engine_extensions as tex
from .constants import TE_DType
from .fp8 import FP8GlobalStateManager
aten = torch.ops.aten
c10d = torch.ops.c10d
def _make_fp8_attr_property_funcs(name: str) -> Any:
"""Make accessors for an FP8 attribute
We store FP8 attributes in a dictionary so we can share them
between tensors with the same data, e.g. detached tensors. For
convenience, we also expose them as property attributes. This
function creates the accessors for property attributes.
Parameters
----------
name: str
Key in dictionary of FP8 attributes
"""
def get_func(self) -> Any:
return self._fp8_attrs[name]
def set_func(self, value: Any) -> None:
self._fp8_attrs[name] = value
def del_func(self) -> None:
del self._fp8_attrs[name]
return dict(fget=get_func, fset=set_func, fdel=del_func)
class _FromFloat8Func(torch.autograd.Function):
"""Cast from FP8 to other dtype"""
@staticmethod
def forward(
ctx,
tensor: Float8Tensor,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
if dtype is None:
dtype = tensor.dtype
data = tensor._data.contiguous().view(1,-1).detach()
out = tex.cast_from_fp8(
data,
tensor._scale_inv,
tensor._fp8_dtype,
TE_DType[dtype],
)
out = out.view(tensor.size())
return out
@staticmethod
def backward(ctx, grad):
# Assume that we want gradients in full precision
return grad, None
class _ToFloat8Func(torch.autograd.Function):
"""Cast to FP8 from other dtype"""
@staticmethod
def forward(
ctx,
tensor: torch.Tensor,
fp8_meta: Optional[Dict[str, Any]] = None,
fp8_meta_forward: bool = True,
fp8_meta_index: Optional[int] = None,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
):
# Manually compute scale-inverse if needed
if scale is not None and scale_inv is None:
if isinstance(scale, torch.Tensor):
scale_inv = scale.reciprocal()
else:
scale_inv = 1 / scale
# Extract data from FP8 meta tensors if provided
if fp8_meta is not None:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=fp8_meta_forward,
)
if fp8_meta_index is None:
raise ValueError(
"To initialize Float8Tensor with FP8 meta tensors, "
"the FP8 meta tensor index must also be provided"
)
if scale is None:
scale = fp8_meta[fp8_meta_key].scale[fp8_meta_index]
if amax is None:
amax = fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index]
if scale_inv is None:
scale_inv = fp8_meta[fp8_meta_key].scale_inv[fp8_meta_index]
scale_inv = scale_inv.detach().view(1).clone()
# Check input tensor
tensor = tensor.contiguous().cuda().detach()
if tensor.dtype not in (torch.float32, torch.bfloat16, torch.float16):
tensor = tensor.float()
# Check scale
if not isinstance(scale, torch.Tensor):
if scale is None:
scale = 1
scale = torch.full(
[1],
scale,
dtype=torch.float32,
device=tensor.device,
)
if scale.numel() != 1:
raise ValueError(
"Attempted to initialize Float8Tensor with invalid scale tensor"
)
scale = scale.to(device=tensor.device, dtype=torch.float32)
# Check scale-inverse
if scale_inv is None:
scale_inv = scale.reciprocal()
scale_inv = scale_inv.to(device=tensor.device, dtype=torch.float32)
# Check amax
if amax is None:
amax = torch.empty_like(scale)
if not (amax.numel() == 1 and amax.is_cuda and amax.dtype == torch.float32):
raise ValueError(
"Attempted to initialize Float8Tensor with invalid amax tensor"
)
# Cast data to FP8
data = tex.cast_to_fp8(
tensor.view(1,-1),
scale,
amax,
scale_inv,
fp8_dtype,
)
data = data.view(tensor.size())
# Construct FP8 tensor
return Float8Tensor(
data=data,
fp8_meta=fp8_meta,
fp8_meta_forward=fp8_meta_forward,
fp8_meta_index=fp8_meta_index,
fp8_dtype=fp8_dtype,
fp8_scale_inv=scale_inv,
dtype=tensor.dtype,
)
@staticmethod
def backward(ctx, grad):
# Assume that we want gradients in full precision
return grad, None, None, None, None, None, None, None
class _IdentityFunc(torch.autograd.Function):
"""Identity function
If constructor keyword-arguments are provided, then construct a
new Float8Tensor using the provided tensor's attributes.
"""
@staticmethod
def forward(
ctx,
tensor: Float8Tensor,
init_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
# Return input tensor if constructor kwargs are not provided
ctx.input_dtype = tensor.dtype
if init_kwargs is None:
return tensor
# Construct new tensor if constructor kwargs are provided
default_kwargs = dict(
data=tensor._data,
fp8_meta=tensor._fp8_meta,
fp8_meta_forward=tensor._fp8_meta_forward,
fp8_meta_index=tensor._fp8_meta_index,
fp8_dtype=tensor._fp8_dtype,
fp8_scale_inv=tensor._scale_inv,
dtype=tensor.dtype,
)
for key, val in default_kwargs.items():
if key not in init_kwargs:
init_kwargs[key] = val
return Float8Tensor(**init_kwargs)
@staticmethod
def backward(ctx, grad):
return grad.to(ctx.input_dtype), None
class Float8Tensor(torch.Tensor):
"""Experimental tensor class with FP8 data
The tensor presents as having a standard, higher-precision dtype,
but the data itself is (scaled) FP8. For most tensor operations,
the data will be cast to the nominal dtype before performing the
operation.
Parameters
----------
data: torch.Tensor
Raw FP8 data in a uint8 tensor
fp8_attrs: dict, optional
FP8 metadata, primarily managed by Float8Tensor. If
provided, all other FP8 configuration is ignored.
fp8_meta: dict, optional
FP8 metadata object, primarily managed by TE modules.
fp8_meta_forward: bool, default = `True`
Whether to access the FP8 metadata for the
forward pass. Ignored if fp8_meta is not
provided.
fp8_meta_index: int, optional
Index to access in FP8 meta tensors. Required if
fp8_meta is provided and otherwise ignored.
fp8_dtype: transformer_engine_extensions.DType, tex.DType.kFloat8E4M3
FP8 format.
fp8_scale_inv: torch.Tensor
Reciprocal of the scaling factor applied when
casting to FP8, i.e. the scaling factor that must
be applied when casting from FP8 to higher
precision. Can be inferred from fp8_meta if
provided.
dtype: torch.dtype, default = torch.float32
Nominal tensor datatype.
"""
def __new__(
cls,
*,
data: torch.Tensor,
fp8_attrs: Optional[Dict[str, Any]] = None,
fp8_meta: Optional[Dict[str, Any]] = None,
fp8_meta_forward: bool = True,
fp8_meta_index: Optional[int] = None,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
fp8_scale_inv: Optional[torch.Tensor] = None,
dtype: torch.dtype = torch.float32,
):
# Check that data buffer is valid
if data.element_size() != 1:
raise ValueError(
"Float8Tensor requires data buffer with 8-bit dtype "
f"(got dtype={data.dtype})"
)
if data.requires_grad:
raise ValueError(
"Float8Tensor requires non-differentiable data buffer"
)
data = data.cuda()
# Initialize tensor object
self = torch.Tensor._make_wrapper_subclass(
cls,
data.size(),
strides=data.stride(),
storage_offset=data.storage_offset(),
dtype=dtype,
layout=data.layout,
requires_grad=data.requires_grad,
device=data.device,
)
self._data: torch.Tensor = data
# Initialize dict of class attributes
# Note: We store FP8 attributes in a dictionary so we can
# share them between tensors with the same data, e.g. detached
# tensors.
self._fp8_attrs: dict = {}
if fp8_attrs is not None:
self._fp8_attrs = fp8_attrs
return self
# FP8 meta tensors
if fp8_meta is not None and fp8_meta_index is None:
raise ValueError(
"To initialize Float8Tensor with FP8 meta tensors, "
"the FP8 meta tensor index must also be provided"
)
self._fp8_meta: Optional[Dict[str, Any]] = fp8_meta
self._fp8_meta_forward: bool = fp8_meta_forward
self._fp8_meta_index: Optional[int] = fp8_meta_index
# FP8 dtype
assert (
fp8_dtype in (tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2)
), f"Unsupported fp8_dtype {fp8_dtype}."
self._fp8_dtype: tex.DType = fp8_dtype
# Cached transpose
self._transpose: Optional[Float8Tensor] = None
# FP8 scale-inverse
self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv
if self._scale_inv is None and self._fp8_meta is not None:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=self._fp8_meta_forward,
)
scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index]
self._scale_inv = scale_inv.detach().view(1).clone()
if self._scale_inv is None:
raise ValueError(
"Attempted to initialize Float8Tensor without specifying scale-inverse"
)
if not isinstance(self._scale_inv, torch.Tensor):
self._scale_inv = torch.full(
[1],
self._scale_inv,
dtype=torch.float32,
device=self._data.device,
)
if self._scale_inv.numel() != 1:
raise ValueError(
"Attempted to initialize Float8Tensor with invalid scale-inverse tensor"
)
self._scale_inv = self._scale_inv.to(
device=self._data.device,
dtype=torch.float32,
)
return self
@classmethod
def make_like(
cls,
tensor: Float8Tensor,
*,
data: torch.Tensor,
fp8_attrs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Float8Tensor:
"""Use attributes of a Float8Tensor to create another Float8Tensor
See constructor for list of keyword arguments.
"""
default_kwargs = dict(
fp8_meta=tensor._fp8_meta,
fp8_meta_forward=tensor._fp8_meta_forward,
fp8_meta_index=tensor._fp8_meta_index,
fp8_dtype=tensor._fp8_dtype,
fp8_scale_inv=tensor._scale_inv,
dtype=tensor.dtype,
)
for key, val in default_kwargs.items():
if key not in kwargs:
kwargs[key] = val
return Float8Tensor(data=data, fp8_attrs=fp8_attrs, **kwargs)
def __repr__(self):
return (
"Float8Tensor("
f"fp8_dtype={self._fp8_dtype}, "
f"scale_inv={self._scale_inv.item()}, "
f"data={self.from_float8(dtype=self.dtype)}"
")"
)
def from_float8(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""
Construct plain PyTorch tensor from Float8Tensor
By default the resulting tensor's dtype is the
Float8Tensor's nominal dtype.
"""
return _FromFloat8Func.apply(self, dtype)
@classmethod
def to_float8(
cls,
tensor: torch.Tensor,
*,
fp8_meta: Optional[Dict[str, Any]] = None,
fp8_meta_forward: bool = True,
fp8_meta_index: Optional[int] = None,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
):
"""Construct Float8Tensor from plain PyTorch tensor"""
return _ToFloat8Func.apply(
tensor,
fp8_meta,
fp8_meta_forward,
fp8_meta_index,
fp8_dtype,
scale,
amax,
scale_inv,
)
def float(self) -> torch.Tensor:
return self.from_float8(dtype=torch.float32)
def bfloat16(self) -> torch.Tensor:
return self.from_float8(dtype=torch.bfloat16)
def half(self) -> torch.Tensor:
return self.from_float8(dtype=torch.float16)
def cpu(self) -> torch.Tensor:
return self.from_float8().cpu()
def clone(self) -> Float8Tensor:
return _IdentityFunc.apply(self, {"data": self._data.detach().clone()})
def expand_as(self, other: torch.Tensor):
if other is self:
# Note: expand_as is hackily used to create dummy autograd nodes
# and access the backward graph (see
# https://github.com/pytorch/pytorch/blob/238fb660851268f44ff88127887041fea352fe48/torch/nn/parallel/distributed.py#L1026).
# We equally hackily add a dummy function to handle this
# case.
return _IdentityFunc.apply(self)
return super().expand_as(other)
def _transpose_no_cache(self) -> torch.Tensor:
"""
Swap tensor dimensions
For basic 2D matrix transposes, an optimized transpose kernel
is applied and a Float8Tensor is returned.
"""
# Use optimized kernel for basic 2D transpose
# TODO Support differentiation # pylint: disable=fixme
return Float8Tensor.make_like(
self,
data=tex.fp8_transpose(
self._data.contiguous().detach(),
self._fp8_dtype,
),
)
def transpose(
self,
dim0: int = 0,
dim1: int = 1,
*,
update_cache: Optional[bool] = None,
) -> torch.Tensor:
"""
Swap tensor dimensions
For basic 2D matrix transposes, an optimized transpose kernel
is applied and a Float8Tensor is returned.
Parameters
----------
dim0: int, default = 0
The first dimension to be transposed
dim1: int, default = 1
The second dimension to be transposed
update_cache: Optional[bool], default = None
If set to `True`, the result is computed and stored in a cache.
If set to `False`, the result is computed only if the cache is
empty, otherwise the cache is returned. If set to `None`, the
result is not cached. Caching is only supported for basic 2D
transposes and the cache is reset after any in-place operations.
"""
# Handle non-2D transposes
if -self.dim() <= dim0 < 0:
dim0 += self.dim()
if -self.dim() <= dim1 < 0:
dim1 += self.dim()
if self.dim() != 2 or dim0 == dim1:
if update_cache is not None:
raise ValueError(
"Transpose caching is only supported for basic 2D transposes "
f"(ndims={self.dim()}, dim0={dim0}, dim1={dim1})"
)
return super().transpose(dim0, dim1)
# No caching.
if update_cache is None:
return self._transpose_no_cache()
# Update cache.
if update_cache or self._transpose is None:
self._transpose = self._transpose_no_cache()
return self._transpose
@torch.no_grad()
def reset_fp8_meta_scale_inv(self) -> None:
"""Replace FP8 meta tensor scale-inverse with cached value
The FP8 meta tensor scale_inv entry corresponding to this
tensor is replaced with the scale_inv value used to construct
the tensor.
"""
if self._fp8_meta is None:
return
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=self._fp8_meta_forward,
)
scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index]
scale_inv.view(1).copy_(self._scale_inv.view(1))
def to_dtype(self, dtype: torch.dtype) -> Float8Tensor:
"""Create `Float8Tensor` with given nominal dtype
The new tensor has the same underlying FP8 data.
"""
return Float8Tensor.make_like(
self,
data=self._data,
fp8_attrs=self._fp8_attrs,
dtype=dtype,
)
def _reset_caches(self) -> None:
"""Reset cached values
Should be called after any in-place operation.
"""
self._transpose = None
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
# In-place copy op
if func == aten.copy_.default:
# Check tensors
dst = args[0]
src = args[1]
if not isinstance(dst, Float8Tensor):
raise RuntimeError("Expected to copy into Float8Tensor")
if not isinstance(src, torch.Tensor):
raise RuntimeError("Expected to copy from tensor")
if not dst._data.is_contiguous():
raise RuntimeError("Transformer Engine cast kernels require contiguous data")
# Make sure input is in expected format
if isinstance(src, Float8Tensor):
src = src.from_float8()
src = src.expand(dst.size())
src = src.to(
device=dst.device,
memory_format=torch.contiguous_format,
)
# Update scaling factor if FP8 meta tensors are available
if dst._fp8_meta is None:
scale = dst._scale_inv.reciprocal()
else:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=dst._fp8_meta_forward,
)
scale = dst._fp8_meta[fp8_meta_key].scale[dst._fp8_meta_index]
dst._scale_inv = scale.detach().view(1).reciprocal()
# Cast to FP8
tex.cast_to_fp8_noalloc(
src.view(1,-1),
scale,
dst._data.view(1,-1),
torch.empty_like(dst._scale_inv), # amax
dst._scale_inv,
dst._fp8_dtype,
)
# Nothing to return for in-place ops
dst._reset_caches()
return None
# Slice op
# TODO Consider additional bookkeeping so we invalidate caches # pylint: disable=fixme
# if these slices are modified in-place
if func == aten.slice.Tensor:
tensor = args[0]
data = tensor._data
data_slice = data.__torch_dispatch__(
func,
types,
[data] + list(args[1:]),
kwargs,
)
return Float8Tensor.make_like(tensor, data=data_slice)
# Detach op
if func == aten.detach.default:
# Simply return a new Float8Tensor with the same attrs
return Float8Tensor.make_like(
args[0],
data=args[0]._data,
fp8_attrs=args[0]._fp8_attrs,
)
def maybe_unwrap(t):
if isinstance(t, Float8Tensor):
return t.from_float8()
return t
def maybe_update_inplace(arg, new_arg, schema_arg):
"""Update values of FP8 tensors
Keep the same FP8 scaling factors.
"""
if(
isinstance(arg, Float8Tensor) and
isinstance(new_arg, torch.Tensor) and
hasattr(schema_arg, 'alias_info') and
hasattr(schema_arg.alias_info, 'is_write') and
schema_arg.alias_info.is_write
):
arg.copy_(new_arg)
arg._reset_caches()
# In-place op
if func._schema.is_mutable:
# Cast to higher precision, perform op, and cast values
# back to original FP8 buffers
new_args = tree_map(maybe_unwrap, args)
new_kwargs = tree_map(maybe_unwrap, kwargs)
schema_args = func._schema.arguments
args_len = len(args)
out = super().__torch_dispatch__(func, types, new_args, new_kwargs)
for arg, new_arg, schema_arg in zip(args, new_args, schema_args):
maybe_update_inplace(arg, new_arg, schema_arg)
for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]):
assert kwarg == new_kwarg == schema_arg.name, "name of the kw argument should match"
maybe_update_inplace(kwargs[kwarg], new_kwargs[new_kwarg], schema_arg)
return None
# Default op
# Note: cast to higher precision and perform op
args = tree_map(maybe_unwrap, args)
if kwargs is not None:
kwargs = tree_map(maybe_unwrap, kwargs)
out = super().__torch_dispatch__(func, types, args, kwargs)
return out
def _get_data(self) -> Float8Tensor:
"""Get tensor data property"""
return super().data
def _set_data(self, tensor: torch.Tensor) -> None:
"""Set tensor data property
Cast tensor to FP8 and store in FP8 buffer.
"""
with torch.no_grad():
self.copy_(tensor)
# Cast to FP8 when setting Float8Tensor.data
data = property(_get_data, _set_data)
# Accessors for objects in self._fp8_attrs
# Note: We store FP8 attributes in a dictionary so we can share
# them between tensors with the same data, e.g. detached tensors.
# For convenience, we also expose them as property attributes.
_fp8_meta = property(**_make_fp8_attr_property_funcs("fp8_meta"))
_fp8_meta_forward = property(**_make_fp8_attr_property_funcs("fp8_meta_forward"))
_fp8_meta_index = property(**_make_fp8_attr_property_funcs("fp8_meta_index"))
_fp8_dtype = property(**_make_fp8_attr_property_funcs("dtype"))
_transpose = property(**_make_fp8_attr_property_funcs("transpose"))
_scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv"))
# Do not force the Float8Tensor type on the returned tensor
__torch_function__ = torch._C._disabled_torch_function_impl
......@@ -17,7 +17,7 @@ from .utils import get_device_compute_capability
from .jit import jit_fuser
__all__ = ["fp8_autocast"]
__all__ = ["fp8_autocast", "fp8_model_init"]
def check_fp8_support() -> Tuple[bool, str]:
......@@ -59,6 +59,7 @@ class FP8GlobalStateManager:
FP8_CALIBRATION = False
FP8_RECIPE = None
FP8_DISTRIBUTED_GROUP = None
FP8_PARAMETERS = False
IS_FIRST_FP8_MODULE = False
FP8_AUTOCAST_COUNTER = 0
FP8_CURRENT_CONTEXT_ID = 0
......@@ -277,6 +278,11 @@ class FP8GlobalStateManager:
"""Is FP8 calibration"""
return cls.FP8_CALIBRATION
@classmethod
def with_fp8_parameters(cls) -> bool:
"""Should the parameters be stored as FP8"""
return cls.FP8_PARAMETERS
@classmethod
def is_first_fp8_module(cls):
"""Returns `True` only the first time when called multiple
......@@ -400,6 +406,11 @@ class FP8GlobalStateManager:
fp8_group: Optional[dist_group_type] = None,
) -> None:
"""Set state and tracking variables for entry into FP8 region."""
if cls.FP8_AUTOCAST_DEPTH == 0:
if callable(cls.amax_forward_global_reduce_func):
cls.amax_reduce_handle_fwd = cls.amax_forward_global_reduce_func() # pylint: disable=not-callable
cls.delete_key_from_amax_buffer(forward=True)
cls.FP8_ENABLED = enabled
cls.FP8_CALIBRATION = calibrating
cls.FP8_RECIPE = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
......@@ -419,11 +430,6 @@ class FP8GlobalStateManager:
"""Set state and tracking variables for exit from FP8 region."""
cls.FP8_AUTOCAST_DEPTH -= 1
if cls.FP8_AUTOCAST_DEPTH == 0:
if callable(cls.amax_forward_global_reduce_func):
cls.amax_reduce_handle_fwd = cls.amax_forward_global_reduce_func() # pylint: disable=not-callable
cls.delete_key_from_amax_buffer(forward=True)
@classmethod
def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None:
"""Copy the scaling factors and amaxes for recompute forward phase
......@@ -477,9 +483,45 @@ class FP8GlobalStateManager:
fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"]
@contextmanager
def fp8_model_init(enabled: bool = True) -> None:
"""
Context manager for FP8 initialization of parameters.
Example usage:
.. code-block:: python
with fp8_model_init(enabled=True):
model = transformer_engine.pytorch.Linear(768, 768)
Parameters
----------
enabled: bool, default = `True`
when enabled, Transformer Engine modules created inside this `fp8_model_init`
region will hold only FP8 copies of its parameters, as opposed to the default
behavior where both higher precision and FP8 copies are present. Setting this
option to `True` may result in lower memory consumption and is especially
useful for scenarios like:
* full model training using optimizer with master weights, where the high
precision copies of weights are already present in the optimizer.
* inference, where only the FP8 copies of the parameters are used.
* LoRA-like fine-tuning, where the main parameters of the model do not change.
This functionality is *EXPERIMENTAL*.
"""
try:
_fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS
FP8GlobalStateManager.FP8_PARAMETERS = enabled
yield
finally:
FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters # pylint: disable=used-before-assignment
@contextmanager
def fp8_autocast(
enabled: bool = False,
enabled: bool = True,
calibrating: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
fp8_group: Optional[dist_group_type] = None,
......@@ -508,7 +550,7 @@ def fp8_autocast(
Parameters
----------
enabled: bool, default = `False`
enabled: bool, default = `True`
whether or not to enable fp8
calibrating: bool, default = `False`
calibration mode allows collecting statistics such as amax and scale
......@@ -523,7 +565,10 @@ def fp8_autocast(
"""
try:
fp8_state = FP8GlobalStateManager.get_fp8_autocast_state()
FP8GlobalStateManager.fp8_autocast_enter(enabled, calibrating, fp8_recipe, fp8_group)
FP8GlobalStateManager.fp8_autocast_enter(enabled=enabled,
calibrating=calibrating,
fp8_recipe=fp8_recipe,
fp8_group=fp8_group)
yield
finally:
FP8GlobalStateManager.set_fp8_autocast_state(fp8_state) # pylint: disable=used-before-assignment
......
......@@ -36,6 +36,7 @@ from ..cpp_extensions import (
cast_to_fp8,
)
from ..constants import dist_group_type
from ..float8_tensor import Float8Tensor
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
......@@ -451,21 +452,29 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
setattr(
self,
weight_cast_attr,
torch.empty(
shape,
device=torch.cuda.current_device(),
dtype=torch.uint8,
),
Float8Tensor(
data=torch.empty(
shape,
device=torch.cuda.current_device(),
dtype=torch.uint8,
),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_scale_inv=1,
)
)
setattr(
self,
weight_transpose_attr,
torch.empty(
shape[1],
shape[0],
device=torch.cuda.current_device(),
dtype=torch.uint8,
),
Float8Tensor(
data=torch.empty(
shape[1],
shape[0],
device=torch.cuda.current_device(),
dtype=torch.uint8,
),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_scale_inv=1,
)
)
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
......@@ -483,12 +492,17 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# This routine is shared across FP8 and FP8_calibration paths so should not actually
# assume FP8 execution.
def fp8_init(self, num_gemms: int = 1) -> None:
def init_fp8_metadata(self, num_gemms: int = 1) -> None:
"""Initialize fp8 related metadata and tensors during fprop."""
self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
if self.fp8_parameters and not self.fp8_initialized:
self.fp8_meta["num_gemms"] = num_gemms
self.init_fp8_meta_tensors()
if self.fp8 or self.fp8_calibration:
# FP8 init has already been run and recipe is the same, don't do anything.
if (self.fp8_initialized
......@@ -536,7 +550,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
assert self.tp_group_initialized, "TP group not initialized."
self.set_activation_dtype(inp)
self.fp8_init(num_gemms=num_gemms)
self.init_fp8_metadata(num_gemms=num_gemms)
# Create persistent tensors for fp8 weights and their transposes
# only when fp8 weight caching is used.
......@@ -765,7 +779,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def get_fp8_weights_empty_tensors(
self,
is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]:
) -> List[Float8Tensor]:
"""
Returns empty tensors to be later used to store fp8 version of weights
and their transposes (for the bwd pass) for this batch (or microbatch).
......@@ -781,23 +795,42 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
fp8_weight_tensors = []
for shape in self.fp8_weight_shapes:
fp8_weight_tensors.append(
torch.empty(
shape,
device=torch.cuda.current_device(),
dtype=torch.uint8,
Float8Tensor(
data=torch.empty(
shape,
device=torch.cuda.current_device(),
dtype=torch.uint8,
),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_scale_inv=1,
)
)
fp8_weight_tensors.append(
torch.empty(
shape[1],
shape[0],
device=torch.cuda.current_device(),
dtype=torch.uint8,
Float8Tensor(
data=torch.empty(
shape[1],
shape[0],
device=torch.cuda.current_device(),
dtype=torch.uint8,
),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_scale_inv=1,
)
)
return fp8_weight_tensors
def state_dict(self, *args, **kwargs) -> Dict:
"""Get dictionary containing module state"""
state = super().state_dict(*args, **kwargs)
# Convert Float8Tensors to plain tensors
# Note: Float8Tensors don't serialize well, especially if they
# contain references to FP8 metadata.
for key, val in state.items():
if isinstance(val, Float8Tensor):
state[key] = val.from_float8()
return state
@abstractmethod
def forward(self):
......
......@@ -23,7 +23,7 @@ from .base import (
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ..fp8 import get_fp8_te_dtype
from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager
from ..utils import (
divide,
get_default_init_method,
......@@ -43,6 +43,7 @@ from ..jit import no_torch_dynamo
from ._common import _apply_normalization
from ..float8_tensor import Float8Tensor
__all__ = ["LayerNormLinear"]
......@@ -79,10 +80,11 @@ class _LayerNormLinear(torch.autograd.Function):
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
normalization: str,
primary_weights_in_fp8: bool,
ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool,
ub_split_ag: bool,
normalization: str,
ub_atomic_gemm_ag: bool,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
......@@ -159,28 +161,43 @@ class _LayerNormLinear(torch.autograd.Function):
)
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
if update_fp8_weights:
if primary_weights_in_fp8:
# Weight is already in FP8
weight.reset_fp8_meta_scale_inv()
weight_fp8 = weight
weight_t_fp8 = None
if is_grad_enabled:
weight_t_fp8 = weight_fp8.transpose(update_cache=is_first_microbatch)
elif update_fp8_weights:
# Need to cast weights to FP8
weight_fp8 = Float8Tensor(
data=weight_fp8._data,
fp8_meta=fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
if is_grad_enabled:
tex.fp8_cast_transpose_fused(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
cast_out=weight_fp8,
transpose_out=weight_t_fp8,
cast_out=weight_fp8._data,
transpose_out=weight_t_fp8._data,
)
else:
weight_t_fp8 = None
weight_fp8 = tex.cast_to_fp8(
weight_fp8._data = tex.cast_to_fp8(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward)
fp8_dtype_forward,
)
weight_t_fp8 = None
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo
out, _ = tex.fp8_gemm(
weight_fp8,
weight_fp8._data,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
......@@ -356,7 +373,7 @@ class _LayerNormLinear(torch.autograd.Function):
# DGRAD: Evaluated unconditionally to feed into Linear backward
_ = tex.fp8_gemm(
weight_t_fp8,
weight_t_fp8._data,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
......@@ -544,6 +561,7 @@ class _LayerNormLinear(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -646,10 +664,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
return_layernorm_output: bool = False,
parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
zero_centered_gamma: bool = False,
device: Union[torch.device, str] = "cuda",
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_split_ag: bool = False,
device: Union[torch.device, str] = "cuda",
ub_atomic_gemm_ag: bool = False,
) -> None:
super().__init__()
......@@ -666,6 +684,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.return_layernorm_output = return_layernorm_output
self.parameters_split = parameters_split
self.zero_centered_gamma = zero_centered_gamma
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.ub_bulk_wgrad = ub_bulk_wgrad
self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_split_ag = ub_split_ag
......@@ -719,18 +738,30 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.layer_norm_bias = None
self.reset_layer_norm_parameters()
self.weight_tensor = torch.empty(
temp_weight = torch.empty(
self.out_features, self.in_features,
device=device, dtype=params_dtype)
initialize_affine_weight_gpu(
self.weight_tensor,
temp_weight,
init_method,
get_rng_state_tracker,
partition_dim=1 if self.parallel_mode == "row" else 0,
stride=1,
)
if self.primary_weights_in_fp8:
self.init_fp8_metadata()
self.fp8_meta["update_amax_and_scale_fwd"] = True
self.weight_tensor = Float8Tensor.to_float8(
temp_weight,
fp8_meta=self.fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
else:
self.weight_tensor = temp_weight
if self.use_bias:
self.bias_tensor = torch.empty(
self.out_features,
......@@ -769,10 +800,17 @@ class LayerNormLinear(TransformerEngineBaseModule):
bname = pname + "bias"
slice_end = slice_begin + slice_size
self.register_parameter(
wname, Parameter(self.weight_tensor[slice_begin:slice_end])
)
# NOTE(future): Figure out a way to support slicing when weights
# are of `Float8Tensor` class
if self.primary_weights_in_fp8:
assert len(parameters_split) == 1, ("Slicing operation is not "
"supported in Float8Tensor "
"class!")
self.register_parameter(wname, Parameter(self.weight_tensor))
else:
self.register_parameter(
wname, Parameter(self.weight_tensor[slice_begin:slice_end])
)
set_tensor_model_parallel_attributes(
tensor=getattr(self, wname),
......@@ -833,7 +871,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
`is_first_microbatch` is not `None`) or return empty fp8 weight
tensors (if `is_first_microbatch is None`)
"""
if not self.fp8:
if not self.fp8 or self.primary_weights_in_fp8:
return [None, None]
if is_first_microbatch is None:
......@@ -877,6 +915,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
"""
with self.prepare_forward(inp, is_first_microbatch) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8."
bias_tensor = (
self.bias if self.parameters_split is None
else self.bias_tensor if not torch.is_grad_enabled()
......@@ -927,10 +967,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.normalization,
self.primary_weights_in_fp8,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_split_ag,
self.normalization,
self.ub_atomic_gemm_ag,
)
out = fwd_fn(*args)
......
......@@ -20,7 +20,7 @@ from .base import (
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ..fp8 import get_fp8_te_dtype
from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager
from ..jit import (
bias_gelu_fused,
bgrad_dgelu_fused,
......@@ -47,6 +47,7 @@ from .. import cpp_extensions as tex
from ..constants import dist_group_type, TE_DType
from ..jit import no_torch_dynamo
from ..float8_tensor import Float8Tensor
from ._common import _apply_normalization
......@@ -105,14 +106,15 @@ class _LayerNormMLP(torch.autograd.Function):
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
activation: str,
normalization: str,
primary_weights_in_fp8: bool,
ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool,
ub_split_rs: bool,
ub_atomic_gemm_rs: bool,
ub_split_ag: bool,
ub_atomic_gemm_ag: bool,
activation: str,
normalization: str,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
......@@ -196,45 +198,68 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_bias = cast_if_needed(fc1_bias, bias_dtype) if use_fc1_bias else fc1_bias
fc2_bias = cast_if_needed(fc2_bias, bias_dtype) if use_fc2_bias else fc2_bias
if update_fp8_weights:
if primary_weights_in_fp8:
# Weights are already in FP8
fc1_weight.reset_fp8_meta_scale_inv()
fc2_weight.reset_fp8_meta_scale_inv()
fc1_weight_fp8 = fc1_weight
fc2_weight_fp8 = fc2_weight
fc1_weight_t_fp8 = None
fc2_weight_t_fp8 = None
if is_grad_enabled:
fc1_weight_t_fp8 = fc1_weight_fp8.transpose(update_cache=is_first_microbatch)
fc2_weight_t_fp8 = fc2_weight_fp8.transpose(update_cache=is_first_microbatch)
elif update_fp8_weights:
# Need to cast weights to FP8
fc1_weight_fp8 = Float8Tensor(
data=fc1_weight_fp8._data,
fp8_meta=fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
fc2_weight_fp8 = Float8Tensor(
data=fc2_weight_fp8._data,
fp8_meta=fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT,
)
if is_grad_enabled:
# Fused cast-transpose kernels
tex.fp8_cast_transpose_fused(
fc1_weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
cast_out=fc1_weight_fp8,
transpose_out=fc1_weight_t_fp8,
cast_out=fc1_weight_fp8._data,
transpose_out=fc1_weight_t_fp8._data,
)
tex.fp8_cast_transpose_fused(
fc2_weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward,
cast_out=fc2_weight_fp8,
transpose_out=fc2_weight_t_fp8,
cast_out=fc2_weight_fp8._data,
transpose_out=fc2_weight_t_fp8._data,
)
else:
fc1_weight_t_fp8 = None
fc1_weight_fp8 = tex.cast_to_fp8(
fc1_weight_fp8._data = tex.cast_to_fp8(
fc1_weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
)
fc2_weight_t_fp8 = None
fc2_weight_fp8 = tex.cast_to_fp8(
fc1_weight_t_fp8 = None
fc2_weight_fp8._data = tex.cast_to_fp8(
fc2_weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward,
)
fc2_weight_t_fp8 = None
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo
fc1_out, _ = tex.fp8_gemm(
fc1_weight_fp8,
fc1_weight_fp8._data,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
......@@ -283,7 +308,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo
_ = tex.fp8_gemm(
fc2_weight_fp8,
fc2_weight_fp8._data,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward,
......@@ -530,7 +555,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ctx.ub_atomic_gemm_ag else ub_algo
# FC2 DGRAD; Unconditional
fc2_dgrad, _ = tex.fp8_gemm(
fc2_weight_t_fp8,
fc2_weight_t_fp8._data,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward,
......@@ -645,7 +670,7 @@ class _LayerNormMLP(torch.autograd.Function):
)
# FC1 DGRAD: Unconditional
_ = tex.fp8_gemm(
fc1_weight_t_fp8,
fc1_weight_t_fp8._data,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
......@@ -908,6 +933,7 @@ class _LayerNormMLP(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -1020,12 +1046,12 @@ class LayerNormMLP(TransformerEngineBaseModule):
micro_batch_size: Optional[int] = None,
set_parallel_mode: bool = False,
zero_centered_gamma: bool = False,
device: Union[torch.device, str] = "cuda",
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_split_rs: bool = False,
ub_atomic_gemm_rs: bool = False,
ub_split_ag: bool = False,
device: Union[torch.device, str] = "cuda",
ub_atomic_gemm_ag: bool = False,
) -> None:
super().__init__()
......@@ -1043,6 +1069,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.activation == 'gelu')
self.set_parallel_mode = set_parallel_mode
self.zero_centered_gamma = zero_centered_gamma
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.ub_bulk_wgrad = ub_bulk_wgrad
self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_split_rs = ub_split_rs
......@@ -1102,19 +1129,30 @@ class LayerNormMLP(TransformerEngineBaseModule):
else:
fc1_output_features = self.size_per_partition
# FC1 init
self.fc1_weight = Parameter(
torch.empty(fc1_output_features, hidden_size, device=device, dtype=params_dtype)
)
self.fp8_weight_shapes.append(self.fc1_weight.shape)
fc1_temp_weight = torch.empty(
fc1_output_features, hidden_size, device=device, dtype=params_dtype)
initialize_affine_weight_gpu(
self.fc1_weight,
fc1_temp_weight,
init_method,
get_rng_state_tracker,
partition_dim=0,
stride=1,
set_tp_attributes=False,
)
if self.primary_weights_in_fp8:
self.init_fp8_metadata(num_gemms=2)
self.fp8_meta["update_amax_and_scale_fwd"] = True
fc1_temp_weight = Float8Tensor.to_float8(
fc1_temp_weight,
fp8_meta=self.fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
self.fc1_weight = Parameter(fc1_temp_weight)
set_tensor_model_parallel_attributes(self.fc1_weight, True, 0, 1)
self.fp8_weight_shapes.append(self.fc1_weight.shape)
if self.use_bias:
self.fc1_bias = Parameter(
torch.empty(fc1_output_features, device=device, dtype=params_dtype)
......@@ -1127,19 +1165,27 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fc1_bias.zero_()
# FC2 init
self.fc2_weight = Parameter(
torch.empty(hidden_size, self.size_per_partition, device=device, dtype=params_dtype)
)
self.fp8_weight_shapes.append(self.fc2_weight.shape)
fc2_temp_weight = torch.empty(
hidden_size, self.size_per_partition, device=device, dtype=params_dtype)
initialize_affine_weight_gpu(
self.fc2_weight,
fc2_temp_weight,
output_layer_init_method,
get_rng_state_tracker,
partition_dim=1,
stride=1,
set_tp_attributes=False,
)
if self.primary_weights_in_fp8:
fc2_temp_weight = Float8Tensor.to_float8(
fc2_temp_weight,
fp8_meta=self.fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT,
)
self.fc2_weight = Parameter(fc2_temp_weight)
set_tensor_model_parallel_attributes(self.fc2_weight, True, 1, 1)
self.fp8_weight_shapes.append(self.fc2_weight.shape)
if self.use_bias:
self.fc2_bias = Parameter(
torch.empty(hidden_size, device=device, dtype=params_dtype)
......@@ -1192,7 +1238,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
`is_first_microbatch` is not `None`) or return empty fp8 weight
tensors (if `is_first_microbatch is None`)
"""
if not self.fp8:
if not self.fp8 or self.primary_weights_in_fp8:
return [None, None, None, None]
if is_first_microbatch is None:
......@@ -1235,6 +1281,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
"""
with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8."
# Fetch the fp8 weights placeholders (for linear/gemm)
weight1_fp8, weight1_t_fp8, weight2_fp8, weight2_t_fp8 = \
self.get_fp8_weights_scratchpad(
......@@ -1279,14 +1327,15 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.activation,
self.normalization,
self.primary_weights_in_fp8,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_split_rs,
self.ub_atomic_gemm_rs,
self.ub_split_ag,
self.ub_atomic_gemm_ag,
self.activation,
self.normalization,
)
out = fwd_fn(*args)
......
......@@ -20,7 +20,7 @@ from .base import (
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ..fp8 import get_fp8_te_dtype
from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager
from ..utils import (
divide,
get_default_init_method,
......@@ -45,6 +45,8 @@ from ..cpp_extensions import (
from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
from ..float8_tensor import Float8Tensor
__all__ = ["Linear"]
......@@ -57,9 +59,9 @@ class _Linear(torch.autograd.Function):
@staticmethod
def forward(
ctx,
weight: torch.Tensor,
weight_fp8: Union[torch.Tensor, None],
weight_t_fp8: Union[torch.Tensor, None],
weight: Union[Float8Tensor, torch.Tensor],
weight_fp8: Union[Float8Tensor, None],
weight_t_fp8: Union[Float8Tensor, None],
inp: torch.Tensor,
bias: torch.Tensor,
use_bias: bool,
......@@ -75,6 +77,7 @@ class _Linear(torch.autograd.Function):
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
is_grad_enabled: bool,
primary_weights_in_fp8: bool,
ub_split_rs: bool,
ub_split_ag: bool,
ub_atomic_gemm_rs: bool,
......@@ -141,24 +144,38 @@ class _Linear(torch.autograd.Function):
)
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
if update_fp8_weights:
if primary_weights_in_fp8:
# Weight is already in FP8
weight.reset_fp8_meta_scale_inv()
weight_fp8 = weight
weight_t_fp8 = None
if is_grad_enabled:
weight_t_fp8 = weight_fp8.transpose(update_cache=is_first_microbatch)
elif update_fp8_weights:
# Need to cast weights to FP8
weight_fp8 = Float8Tensor(
data=weight_fp8._data,
fp8_meta=fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
if is_grad_enabled:
fp8_cast_transpose_fused(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
cast_out=weight_fp8,
transpose_out=weight_t_fp8,
cast_out=weight_fp8._data,
transpose_out=weight_t_fp8._data,
)
else:
weight_t_fp8 = None
weight_fp8 = cast_to_fp8(
weight_fp8._data = cast_to_fp8(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
)
weight_t_fp8 = None
proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = (
None, None, None, activation_dtype)
......@@ -184,7 +201,7 @@ class _Linear(torch.autograd.Function):
ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo
_ = fp8_gemm(
weight_fp8,
weight_fp8._data,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
......@@ -245,6 +262,9 @@ class _Linear(torch.autograd.Function):
if is_grad_enabled:
fp8_wgrad = fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad
if fp8:
assert hasattr(weight_t_fp8, "_data"), \
"_data attr doesn't exist (before save for bwd)"
ctx.save_for_backward(
inputmat_no_fp8 if weight.requires_grad and not fp8_wgrad else None,
inputmat_t if weight.requires_grad and fp8_wgrad else None,
......@@ -294,6 +314,9 @@ class _Linear(torch.autograd.Function):
weight_t_fp8,
fwd_scale_inverses,
) = ctx.saved_tensors
if weight_t_fp8 is not None:
assert hasattr(weight_t_fp8, "_data"), \
"_data attr doesn't exist (after restore in bwd)"
if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag:
tp_world_size = get_distributed_world_size(ctx.tp_group)
......@@ -349,7 +372,7 @@ class _Linear(torch.autograd.Function):
if ctx.requires_dgrad:
if ctx.fp8:
dgrad, _ = fp8_gemm(
weight_t_fp8,
weight_t_fp8._data,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
......@@ -470,6 +493,7 @@ class _Linear(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -554,9 +578,9 @@ class Linear(TransformerEngineBaseModule):
params_dtype: Optional[torch.dtype] = None,
parallel_mode: Optional[str] = None,
parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
device: Union[torch.device, str] = "cuda",
ub_split_rs: bool = False,
ub_split_ag: bool = False,
device: Union[torch.device, str] = "cuda",
ub_atomic_gemm_rs: bool = False,
ub_atomic_gemm_ag: bool = False,
) -> None:
......@@ -570,6 +594,7 @@ class Linear(TransformerEngineBaseModule):
self.return_bias = return_bias
self.apply_bias = bias and not return_bias
self.parameters_split = parameters_split
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.ub_split_rs = ub_split_rs
self.ub_split_ag = ub_split_ag
self.ub_atomic_gemm_rs = ub_atomic_gemm_rs
......@@ -609,18 +634,31 @@ class Linear(TransformerEngineBaseModule):
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
self.weight_tensor = torch.empty(
temp_weight = torch.empty(
self.out_features, self.in_features,
device=device, dtype=params_dtype)
# TODO(ksivaman): This functionality works with FP8 outside TE.
initialize_affine_weight_gpu(
self.weight_tensor,
temp_weight,
init_method,
get_rng_state_tracker,
partition_dim=1 if self.parallel_mode == "row" else 0,
stride=1,
)
if self.primary_weights_in_fp8:
self.init_fp8_metadata()
self.fp8_meta["update_amax_and_scale_fwd"] = True
self.weight_tensor = Float8Tensor.to_float8(
temp_weight,
fp8_meta=self.fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
else:
self.weight_tensor = temp_weight
if self.use_bias:
self.bias_tensor = torch.empty(self.out_features, device=device, dtype=params_dtype)
else:
......@@ -657,9 +695,17 @@ class Linear(TransformerEngineBaseModule):
slice_end = slice_begin + slice_size
self.register_parameter(
wname, Parameter(self.weight_tensor[slice_begin:slice_end])
)
# TODO(ksivaman): Add indexing op to torch dispatcher for float8
if self.primary_weights_in_fp8:
assert len(parameters_split) == 1, ("Slicing operation is not "
"supported in Float8Tensor "
"class!")
self.register_parameter(wname, Parameter(self.weight_tensor))
else:
self.register_parameter(
wname, Parameter(self.weight_tensor[slice_begin:slice_end])
)
set_tensor_model_parallel_attributes(
tensor=getattr(self, wname),
......@@ -697,13 +743,13 @@ class Linear(TransformerEngineBaseModule):
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]:
) -> List[Float8Tensor]:
"""
Fetch the fp8 weight tensor placeholders if they exist (when
`is_first_microbatch` is not `None`) or return empty fp8 weight
tensors (if `is_first_microbatch is None`)
"""
if not self.fp8:
if not self.fp8 or self.primary_weights_in_fp8:
return [None, None]
if is_first_microbatch is None:
......@@ -747,6 +793,8 @@ class Linear(TransformerEngineBaseModule):
"""
with self.prepare_forward(inp, is_first_microbatch) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8."
bias_tensor = (
self.bias if self.parameters_split is None
else self.bias_tensor if not torch.is_grad_enabled()
......@@ -790,6 +838,7 @@ class Linear(TransformerEngineBaseModule):
self.activation_dtype,
self.parallel_mode,
torch.is_grad_enabled(),
self.primary_weights_in_fp8,
self.ub_split_rs,
self.ub_split_ag,
self.ub_atomic_gemm_rs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment