"tests/vscode:/vscode.git/clone" did not exist on "c664b0e6838644c22839b6e9c88e61b4e9a540f6"
Unverified Commit b1820c44 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Experimental FP8 tensor class (#452)



* Experimental FP8 tensor
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add fp8 tensor to ci test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* review comments and tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Minor changes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Default to FP8 usage
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Naming changes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* minor fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix transpose caching
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Debug transpose caching

Handle case where transpose cache is updated externally.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Rename FP8GlobalStateManager.with_fp8_parameters
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* remove Float8Tensor from import API
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Avoid caching FP8 transposes if not required
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix import error in FP8 tensor tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix tranpose caching and checkpointing bug
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Improve caching and fix distopt case
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update transformer_engine/pytorch/float8_tensor.py
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Remove recursive logic
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix cache reset bug
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Store FP8 attributes in dict

Easier for multiple tensors to share, e.g. detached tensors.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Make sure scale_inv is 1D tensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Make sure scale_inv is 1D tensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fixes and detach recipe
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Set default fp8 data type
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
parent 67051eff
......@@ -35,6 +35,8 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.fp8_autocast
.. autoapifunction:: transformer_engine.pytorch.fp8_model_init
.. autoapifunction:: transformer_engine.pytorch.checkpoint
.. autoapifunction:: transformer_engine.pytorch.onnx_export
......@@ -12,3 +12,4 @@ PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pyt
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_attn.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from collections.abc import Iterable
from typing import Any, Dict, List, Tuple, Union
import pytest
import torch
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine_extensions as tex
# PyTorch tensor dtypes
_dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16]
# TE FP8 dtypes
_fp8_dtypes: List[tex.DType] = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]
# Numerical tolerances with FP8 types
_tols: Dict[tex.DType, Dict[str, float]] = {
tex.DType.kFloat8E4M3: dict(rtol=0.125, atol=0.0675), # epsilon = 0.0625
tex.DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125), # epsilon = 0.125
}
def _to_list(x: Union[Iterable, Any]) -> List:
"""Convert to list if iterable, otherwise put in singleton list"""
if isinstance(x, Iterable):
return list(x)
else:
return [x]
# Types that can be interpreted as tensor dims
DimsType = Union[Iterable[int], int]
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFloat8Tensor:
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def test_constructor(
self,
dims: DimsType = 1,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale_inv: float = 0.375,
dtype: torch.dtype = torch.float32,
) -> None:
"""Call constructor and perform sanity checks"""
dims = _to_list(dims)
tensor = Float8Tensor(
data=torch.zeros(dims, device="cuda", dtype=torch.uint8),
fp8_dtype=fp8_dtype,
fp8_scale_inv=torch.full([1], scale_inv),
dtype=dtype,
)
assert list(tensor.size()) == dims, "Incorrect dims"
assert tensor.dtype == dtype, "Incorrect nominal dtype"
assert tensor.is_cuda, "Incorrect device"
def _test_quantize_dequantize(
self,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 3.5,
dtype: torch.dtype = torch.float32,
dims: DimsType = 23,
) -> None:
"""Check numerical error when casting to FP8 and back"""
# Initialize random data
x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cpu") - 1
# Cast to FP8 and back
x_fp8 = Float8Tensor.to_float8(
x_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x_fp8 = x_fp8.from_float8().cpu()
# Check results
torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype])
# Make sure we are not trivially passing the test
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8, -x_ref, **_tols[fp8_dtype])
@pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
@pytest.mark.parametrize("dtype", _dtypes)
def test_quantize_dequantize_dtypes(
self,
fp8_dtype: tex.DType,
dtype: torch.dtype,
) -> None:
self._test_quantize_dequantize(fp8_dtype=fp8_dtype, dtype=dtype)
@pytest.mark.parametrize("scale", [0.375, 1, 3.5])
def test_quantize_dequantize_scales(self, scale: float) -> None:
self._test_quantize_dequantize(scale=scale)
@pytest.mark.parametrize("dims", [[], 1, 311, [7,11], [7,5,3], [2,3,5,3]])
def test_quantize_dequantize_dims(self, dims: DimsType) -> None:
self._test_quantize_dequantize(dims=dims)
def test_fp8_meta(
self,
dtype: torch.dtype = torch.float32,
dims: DimsType = 23,
) -> None:
"""Construct Float8Tensor using FP8 metadata and perform basic checks"""
# Get FP8 metadata from linear module
fp8_dtype = tex.DType.kFloat8E4M3
recipe = transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
module = te.Linear(32, 32)
_ = module(torch.zeros([8, 32], device="cuda"))
fp8_meta = module.fp8_meta
fp8_meta_index = tex.FP8FwdTensors.GEMM1_WEIGHT
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)
# Initialize random data
dims = _to_list(dims)
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
# Make Float8Tensor
x_fp8 = Float8Tensor.to_float8(
x_ref,
fp8_meta=fp8_meta,
fp8_meta_index=fp8_meta_index,
)
x_ref = x_fp8.from_float8()
assert list(x_fp8.size()) == dims, "Incorrect dims"
assert x_fp8.dtype == dtype, "Incorrect nominal dtype"
assert x_fp8.is_cuda, "Incorrect device"
assert x_fp8._fp8_dtype == fp8_dtype, "Incorrect FP8 dtype"
# Change FP8 metadata scale
fp8_meta[fp8_meta_key].scale[fp8_meta_index] = 2
fp8_meta[fp8_meta_key].scale_inv.fill_(123)
# Check results
torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype])
with pytest.raises(AssertionError):
# Make sure we are not trivially passing the test
torch.testing.assert_close(x_fp8, -x_ref, **_tols[fp8_dtype])
# Check if scaling factor is updated after in-place ops
x_fp8 += 0
fp8_meta[fp8_meta_key].scale[fp8_meta_index] = 4
fp8_meta[fp8_meta_key].scale_inv.fill_(321)
assert x_fp8._scale_inv.item() == 0.5, "Incorrect FP8 scale_inv"
torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype])
y = x_fp8.detach()
y += 0
assert x_fp8._scale_inv.item() == 0.25, "Incorrect FP8 scale_inv"
torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype])
def test_basic_ops(
self,
dims: DimsType = 23,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 3.5,
dtype: torch.dtype = torch.float32,
) -> None:
"""Test basic out-of-place ops"""
# Initialize random data
dims = _to_list(dims)
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
x_fp8 = Float8Tensor.to_float8(
x_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
y_fp8 = Float8Tensor.to_float8(
y_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x_ref = x_fp8.from_float8()
y_ref = y_fp8.from_float8()
# Exact operations
torch.testing.assert_close(-x_fp8, -x_ref, rtol=0, atol=0)
torch.testing.assert_close(x_fp8.abs(), x_ref.abs(), rtol=0, atol=0)
# Operations with numerical error
tols = _tols[fp8_dtype]
torch.testing.assert_close(x_fp8 + y_fp8, x_ref + y_ref, **tols)
torch.testing.assert_close(x_fp8 - y_fp8, x_ref - y_ref, **tols)
torch.testing.assert_close(x_fp8 * y_fp8, x_ref * y_ref, **tols)
torch.testing.assert_close(x_fp8 + y_ref, x_ref + y_ref, **tols)
torch.testing.assert_close(x_ref + y_fp8, x_ref + y_ref, **tols)
torch.testing.assert_close(torch.sin(x_fp8), torch.sin(x_ref), **tols)
# Make sure we are not trivially passing tests
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8 + y_fp8, x_ref - y_fp8, **tols)
def test_inplace_ops(
self,
dims: DimsType = 23,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 3.5,
dtype: torch.dtype = torch.float32,
) -> None:
"""Test in-place ops"""
# Initialize random data
dims = _to_list(dims)
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
x_fp8 = Float8Tensor.to_float8(
x_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
y_fp8 = Float8Tensor.to_float8(
y_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x_ref = x_fp8.from_float8()
y_ref = y_fp8.from_float8()
# In-place operations
tols = _tols[fp8_dtype]
x_fp8 += y_ref
x_ref += y_ref
torch.testing.assert_close(x_fp8, x_ref, **tols)
x_ref = x_fp8.from_float8()
x_fp8 -= y_fp8
x_ref -= y_fp8
torch.testing.assert_close(x_fp8, x_ref, **tols)
x_ref = x_fp8.from_float8()
x_fp8 *= 2
x_ref *= 2
torch.testing.assert_close(x_fp8, x_ref, **tols)
x_ref = x_fp8.from_float8()
# Make sure we are not trivially passing tests
x_ref += 123
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8, x_ref, **tols)
@pytest.mark.parametrize("dims", [[33, 41], [5, 7, 11]])
@pytest.mark.parametrize("transpose_dims", [(0, 1), (-2, -1), (0, 0)])
def test_transpose(
self,
dims: DimsType,
transpose_dims: Tuple[int, int],
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 1,
dtype: torch.dtype = torch.float32,
) -> None:
"""Test transpose"""
# Initialize random data
dims = _to_list(dims)
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
x_fp8 = Float8Tensor.to_float8(
x_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x_ref = x_fp8.from_float8()
# Perform transpose
y_fp8 = x_fp8.transpose(*transpose_dims)
y_ref = x_ref.transpose(*transpose_dims)
# Check results
tols = dict(rtol=0, atol=0)
torch.testing.assert_close(y_fp8, y_ref, **tols)
# Make sure we are not trivially passing the test
if transpose_dims[0] != transpose_dims[1]:
with pytest.raises(AssertionError):
torch.testing.assert_close(
y_fp8,
x_ref,
**tols,
)
# Check transpose caching
if x_fp8.dim() == 2 and transpose_dims[0] != transpose_dims[1]:
x_fp8 += 0.5
x_ref = x_fp8.from_float8()
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache=True),
x_ref.transpose(*transpose_dims),
**tols,
)
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache=True),
x_ref.transpose(*transpose_dims),
**tols,
)
x_fp8 += 0.5
x_ref = x_fp8.from_float8()
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache=True),
x_ref.transpose(*transpose_dims),
**tols,
)
......@@ -12,7 +12,7 @@ import torch
import torch.nn as nn
from torch.nn import Parameter
from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager
from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager, fp8_model_init
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
......@@ -339,7 +339,7 @@ class TorchGPT(nn.Module):
return x
def _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False):
def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False):
reset_rng_states()
FP8GlobalStateManager.reset()
......@@ -354,24 +354,26 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False):
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
params_dtype=dtype,
with fp8_model_init(enabled=fp8 and fp8_model_params):
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
params_dtype=dtype,
fuse_qkv_params=True,
)
.cuda()
)
.cuda()
)
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
......@@ -400,18 +402,19 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False):
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("fp8", all_boolean)
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8):
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_params):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False)
outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=True)
outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False)
outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True)
assert_all_equal(outputs, outputs_recompute)
def _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False):
def _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False):
reset_rng_states()
FP8GlobalStateManager.reset()
......@@ -426,7 +429,8 @@ def _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False):
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
block = (
with fp8_model_init(enabled=fp8 and fp8_model_params):
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
......@@ -441,9 +445,10 @@ def _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False):
output_layernorm=False,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
params_dtype=dtype,
fuse_qkv_params=True,
)
.cuda()
)
)
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
......@@ -483,14 +488,15 @@ def _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False):
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("fp8", all_boolean)
def test_gpt_full_activation_recompute(dtype, bs, model, fp8):
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
outputs = _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False)
outputs_recompute = _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=True)
outputs = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False)
outputs_recompute = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True)
assert_all_equal(outputs, outputs_recompute)
......@@ -871,6 +877,7 @@ def test_linear_accuracy(dtype, bs, model):
else:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
......@@ -911,6 +918,7 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps):
else:
assert_allclose(te_outputs[0], torch_outputs[0], 2e-2)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
......@@ -1110,3 +1118,72 @@ def test_gpt_cuda_graph(dtype, bs, model):
assert_allclose(out, graphed_out, 1e-3)
assert_allclose(params, graphed_params, 1e-3)
assert_allclose(grads, graphed_grads, 1e-3)
def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params):
reset_rng_states()
FP8GlobalStateManager.reset()
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
with fp8_model_init(enabled=fp8_model_params):
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
params_dtype=dtype,
fuse_qkv_params=True,
)
.cuda()
)
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
with fp8_autocast(enabled=True):
te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
outputs = [te_out, te_inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_gpt_fp8_parameters(dtype, bs, model):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
outputs = _test_gpt_fp8_parameters(bs, dtype, config, False)
outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True)
assert_all_equal(outputs, outputs_fp8_params)
......@@ -147,7 +147,7 @@ def set_layer_scale(module: torch.nn.Module, scale: float, num_gemms: int):
"""Initialize the FP8 quantization scales in module"""
NB_SCALES_PER_GEMM = 3 # One scale per: input, weights, and output GEMM tensors.
nb_total_scales = num_gemms * NB_SCALES_PER_GEMM
module.fp8_init(num_gemms)
module.init_fp8_metadata(num_gemms)
module.fp8_meta["scaling_fwd"].scale = torch.ones(
nb_total_scales, dtype=torch.float32, device="cuda") / scale
module.fp8_meta["scaling_fwd"].scale_inv = torch.ones(
......
......@@ -16,7 +16,7 @@ import pytest
import torch
import transformer_engine.pytorch as te
import transformer_engine_extensions as tex
from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8, cast_from_fp8
from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8
from transformer_engine.pytorch.module.base import get_workspace
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
......@@ -93,7 +93,7 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd
model_in = Test_TE_Export(precision, True)
with te.fp8_autocast(enabled=True):
model_in.fp8_init()
model_in.init_fp8_metadata()
# scaling fwd
model_in.fp8_meta["scaling_fwd"].scale = torch.ones(3, dtype=torch.float32, device="cuda") * scale_fwd
model_in.fp8_meta["scaling_fwd"].scale_inv = torch.ones(3, dtype=torch.float32, device="cuda") / scale_fwd
......
......@@ -13,6 +13,7 @@ from .attention import InferenceParams
from .attention import MultiheadAttention
from .transformer import TransformerLayer
from .fp8 import fp8_autocast
from .fp8 import fp8_model_init
from .export import onnx_export
from .distributed import checkpoint
from .distributed import CudaRNGStatesTracker
......
......@@ -83,14 +83,16 @@ def initialize_affine_weight_gpu(
weight: torch.Tensor,
init_method: Callable,
get_rng_state_tracker: Callable,
partition_dim: int,
partition_dim: int = 0,
stride: int = 1,
set_tp_attributes: bool = True,
) -> None:
"""Initialize affine weight for model parallel on GPU."""
set_tensor_model_parallel_attributes(
tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
)
if set_tp_attributes:
set_tensor_model_parallel_attributes(
tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
)
if get_rng_state_tracker is None:
init_method(weight)
......
This diff is collapsed.
......@@ -17,7 +17,7 @@ from .utils import get_device_compute_capability
from .jit import jit_fuser
__all__ = ["fp8_autocast"]
__all__ = ["fp8_autocast", "fp8_model_init"]
def check_fp8_support() -> Tuple[bool, str]:
......@@ -59,6 +59,7 @@ class FP8GlobalStateManager:
FP8_CALIBRATION = False
FP8_RECIPE = None
FP8_DISTRIBUTED_GROUP = None
FP8_PARAMETERS = False
IS_FIRST_FP8_MODULE = False
FP8_AUTOCAST_COUNTER = 0
FP8_CURRENT_CONTEXT_ID = 0
......@@ -277,6 +278,11 @@ class FP8GlobalStateManager:
"""Is FP8 calibration"""
return cls.FP8_CALIBRATION
@classmethod
def with_fp8_parameters(cls) -> bool:
"""Should the parameters be stored as FP8"""
return cls.FP8_PARAMETERS
@classmethod
def is_first_fp8_module(cls):
"""Returns `True` only the first time when called multiple
......@@ -400,6 +406,11 @@ class FP8GlobalStateManager:
fp8_group: Optional[dist_group_type] = None,
) -> None:
"""Set state and tracking variables for entry into FP8 region."""
if cls.FP8_AUTOCAST_DEPTH == 0:
if callable(cls.amax_forward_global_reduce_func):
cls.amax_reduce_handle_fwd = cls.amax_forward_global_reduce_func() # pylint: disable=not-callable
cls.delete_key_from_amax_buffer(forward=True)
cls.FP8_ENABLED = enabled
cls.FP8_CALIBRATION = calibrating
cls.FP8_RECIPE = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
......@@ -419,11 +430,6 @@ class FP8GlobalStateManager:
"""Set state and tracking variables for exit from FP8 region."""
cls.FP8_AUTOCAST_DEPTH -= 1
if cls.FP8_AUTOCAST_DEPTH == 0:
if callable(cls.amax_forward_global_reduce_func):
cls.amax_reduce_handle_fwd = cls.amax_forward_global_reduce_func() # pylint: disable=not-callable
cls.delete_key_from_amax_buffer(forward=True)
@classmethod
def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None:
"""Copy the scaling factors and amaxes for recompute forward phase
......@@ -477,9 +483,45 @@ class FP8GlobalStateManager:
fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"]
@contextmanager
def fp8_model_init(enabled: bool = True) -> None:
"""
Context manager for FP8 initialization of parameters.
Example usage:
.. code-block:: python
with fp8_model_init(enabled=True):
model = transformer_engine.pytorch.Linear(768, 768)
Parameters
----------
enabled: bool, default = `True`
when enabled, Transformer Engine modules created inside this `fp8_model_init`
region will hold only FP8 copies of its parameters, as opposed to the default
behavior where both higher precision and FP8 copies are present. Setting this
option to `True` may result in lower memory consumption and is especially
useful for scenarios like:
* full model training using optimizer with master weights, where the high
precision copies of weights are already present in the optimizer.
* inference, where only the FP8 copies of the parameters are used.
* LoRA-like fine-tuning, where the main parameters of the model do not change.
This functionality is *EXPERIMENTAL*.
"""
try:
_fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS
FP8GlobalStateManager.FP8_PARAMETERS = enabled
yield
finally:
FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters # pylint: disable=used-before-assignment
@contextmanager
def fp8_autocast(
enabled: bool = False,
enabled: bool = True,
calibrating: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
fp8_group: Optional[dist_group_type] = None,
......@@ -508,7 +550,7 @@ def fp8_autocast(
Parameters
----------
enabled: bool, default = `False`
enabled: bool, default = `True`
whether or not to enable fp8
calibrating: bool, default = `False`
calibration mode allows collecting statistics such as amax and scale
......@@ -523,7 +565,10 @@ def fp8_autocast(
"""
try:
fp8_state = FP8GlobalStateManager.get_fp8_autocast_state()
FP8GlobalStateManager.fp8_autocast_enter(enabled, calibrating, fp8_recipe, fp8_group)
FP8GlobalStateManager.fp8_autocast_enter(enabled=enabled,
calibrating=calibrating,
fp8_recipe=fp8_recipe,
fp8_group=fp8_group)
yield
finally:
FP8GlobalStateManager.set_fp8_autocast_state(fp8_state) # pylint: disable=used-before-assignment
......
......@@ -36,6 +36,7 @@ from ..cpp_extensions import (
cast_to_fp8,
)
from ..constants import dist_group_type
from ..float8_tensor import Float8Tensor
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
......@@ -451,21 +452,29 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
setattr(
self,
weight_cast_attr,
torch.empty(
shape,
device=torch.cuda.current_device(),
dtype=torch.uint8,
),
Float8Tensor(
data=torch.empty(
shape,
device=torch.cuda.current_device(),
dtype=torch.uint8,
),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_scale_inv=1,
)
)
setattr(
self,
weight_transpose_attr,
torch.empty(
shape[1],
shape[0],
device=torch.cuda.current_device(),
dtype=torch.uint8,
),
Float8Tensor(
data=torch.empty(
shape[1],
shape[0],
device=torch.cuda.current_device(),
dtype=torch.uint8,
),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_scale_inv=1,
)
)
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
......@@ -483,12 +492,17 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# This routine is shared across FP8 and FP8_calibration paths so should not actually
# assume FP8 execution.
def fp8_init(self, num_gemms: int = 1) -> None:
def init_fp8_metadata(self, num_gemms: int = 1) -> None:
"""Initialize fp8 related metadata and tensors during fprop."""
self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
if self.fp8_parameters and not self.fp8_initialized:
self.fp8_meta["num_gemms"] = num_gemms
self.init_fp8_meta_tensors()
if self.fp8 or self.fp8_calibration:
# FP8 init has already been run and recipe is the same, don't do anything.
if (self.fp8_initialized
......@@ -536,7 +550,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
assert self.tp_group_initialized, "TP group not initialized."
self.set_activation_dtype(inp)
self.fp8_init(num_gemms=num_gemms)
self.init_fp8_metadata(num_gemms=num_gemms)
# Create persistent tensors for fp8 weights and their transposes
# only when fp8 weight caching is used.
......@@ -765,7 +779,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def get_fp8_weights_empty_tensors(
self,
is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]:
) -> List[Float8Tensor]:
"""
Returns empty tensors to be later used to store fp8 version of weights
and their transposes (for the bwd pass) for this batch (or microbatch).
......@@ -781,23 +795,42 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
fp8_weight_tensors = []
for shape in self.fp8_weight_shapes:
fp8_weight_tensors.append(
torch.empty(
shape,
device=torch.cuda.current_device(),
dtype=torch.uint8,
Float8Tensor(
data=torch.empty(
shape,
device=torch.cuda.current_device(),
dtype=torch.uint8,
),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_scale_inv=1,
)
)
fp8_weight_tensors.append(
torch.empty(
shape[1],
shape[0],
device=torch.cuda.current_device(),
dtype=torch.uint8,
Float8Tensor(
data=torch.empty(
shape[1],
shape[0],
device=torch.cuda.current_device(),
dtype=torch.uint8,
),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_scale_inv=1,
)
)
return fp8_weight_tensors
def state_dict(self, *args, **kwargs) -> Dict:
"""Get dictionary containing module state"""
state = super().state_dict(*args, **kwargs)
# Convert Float8Tensors to plain tensors
# Note: Float8Tensors don't serialize well, especially if they
# contain references to FP8 metadata.
for key, val in state.items():
if isinstance(val, Float8Tensor):
state[key] = val.from_float8()
return state
@abstractmethod
def forward(self):
......
......@@ -23,7 +23,7 @@ from .base import (
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ..fp8 import get_fp8_te_dtype
from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager
from ..utils import (
divide,
get_default_init_method,
......@@ -43,6 +43,7 @@ from ..jit import no_torch_dynamo
from ._common import _apply_normalization
from ..float8_tensor import Float8Tensor
__all__ = ["LayerNormLinear"]
......@@ -79,10 +80,11 @@ class _LayerNormLinear(torch.autograd.Function):
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
normalization: str,
primary_weights_in_fp8: bool,
ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool,
ub_split_ag: bool,
normalization: str,
ub_atomic_gemm_ag: bool,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
......@@ -159,28 +161,43 @@ class _LayerNormLinear(torch.autograd.Function):
)
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
if update_fp8_weights:
if primary_weights_in_fp8:
# Weight is already in FP8
weight.reset_fp8_meta_scale_inv()
weight_fp8 = weight
weight_t_fp8 = None
if is_grad_enabled:
weight_t_fp8 = weight_fp8.transpose(update_cache=is_first_microbatch)
elif update_fp8_weights:
# Need to cast weights to FP8
weight_fp8 = Float8Tensor(
data=weight_fp8._data,
fp8_meta=fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
if is_grad_enabled:
tex.fp8_cast_transpose_fused(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
cast_out=weight_fp8,
transpose_out=weight_t_fp8,
cast_out=weight_fp8._data,
transpose_out=weight_t_fp8._data,
)
else:
weight_t_fp8 = None
weight_fp8 = tex.cast_to_fp8(
weight_fp8._data = tex.cast_to_fp8(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward)
fp8_dtype_forward,
)
weight_t_fp8 = None
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo
out, _ = tex.fp8_gemm(
weight_fp8,
weight_fp8._data,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
......@@ -356,7 +373,7 @@ class _LayerNormLinear(torch.autograd.Function):
# DGRAD: Evaluated unconditionally to feed into Linear backward
_ = tex.fp8_gemm(
weight_t_fp8,
weight_t_fp8._data,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
......@@ -544,6 +561,7 @@ class _LayerNormLinear(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -646,10 +664,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
return_layernorm_output: bool = False,
parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
zero_centered_gamma: bool = False,
device: Union[torch.device, str] = "cuda",
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_split_ag: bool = False,
device: Union[torch.device, str] = "cuda",
ub_atomic_gemm_ag: bool = False,
) -> None:
super().__init__()
......@@ -666,6 +684,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.return_layernorm_output = return_layernorm_output
self.parameters_split = parameters_split
self.zero_centered_gamma = zero_centered_gamma
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.ub_bulk_wgrad = ub_bulk_wgrad
self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_split_ag = ub_split_ag
......@@ -719,18 +738,30 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.layer_norm_bias = None
self.reset_layer_norm_parameters()
self.weight_tensor = torch.empty(
temp_weight = torch.empty(
self.out_features, self.in_features,
device=device, dtype=params_dtype)
initialize_affine_weight_gpu(
self.weight_tensor,
temp_weight,
init_method,
get_rng_state_tracker,
partition_dim=1 if self.parallel_mode == "row" else 0,
stride=1,
)
if self.primary_weights_in_fp8:
self.init_fp8_metadata()
self.fp8_meta["update_amax_and_scale_fwd"] = True
self.weight_tensor = Float8Tensor.to_float8(
temp_weight,
fp8_meta=self.fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
else:
self.weight_tensor = temp_weight
if self.use_bias:
self.bias_tensor = torch.empty(
self.out_features,
......@@ -769,10 +800,17 @@ class LayerNormLinear(TransformerEngineBaseModule):
bname = pname + "bias"
slice_end = slice_begin + slice_size
self.register_parameter(
wname, Parameter(self.weight_tensor[slice_begin:slice_end])
)
# NOTE(future): Figure out a way to support slicing when weights
# are of `Float8Tensor` class
if self.primary_weights_in_fp8:
assert len(parameters_split) == 1, ("Slicing operation is not "
"supported in Float8Tensor "
"class!")
self.register_parameter(wname, Parameter(self.weight_tensor))
else:
self.register_parameter(
wname, Parameter(self.weight_tensor[slice_begin:slice_end])
)
set_tensor_model_parallel_attributes(
tensor=getattr(self, wname),
......@@ -833,7 +871,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
`is_first_microbatch` is not `None`) or return empty fp8 weight
tensors (if `is_first_microbatch is None`)
"""
if not self.fp8:
if not self.fp8 or self.primary_weights_in_fp8:
return [None, None]
if is_first_microbatch is None:
......@@ -877,6 +915,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
"""
with self.prepare_forward(inp, is_first_microbatch) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8."
bias_tensor = (
self.bias if self.parameters_split is None
else self.bias_tensor if not torch.is_grad_enabled()
......@@ -927,10 +967,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.normalization,
self.primary_weights_in_fp8,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_split_ag,
self.normalization,
self.ub_atomic_gemm_ag,
)
out = fwd_fn(*args)
......
......@@ -20,7 +20,7 @@ from .base import (
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ..fp8 import get_fp8_te_dtype
from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager
from ..jit import (
bias_gelu_fused,
bgrad_dgelu_fused,
......@@ -47,6 +47,7 @@ from .. import cpp_extensions as tex
from ..constants import dist_group_type, TE_DType
from ..jit import no_torch_dynamo
from ..float8_tensor import Float8Tensor
from ._common import _apply_normalization
......@@ -105,14 +106,15 @@ class _LayerNormMLP(torch.autograd.Function):
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
activation: str,
normalization: str,
primary_weights_in_fp8: bool,
ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool,
ub_split_rs: bool,
ub_atomic_gemm_rs: bool,
ub_split_ag: bool,
ub_atomic_gemm_ag: bool,
activation: str,
normalization: str,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
......@@ -196,45 +198,68 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_bias = cast_if_needed(fc1_bias, bias_dtype) if use_fc1_bias else fc1_bias
fc2_bias = cast_if_needed(fc2_bias, bias_dtype) if use_fc2_bias else fc2_bias
if update_fp8_weights:
if primary_weights_in_fp8:
# Weights are already in FP8
fc1_weight.reset_fp8_meta_scale_inv()
fc2_weight.reset_fp8_meta_scale_inv()
fc1_weight_fp8 = fc1_weight
fc2_weight_fp8 = fc2_weight
fc1_weight_t_fp8 = None
fc2_weight_t_fp8 = None
if is_grad_enabled:
fc1_weight_t_fp8 = fc1_weight_fp8.transpose(update_cache=is_first_microbatch)
fc2_weight_t_fp8 = fc2_weight_fp8.transpose(update_cache=is_first_microbatch)
elif update_fp8_weights:
# Need to cast weights to FP8
fc1_weight_fp8 = Float8Tensor(
data=fc1_weight_fp8._data,
fp8_meta=fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
fc2_weight_fp8 = Float8Tensor(
data=fc2_weight_fp8._data,
fp8_meta=fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT,
)
if is_grad_enabled:
# Fused cast-transpose kernels
tex.fp8_cast_transpose_fused(
fc1_weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
cast_out=fc1_weight_fp8,
transpose_out=fc1_weight_t_fp8,
cast_out=fc1_weight_fp8._data,
transpose_out=fc1_weight_t_fp8._data,
)
tex.fp8_cast_transpose_fused(
fc2_weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward,
cast_out=fc2_weight_fp8,
transpose_out=fc2_weight_t_fp8,
cast_out=fc2_weight_fp8._data,
transpose_out=fc2_weight_t_fp8._data,
)
else:
fc1_weight_t_fp8 = None
fc1_weight_fp8 = tex.cast_to_fp8(
fc1_weight_fp8._data = tex.cast_to_fp8(
fc1_weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
)
fc2_weight_t_fp8 = None
fc2_weight_fp8 = tex.cast_to_fp8(
fc1_weight_t_fp8 = None
fc2_weight_fp8._data = tex.cast_to_fp8(
fc2_weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward,
)
fc2_weight_t_fp8 = None
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo
fc1_out, _ = tex.fp8_gemm(
fc1_weight_fp8,
fc1_weight_fp8._data,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
......@@ -283,7 +308,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo
_ = tex.fp8_gemm(
fc2_weight_fp8,
fc2_weight_fp8._data,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward,
......@@ -530,7 +555,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ctx.ub_atomic_gemm_ag else ub_algo
# FC2 DGRAD; Unconditional
fc2_dgrad, _ = tex.fp8_gemm(
fc2_weight_t_fp8,
fc2_weight_t_fp8._data,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward,
......@@ -645,7 +670,7 @@ class _LayerNormMLP(torch.autograd.Function):
)
# FC1 DGRAD: Unconditional
_ = tex.fp8_gemm(
fc1_weight_t_fp8,
fc1_weight_t_fp8._data,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
......@@ -908,6 +933,7 @@ class _LayerNormMLP(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -1020,12 +1046,12 @@ class LayerNormMLP(TransformerEngineBaseModule):
micro_batch_size: Optional[int] = None,
set_parallel_mode: bool = False,
zero_centered_gamma: bool = False,
device: Union[torch.device, str] = "cuda",
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_split_rs: bool = False,
ub_atomic_gemm_rs: bool = False,
ub_split_ag: bool = False,
device: Union[torch.device, str] = "cuda",
ub_atomic_gemm_ag: bool = False,
) -> None:
super().__init__()
......@@ -1043,6 +1069,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.activation == 'gelu')
self.set_parallel_mode = set_parallel_mode
self.zero_centered_gamma = zero_centered_gamma
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.ub_bulk_wgrad = ub_bulk_wgrad
self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_split_rs = ub_split_rs
......@@ -1102,19 +1129,30 @@ class LayerNormMLP(TransformerEngineBaseModule):
else:
fc1_output_features = self.size_per_partition
# FC1 init
self.fc1_weight = Parameter(
torch.empty(fc1_output_features, hidden_size, device=device, dtype=params_dtype)
)
self.fp8_weight_shapes.append(self.fc1_weight.shape)
fc1_temp_weight = torch.empty(
fc1_output_features, hidden_size, device=device, dtype=params_dtype)
initialize_affine_weight_gpu(
self.fc1_weight,
fc1_temp_weight,
init_method,
get_rng_state_tracker,
partition_dim=0,
stride=1,
set_tp_attributes=False,
)
if self.primary_weights_in_fp8:
self.init_fp8_metadata(num_gemms=2)
self.fp8_meta["update_amax_and_scale_fwd"] = True
fc1_temp_weight = Float8Tensor.to_float8(
fc1_temp_weight,
fp8_meta=self.fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
self.fc1_weight = Parameter(fc1_temp_weight)
set_tensor_model_parallel_attributes(self.fc1_weight, True, 0, 1)
self.fp8_weight_shapes.append(self.fc1_weight.shape)
if self.use_bias:
self.fc1_bias = Parameter(
torch.empty(fc1_output_features, device=device, dtype=params_dtype)
......@@ -1127,19 +1165,27 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fc1_bias.zero_()
# FC2 init
self.fc2_weight = Parameter(
torch.empty(hidden_size, self.size_per_partition, device=device, dtype=params_dtype)
)
self.fp8_weight_shapes.append(self.fc2_weight.shape)
fc2_temp_weight = torch.empty(
hidden_size, self.size_per_partition, device=device, dtype=params_dtype)
initialize_affine_weight_gpu(
self.fc2_weight,
fc2_temp_weight,
output_layer_init_method,
get_rng_state_tracker,
partition_dim=1,
stride=1,
set_tp_attributes=False,
)
if self.primary_weights_in_fp8:
fc2_temp_weight = Float8Tensor.to_float8(
fc2_temp_weight,
fp8_meta=self.fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT,
)
self.fc2_weight = Parameter(fc2_temp_weight)
set_tensor_model_parallel_attributes(self.fc2_weight, True, 1, 1)
self.fp8_weight_shapes.append(self.fc2_weight.shape)
if self.use_bias:
self.fc2_bias = Parameter(
torch.empty(hidden_size, device=device, dtype=params_dtype)
......@@ -1192,7 +1238,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
`is_first_microbatch` is not `None`) or return empty fp8 weight
tensors (if `is_first_microbatch is None`)
"""
if not self.fp8:
if not self.fp8 or self.primary_weights_in_fp8:
return [None, None, None, None]
if is_first_microbatch is None:
......@@ -1235,6 +1281,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
"""
with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8."
# Fetch the fp8 weights placeholders (for linear/gemm)
weight1_fp8, weight1_t_fp8, weight2_fp8, weight2_t_fp8 = \
self.get_fp8_weights_scratchpad(
......@@ -1279,14 +1327,15 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.activation,
self.normalization,
self.primary_weights_in_fp8,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_split_rs,
self.ub_atomic_gemm_rs,
self.ub_split_ag,
self.ub_atomic_gemm_ag,
self.activation,
self.normalization,
)
out = fwd_fn(*args)
......
......@@ -20,7 +20,7 @@ from .base import (
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ..fp8 import get_fp8_te_dtype
from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager
from ..utils import (
divide,
get_default_init_method,
......@@ -45,6 +45,8 @@ from ..cpp_extensions import (
from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
from ..float8_tensor import Float8Tensor
__all__ = ["Linear"]
......@@ -57,9 +59,9 @@ class _Linear(torch.autograd.Function):
@staticmethod
def forward(
ctx,
weight: torch.Tensor,
weight_fp8: Union[torch.Tensor, None],
weight_t_fp8: Union[torch.Tensor, None],
weight: Union[Float8Tensor, torch.Tensor],
weight_fp8: Union[Float8Tensor, None],
weight_t_fp8: Union[Float8Tensor, None],
inp: torch.Tensor,
bias: torch.Tensor,
use_bias: bool,
......@@ -75,6 +77,7 @@ class _Linear(torch.autograd.Function):
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
is_grad_enabled: bool,
primary_weights_in_fp8: bool,
ub_split_rs: bool,
ub_split_ag: bool,
ub_atomic_gemm_rs: bool,
......@@ -141,24 +144,38 @@ class _Linear(torch.autograd.Function):
)
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
if update_fp8_weights:
if primary_weights_in_fp8:
# Weight is already in FP8
weight.reset_fp8_meta_scale_inv()
weight_fp8 = weight
weight_t_fp8 = None
if is_grad_enabled:
weight_t_fp8 = weight_fp8.transpose(update_cache=is_first_microbatch)
elif update_fp8_weights:
# Need to cast weights to FP8
weight_fp8 = Float8Tensor(
data=weight_fp8._data,
fp8_meta=fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
if is_grad_enabled:
fp8_cast_transpose_fused(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
cast_out=weight_fp8,
transpose_out=weight_t_fp8,
cast_out=weight_fp8._data,
transpose_out=weight_t_fp8._data,
)
else:
weight_t_fp8 = None
weight_fp8 = cast_to_fp8(
weight_fp8._data = cast_to_fp8(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
)
weight_t_fp8 = None
proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = (
None, None, None, activation_dtype)
......@@ -184,7 +201,7 @@ class _Linear(torch.autograd.Function):
ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo
_ = fp8_gemm(
weight_fp8,
weight_fp8._data,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
......@@ -245,6 +262,9 @@ class _Linear(torch.autograd.Function):
if is_grad_enabled:
fp8_wgrad = fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad
if fp8:
assert hasattr(weight_t_fp8, "_data"), \
"_data attr doesn't exist (before save for bwd)"
ctx.save_for_backward(
inputmat_no_fp8 if weight.requires_grad and not fp8_wgrad else None,
inputmat_t if weight.requires_grad and fp8_wgrad else None,
......@@ -294,6 +314,9 @@ class _Linear(torch.autograd.Function):
weight_t_fp8,
fwd_scale_inverses,
) = ctx.saved_tensors
if weight_t_fp8 is not None:
assert hasattr(weight_t_fp8, "_data"), \
"_data attr doesn't exist (after restore in bwd)"
if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag:
tp_world_size = get_distributed_world_size(ctx.tp_group)
......@@ -349,7 +372,7 @@ class _Linear(torch.autograd.Function):
if ctx.requires_dgrad:
if ctx.fp8:
dgrad, _ = fp8_gemm(
weight_t_fp8,
weight_t_fp8._data,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
......@@ -470,6 +493,7 @@ class _Linear(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -554,9 +578,9 @@ class Linear(TransformerEngineBaseModule):
params_dtype: Optional[torch.dtype] = None,
parallel_mode: Optional[str] = None,
parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
device: Union[torch.device, str] = "cuda",
ub_split_rs: bool = False,
ub_split_ag: bool = False,
device: Union[torch.device, str] = "cuda",
ub_atomic_gemm_rs: bool = False,
ub_atomic_gemm_ag: bool = False,
) -> None:
......@@ -570,6 +594,7 @@ class Linear(TransformerEngineBaseModule):
self.return_bias = return_bias
self.apply_bias = bias and not return_bias
self.parameters_split = parameters_split
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.ub_split_rs = ub_split_rs
self.ub_split_ag = ub_split_ag
self.ub_atomic_gemm_rs = ub_atomic_gemm_rs
......@@ -609,18 +634,31 @@ class Linear(TransformerEngineBaseModule):
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
self.weight_tensor = torch.empty(
temp_weight = torch.empty(
self.out_features, self.in_features,
device=device, dtype=params_dtype)
# TODO(ksivaman): This functionality works with FP8 outside TE.
initialize_affine_weight_gpu(
self.weight_tensor,
temp_weight,
init_method,
get_rng_state_tracker,
partition_dim=1 if self.parallel_mode == "row" else 0,
stride=1,
)
if self.primary_weights_in_fp8:
self.init_fp8_metadata()
self.fp8_meta["update_amax_and_scale_fwd"] = True
self.weight_tensor = Float8Tensor.to_float8(
temp_weight,
fp8_meta=self.fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
else:
self.weight_tensor = temp_weight
if self.use_bias:
self.bias_tensor = torch.empty(self.out_features, device=device, dtype=params_dtype)
else:
......@@ -657,9 +695,17 @@ class Linear(TransformerEngineBaseModule):
slice_end = slice_begin + slice_size
self.register_parameter(
wname, Parameter(self.weight_tensor[slice_begin:slice_end])
)
# TODO(ksivaman): Add indexing op to torch dispatcher for float8
if self.primary_weights_in_fp8:
assert len(parameters_split) == 1, ("Slicing operation is not "
"supported in Float8Tensor "
"class!")
self.register_parameter(wname, Parameter(self.weight_tensor))
else:
self.register_parameter(
wname, Parameter(self.weight_tensor[slice_begin:slice_end])
)
set_tensor_model_parallel_attributes(
tensor=getattr(self, wname),
......@@ -697,13 +743,13 @@ class Linear(TransformerEngineBaseModule):
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]:
) -> List[Float8Tensor]:
"""
Fetch the fp8 weight tensor placeholders if they exist (when
`is_first_microbatch` is not `None`) or return empty fp8 weight
tensors (if `is_first_microbatch is None`)
"""
if not self.fp8:
if not self.fp8 or self.primary_weights_in_fp8:
return [None, None]
if is_first_microbatch is None:
......@@ -747,6 +793,8 @@ class Linear(TransformerEngineBaseModule):
"""
with self.prepare_forward(inp, is_first_microbatch) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8."
bias_tensor = (
self.bias if self.parameters_split is None
else self.bias_tensor if not torch.is_grad_enabled()
......@@ -790,6 +838,7 @@ class Linear(TransformerEngineBaseModule):
self.activation_dtype,
self.parallel_mode,
torch.is_grad_enabled(),
self.primary_weights_in_fp8,
self.ub_split_rs,
self.ub_split_ag,
self.ub_atomic_gemm_rs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment