"git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "e0ade06d6305cf84b41c1962cdd9dfdbfee16ac9"
Unverified Commit b1820c44 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Experimental FP8 tensor class (#452)



* Experimental FP8 tensor
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add fp8 tensor to ci test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* review comments and tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Minor changes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Default to FP8 usage
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Naming changes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* minor fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix transpose caching
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Debug transpose caching

Handle case where transpose cache is updated externally.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Rename FP8GlobalStateManager.with_fp8_parameters
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* remove Float8Tensor from import API
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Avoid caching FP8 transposes if not required
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix import error in FP8 tensor tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix tranpose caching and checkpointing bug
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Improve caching and fix distopt case
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update transformer_engine/pytorch/float8_tensor.py
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Remove recursive logic
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix cache reset bug
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Store FP8 attributes in dict

Easier for multiple tensors to share, e.g. detached tensors.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Make sure scale_inv is 1D tensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Make sure scale_inv is 1D tensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fixes and detach recipe
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Set default fp8 data type
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
parent 67051eff
...@@ -35,6 +35,8 @@ pyTorch ...@@ -35,6 +35,8 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.fp8_autocast .. autoapifunction:: transformer_engine.pytorch.fp8_autocast
.. autoapifunction:: transformer_engine.pytorch.fp8_model_init
.. autoapifunction:: transformer_engine.pytorch.checkpoint .. autoapifunction:: transformer_engine.pytorch.checkpoint
.. autoapifunction:: transformer_engine.pytorch.onnx_export .. autoapifunction:: transformer_engine.pytorch.onnx_export
...@@ -12,3 +12,4 @@ PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pyt ...@@ -12,3 +12,4 @@ PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pyt
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_attn.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_attn.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from collections.abc import Iterable
from typing import Any, Dict, List, Tuple, Union
import pytest
import torch
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine_extensions as tex
# PyTorch tensor dtypes
_dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16]
# TE FP8 dtypes
_fp8_dtypes: List[tex.DType] = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]
# Numerical tolerances with FP8 types
_tols: Dict[tex.DType, Dict[str, float]] = {
tex.DType.kFloat8E4M3: dict(rtol=0.125, atol=0.0675), # epsilon = 0.0625
tex.DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125), # epsilon = 0.125
}
def _to_list(x: Union[Iterable, Any]) -> List:
"""Convert to list if iterable, otherwise put in singleton list"""
if isinstance(x, Iterable):
return list(x)
else:
return [x]
# Types that can be interpreted as tensor dims
DimsType = Union[Iterable[int], int]
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFloat8Tensor:
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def test_constructor(
self,
dims: DimsType = 1,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale_inv: float = 0.375,
dtype: torch.dtype = torch.float32,
) -> None:
"""Call constructor and perform sanity checks"""
dims = _to_list(dims)
tensor = Float8Tensor(
data=torch.zeros(dims, device="cuda", dtype=torch.uint8),
fp8_dtype=fp8_dtype,
fp8_scale_inv=torch.full([1], scale_inv),
dtype=dtype,
)
assert list(tensor.size()) == dims, "Incorrect dims"
assert tensor.dtype == dtype, "Incorrect nominal dtype"
assert tensor.is_cuda, "Incorrect device"
def _test_quantize_dequantize(
self,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 3.5,
dtype: torch.dtype = torch.float32,
dims: DimsType = 23,
) -> None:
"""Check numerical error when casting to FP8 and back"""
# Initialize random data
x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cpu") - 1
# Cast to FP8 and back
x_fp8 = Float8Tensor.to_float8(
x_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x_fp8 = x_fp8.from_float8().cpu()
# Check results
torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype])
# Make sure we are not trivially passing the test
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8, -x_ref, **_tols[fp8_dtype])
@pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
@pytest.mark.parametrize("dtype", _dtypes)
def test_quantize_dequantize_dtypes(
self,
fp8_dtype: tex.DType,
dtype: torch.dtype,
) -> None:
self._test_quantize_dequantize(fp8_dtype=fp8_dtype, dtype=dtype)
@pytest.mark.parametrize("scale", [0.375, 1, 3.5])
def test_quantize_dequantize_scales(self, scale: float) -> None:
self._test_quantize_dequantize(scale=scale)
@pytest.mark.parametrize("dims", [[], 1, 311, [7,11], [7,5,3], [2,3,5,3]])
def test_quantize_dequantize_dims(self, dims: DimsType) -> None:
self._test_quantize_dequantize(dims=dims)
def test_fp8_meta(
self,
dtype: torch.dtype = torch.float32,
dims: DimsType = 23,
) -> None:
"""Construct Float8Tensor using FP8 metadata and perform basic checks"""
# Get FP8 metadata from linear module
fp8_dtype = tex.DType.kFloat8E4M3
recipe = transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
module = te.Linear(32, 32)
_ = module(torch.zeros([8, 32], device="cuda"))
fp8_meta = module.fp8_meta
fp8_meta_index = tex.FP8FwdTensors.GEMM1_WEIGHT
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)
# Initialize random data
dims = _to_list(dims)
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
# Make Float8Tensor
x_fp8 = Float8Tensor.to_float8(
x_ref,
fp8_meta=fp8_meta,
fp8_meta_index=fp8_meta_index,
)
x_ref = x_fp8.from_float8()
assert list(x_fp8.size()) == dims, "Incorrect dims"
assert x_fp8.dtype == dtype, "Incorrect nominal dtype"
assert x_fp8.is_cuda, "Incorrect device"
assert x_fp8._fp8_dtype == fp8_dtype, "Incorrect FP8 dtype"
# Change FP8 metadata scale
fp8_meta[fp8_meta_key].scale[fp8_meta_index] = 2
fp8_meta[fp8_meta_key].scale_inv.fill_(123)
# Check results
torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype])
with pytest.raises(AssertionError):
# Make sure we are not trivially passing the test
torch.testing.assert_close(x_fp8, -x_ref, **_tols[fp8_dtype])
# Check if scaling factor is updated after in-place ops
x_fp8 += 0
fp8_meta[fp8_meta_key].scale[fp8_meta_index] = 4
fp8_meta[fp8_meta_key].scale_inv.fill_(321)
assert x_fp8._scale_inv.item() == 0.5, "Incorrect FP8 scale_inv"
torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype])
y = x_fp8.detach()
y += 0
assert x_fp8._scale_inv.item() == 0.25, "Incorrect FP8 scale_inv"
torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype])
def test_basic_ops(
self,
dims: DimsType = 23,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 3.5,
dtype: torch.dtype = torch.float32,
) -> None:
"""Test basic out-of-place ops"""
# Initialize random data
dims = _to_list(dims)
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
x_fp8 = Float8Tensor.to_float8(
x_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
y_fp8 = Float8Tensor.to_float8(
y_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x_ref = x_fp8.from_float8()
y_ref = y_fp8.from_float8()
# Exact operations
torch.testing.assert_close(-x_fp8, -x_ref, rtol=0, atol=0)
torch.testing.assert_close(x_fp8.abs(), x_ref.abs(), rtol=0, atol=0)
# Operations with numerical error
tols = _tols[fp8_dtype]
torch.testing.assert_close(x_fp8 + y_fp8, x_ref + y_ref, **tols)
torch.testing.assert_close(x_fp8 - y_fp8, x_ref - y_ref, **tols)
torch.testing.assert_close(x_fp8 * y_fp8, x_ref * y_ref, **tols)
torch.testing.assert_close(x_fp8 + y_ref, x_ref + y_ref, **tols)
torch.testing.assert_close(x_ref + y_fp8, x_ref + y_ref, **tols)
torch.testing.assert_close(torch.sin(x_fp8), torch.sin(x_ref), **tols)
# Make sure we are not trivially passing tests
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8 + y_fp8, x_ref - y_fp8, **tols)
def test_inplace_ops(
self,
dims: DimsType = 23,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 3.5,
dtype: torch.dtype = torch.float32,
) -> None:
"""Test in-place ops"""
# Initialize random data
dims = _to_list(dims)
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
x_fp8 = Float8Tensor.to_float8(
x_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
y_fp8 = Float8Tensor.to_float8(
y_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x_ref = x_fp8.from_float8()
y_ref = y_fp8.from_float8()
# In-place operations
tols = _tols[fp8_dtype]
x_fp8 += y_ref
x_ref += y_ref
torch.testing.assert_close(x_fp8, x_ref, **tols)
x_ref = x_fp8.from_float8()
x_fp8 -= y_fp8
x_ref -= y_fp8
torch.testing.assert_close(x_fp8, x_ref, **tols)
x_ref = x_fp8.from_float8()
x_fp8 *= 2
x_ref *= 2
torch.testing.assert_close(x_fp8, x_ref, **tols)
x_ref = x_fp8.from_float8()
# Make sure we are not trivially passing tests
x_ref += 123
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8, x_ref, **tols)
@pytest.mark.parametrize("dims", [[33, 41], [5, 7, 11]])
@pytest.mark.parametrize("transpose_dims", [(0, 1), (-2, -1), (0, 0)])
def test_transpose(
self,
dims: DimsType,
transpose_dims: Tuple[int, int],
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 1,
dtype: torch.dtype = torch.float32,
) -> None:
"""Test transpose"""
# Initialize random data
dims = _to_list(dims)
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
x_fp8 = Float8Tensor.to_float8(
x_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x_ref = x_fp8.from_float8()
# Perform transpose
y_fp8 = x_fp8.transpose(*transpose_dims)
y_ref = x_ref.transpose(*transpose_dims)
# Check results
tols = dict(rtol=0, atol=0)
torch.testing.assert_close(y_fp8, y_ref, **tols)
# Make sure we are not trivially passing the test
if transpose_dims[0] != transpose_dims[1]:
with pytest.raises(AssertionError):
torch.testing.assert_close(
y_fp8,
x_ref,
**tols,
)
# Check transpose caching
if x_fp8.dim() == 2 and transpose_dims[0] != transpose_dims[1]:
x_fp8 += 0.5
x_ref = x_fp8.from_float8()
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache=True),
x_ref.transpose(*transpose_dims),
**tols,
)
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache=True),
x_ref.transpose(*transpose_dims),
**tols,
)
x_fp8 += 0.5
x_ref = x_fp8.from_float8()
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache=True),
x_ref.transpose(*transpose_dims),
**tols,
)
...@@ -12,7 +12,7 @@ import torch ...@@ -12,7 +12,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import Parameter from torch.nn import Parameter
from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager, fp8_model_init
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
init_method_normal, init_method_normal,
scaled_init_method_normal, scaled_init_method_normal,
...@@ -339,7 +339,7 @@ class TorchGPT(nn.Module): ...@@ -339,7 +339,7 @@ class TorchGPT(nn.Module):
return x return x
def _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False): def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False):
reset_rng_states() reset_rng_states()
FP8GlobalStateManager.reset() FP8GlobalStateManager.reset()
...@@ -354,6 +354,7 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False): ...@@ -354,6 +354,7 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False):
"""Get cuda rng tracker.""" """Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER return _DUMMY_CUDA_RNG_STATE_TRACKER
with fp8_model_init(enabled=fp8 and fp8_model_params):
block = ( block = (
TransformerLayer( TransformerLayer(
config.hidden_size, config.hidden_size,
...@@ -369,6 +370,7 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False): ...@@ -369,6 +370,7 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False):
output_layernorm=False, output_layernorm=False,
get_rng_state_tracker=get_dummy_cuda_rng_tracker, get_rng_state_tracker=get_dummy_cuda_rng_tracker,
params_dtype=dtype, params_dtype=dtype,
fuse_qkv_params=True,
) )
.cuda() .cuda()
) )
...@@ -400,18 +402,19 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False): ...@@ -400,18 +402,19 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False):
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("fp8", all_boolean) @pytest.mark.parametrize("fp8", all_boolean)
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8): @pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_params):
if fp8 and not fp8_available: if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
config = model_configs[model] config = model_configs[model]
outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False) outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False)
outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=True) outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True)
assert_all_equal(outputs, outputs_recompute) assert_all_equal(outputs, outputs_recompute)
def _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False): def _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False):
reset_rng_states() reset_rng_states()
FP8GlobalStateManager.reset() FP8GlobalStateManager.reset()
...@@ -426,6 +429,7 @@ def _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False): ...@@ -426,6 +429,7 @@ def _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False):
"""Get cuda rng tracker.""" """Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER return _DUMMY_CUDA_RNG_STATE_TRACKER
with fp8_model_init(enabled=fp8 and fp8_model_params):
block = ( block = (
TransformerLayer( TransformerLayer(
config.hidden_size, config.hidden_size,
...@@ -441,6 +445,7 @@ def _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False): ...@@ -441,6 +445,7 @@ def _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False):
output_layernorm=False, output_layernorm=False,
get_rng_state_tracker=get_dummy_cuda_rng_tracker, get_rng_state_tracker=get_dummy_cuda_rng_tracker,
params_dtype=dtype, params_dtype=dtype,
fuse_qkv_params=True,
) )
.cuda() .cuda()
) )
...@@ -483,14 +488,15 @@ def _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False): ...@@ -483,14 +488,15 @@ def _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False):
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("fp8", all_boolean) @pytest.mark.parametrize("fp8", all_boolean)
def test_gpt_full_activation_recompute(dtype, bs, model, fp8): @pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params):
if fp8 and not fp8_available: if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
config = model_configs[model] config = model_configs[model]
outputs = _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False) outputs = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False)
outputs_recompute = _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=True) outputs_recompute = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True)
assert_all_equal(outputs, outputs_recompute) assert_all_equal(outputs, outputs_recompute)
...@@ -871,6 +877,7 @@ def test_linear_accuracy(dtype, bs, model): ...@@ -871,6 +877,7 @@ def test_linear_accuracy(dtype, bs, model):
else: else:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
...@@ -911,6 +918,7 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps): ...@@ -911,6 +918,7 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps):
else: else:
assert_allclose(te_outputs[0], torch_outputs[0], 2e-2) assert_allclose(te_outputs[0], torch_outputs[0], 2e-2)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
...@@ -1110,3 +1118,72 @@ def test_gpt_cuda_graph(dtype, bs, model): ...@@ -1110,3 +1118,72 @@ def test_gpt_cuda_graph(dtype, bs, model):
assert_allclose(out, graphed_out, 1e-3) assert_allclose(out, graphed_out, 1e-3)
assert_allclose(params, graphed_params, 1e-3) assert_allclose(params, graphed_params, 1e-3)
assert_allclose(grads, graphed_grads, 1e-3) assert_allclose(grads, graphed_grads, 1e-3)
def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params):
reset_rng_states()
FP8GlobalStateManager.reset()
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
with fp8_model_init(enabled=fp8_model_params):
block = (
TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
params_dtype=dtype,
fuse_qkv_params=True,
)
.cuda()
)
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
with fp8_autocast(enabled=True):
te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
outputs = [te_out, te_inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_gpt_fp8_parameters(dtype, bs, model):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
outputs = _test_gpt_fp8_parameters(bs, dtype, config, False)
outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True)
assert_all_equal(outputs, outputs_fp8_params)
...@@ -147,7 +147,7 @@ def set_layer_scale(module: torch.nn.Module, scale: float, num_gemms: int): ...@@ -147,7 +147,7 @@ def set_layer_scale(module: torch.nn.Module, scale: float, num_gemms: int):
"""Initialize the FP8 quantization scales in module""" """Initialize the FP8 quantization scales in module"""
NB_SCALES_PER_GEMM = 3 # One scale per: input, weights, and output GEMM tensors. NB_SCALES_PER_GEMM = 3 # One scale per: input, weights, and output GEMM tensors.
nb_total_scales = num_gemms * NB_SCALES_PER_GEMM nb_total_scales = num_gemms * NB_SCALES_PER_GEMM
module.fp8_init(num_gemms) module.init_fp8_metadata(num_gemms)
module.fp8_meta["scaling_fwd"].scale = torch.ones( module.fp8_meta["scaling_fwd"].scale = torch.ones(
nb_total_scales, dtype=torch.float32, device="cuda") / scale nb_total_scales, dtype=torch.float32, device="cuda") / scale
module.fp8_meta["scaling_fwd"].scale_inv = torch.ones( module.fp8_meta["scaling_fwd"].scale_inv = torch.ones(
......
...@@ -16,7 +16,7 @@ import pytest ...@@ -16,7 +16,7 @@ import pytest
import torch import torch
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8, cast_from_fp8 from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8
from transformer_engine.pytorch.module.base import get_workspace from transformer_engine.pytorch.module.base import get_workspace
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
...@@ -93,7 +93,7 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd ...@@ -93,7 +93,7 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd
model_in = Test_TE_Export(precision, True) model_in = Test_TE_Export(precision, True)
with te.fp8_autocast(enabled=True): with te.fp8_autocast(enabled=True):
model_in.fp8_init() model_in.init_fp8_metadata()
# scaling fwd # scaling fwd
model_in.fp8_meta["scaling_fwd"].scale = torch.ones(3, dtype=torch.float32, device="cuda") * scale_fwd model_in.fp8_meta["scaling_fwd"].scale = torch.ones(3, dtype=torch.float32, device="cuda") * scale_fwd
model_in.fp8_meta["scaling_fwd"].scale_inv = torch.ones(3, dtype=torch.float32, device="cuda") / scale_fwd model_in.fp8_meta["scaling_fwd"].scale_inv = torch.ones(3, dtype=torch.float32, device="cuda") / scale_fwd
......
...@@ -13,6 +13,7 @@ from .attention import InferenceParams ...@@ -13,6 +13,7 @@ from .attention import InferenceParams
from .attention import MultiheadAttention from .attention import MultiheadAttention
from .transformer import TransformerLayer from .transformer import TransformerLayer
from .fp8 import fp8_autocast from .fp8 import fp8_autocast
from .fp8 import fp8_model_init
from .export import onnx_export from .export import onnx_export
from .distributed import checkpoint from .distributed import checkpoint
from .distributed import CudaRNGStatesTracker from .distributed import CudaRNGStatesTracker
......
...@@ -83,11 +83,13 @@ def initialize_affine_weight_gpu( ...@@ -83,11 +83,13 @@ def initialize_affine_weight_gpu(
weight: torch.Tensor, weight: torch.Tensor,
init_method: Callable, init_method: Callable,
get_rng_state_tracker: Callable, get_rng_state_tracker: Callable,
partition_dim: int, partition_dim: int = 0,
stride: int = 1, stride: int = 1,
set_tp_attributes: bool = True,
) -> None: ) -> None:
"""Initialize affine weight for model parallel on GPU.""" """Initialize affine weight for model parallel on GPU."""
if set_tp_attributes:
set_tensor_model_parallel_attributes( set_tensor_model_parallel_attributes(
tensor=weight, is_parallel=True, dim=partition_dim, stride=stride tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
) )
......
This diff is collapsed.
...@@ -17,7 +17,7 @@ from .utils import get_device_compute_capability ...@@ -17,7 +17,7 @@ from .utils import get_device_compute_capability
from .jit import jit_fuser from .jit import jit_fuser
__all__ = ["fp8_autocast"] __all__ = ["fp8_autocast", "fp8_model_init"]
def check_fp8_support() -> Tuple[bool, str]: def check_fp8_support() -> Tuple[bool, str]:
...@@ -59,6 +59,7 @@ class FP8GlobalStateManager: ...@@ -59,6 +59,7 @@ class FP8GlobalStateManager:
FP8_CALIBRATION = False FP8_CALIBRATION = False
FP8_RECIPE = None FP8_RECIPE = None
FP8_DISTRIBUTED_GROUP = None FP8_DISTRIBUTED_GROUP = None
FP8_PARAMETERS = False
IS_FIRST_FP8_MODULE = False IS_FIRST_FP8_MODULE = False
FP8_AUTOCAST_COUNTER = 0 FP8_AUTOCAST_COUNTER = 0
FP8_CURRENT_CONTEXT_ID = 0 FP8_CURRENT_CONTEXT_ID = 0
...@@ -277,6 +278,11 @@ class FP8GlobalStateManager: ...@@ -277,6 +278,11 @@ class FP8GlobalStateManager:
"""Is FP8 calibration""" """Is FP8 calibration"""
return cls.FP8_CALIBRATION return cls.FP8_CALIBRATION
@classmethod
def with_fp8_parameters(cls) -> bool:
"""Should the parameters be stored as FP8"""
return cls.FP8_PARAMETERS
@classmethod @classmethod
def is_first_fp8_module(cls): def is_first_fp8_module(cls):
"""Returns `True` only the first time when called multiple """Returns `True` only the first time when called multiple
...@@ -400,6 +406,11 @@ class FP8GlobalStateManager: ...@@ -400,6 +406,11 @@ class FP8GlobalStateManager:
fp8_group: Optional[dist_group_type] = None, fp8_group: Optional[dist_group_type] = None,
) -> None: ) -> None:
"""Set state and tracking variables for entry into FP8 region.""" """Set state and tracking variables for entry into FP8 region."""
if cls.FP8_AUTOCAST_DEPTH == 0:
if callable(cls.amax_forward_global_reduce_func):
cls.amax_reduce_handle_fwd = cls.amax_forward_global_reduce_func() # pylint: disable=not-callable
cls.delete_key_from_amax_buffer(forward=True)
cls.FP8_ENABLED = enabled cls.FP8_ENABLED = enabled
cls.FP8_CALIBRATION = calibrating cls.FP8_CALIBRATION = calibrating
cls.FP8_RECIPE = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe cls.FP8_RECIPE = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
...@@ -419,11 +430,6 @@ class FP8GlobalStateManager: ...@@ -419,11 +430,6 @@ class FP8GlobalStateManager:
"""Set state and tracking variables for exit from FP8 region.""" """Set state and tracking variables for exit from FP8 region."""
cls.FP8_AUTOCAST_DEPTH -= 1 cls.FP8_AUTOCAST_DEPTH -= 1
if cls.FP8_AUTOCAST_DEPTH == 0:
if callable(cls.amax_forward_global_reduce_func):
cls.amax_reduce_handle_fwd = cls.amax_forward_global_reduce_func() # pylint: disable=not-callable
cls.delete_key_from_amax_buffer(forward=True)
@classmethod @classmethod
def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None:
"""Copy the scaling factors and amaxes for recompute forward phase """Copy the scaling factors and amaxes for recompute forward phase
...@@ -477,9 +483,45 @@ class FP8GlobalStateManager: ...@@ -477,9 +483,45 @@ class FP8GlobalStateManager:
fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"] fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"]
@contextmanager
def fp8_model_init(enabled: bool = True) -> None:
"""
Context manager for FP8 initialization of parameters.
Example usage:
.. code-block:: python
with fp8_model_init(enabled=True):
model = transformer_engine.pytorch.Linear(768, 768)
Parameters
----------
enabled: bool, default = `True`
when enabled, Transformer Engine modules created inside this `fp8_model_init`
region will hold only FP8 copies of its parameters, as opposed to the default
behavior where both higher precision and FP8 copies are present. Setting this
option to `True` may result in lower memory consumption and is especially
useful for scenarios like:
* full model training using optimizer with master weights, where the high
precision copies of weights are already present in the optimizer.
* inference, where only the FP8 copies of the parameters are used.
* LoRA-like fine-tuning, where the main parameters of the model do not change.
This functionality is *EXPERIMENTAL*.
"""
try:
_fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS
FP8GlobalStateManager.FP8_PARAMETERS = enabled
yield
finally:
FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters # pylint: disable=used-before-assignment
@contextmanager @contextmanager
def fp8_autocast( def fp8_autocast(
enabled: bool = False, enabled: bool = True,
calibrating: bool = False, calibrating: bool = False,
fp8_recipe: Optional[DelayedScaling] = None, fp8_recipe: Optional[DelayedScaling] = None,
fp8_group: Optional[dist_group_type] = None, fp8_group: Optional[dist_group_type] = None,
...@@ -508,7 +550,7 @@ def fp8_autocast( ...@@ -508,7 +550,7 @@ def fp8_autocast(
Parameters Parameters
---------- ----------
enabled: bool, default = `False` enabled: bool, default = `True`
whether or not to enable fp8 whether or not to enable fp8
calibrating: bool, default = `False` calibrating: bool, default = `False`
calibration mode allows collecting statistics such as amax and scale calibration mode allows collecting statistics such as amax and scale
...@@ -523,7 +565,10 @@ def fp8_autocast( ...@@ -523,7 +565,10 @@ def fp8_autocast(
""" """
try: try:
fp8_state = FP8GlobalStateManager.get_fp8_autocast_state() fp8_state = FP8GlobalStateManager.get_fp8_autocast_state()
FP8GlobalStateManager.fp8_autocast_enter(enabled, calibrating, fp8_recipe, fp8_group) FP8GlobalStateManager.fp8_autocast_enter(enabled=enabled,
calibrating=calibrating,
fp8_recipe=fp8_recipe,
fp8_group=fp8_group)
yield yield
finally: finally:
FP8GlobalStateManager.set_fp8_autocast_state(fp8_state) # pylint: disable=used-before-assignment FP8GlobalStateManager.set_fp8_autocast_state(fp8_state) # pylint: disable=used-before-assignment
......
...@@ -36,6 +36,7 @@ from ..cpp_extensions import ( ...@@ -36,6 +36,7 @@ from ..cpp_extensions import (
cast_to_fp8, cast_to_fp8,
) )
from ..constants import dist_group_type from ..constants import dist_group_type
from ..float8_tensor import Float8Tensor
_2X_ACC_FPROP = False _2X_ACC_FPROP = False
_2X_ACC_DGRAD = True _2X_ACC_DGRAD = True
...@@ -451,21 +452,29 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -451,21 +452,29 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
setattr( setattr(
self, self,
weight_cast_attr, weight_cast_attr,
torch.empty( Float8Tensor(
data=torch.empty(
shape, shape,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=torch.uint8, dtype=torch.uint8,
), ),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_scale_inv=1,
)
) )
setattr( setattr(
self, self,
weight_transpose_attr, weight_transpose_attr,
torch.empty( Float8Tensor(
data=torch.empty(
shape[1], shape[1],
shape[0], shape[0],
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=torch.uint8, dtype=torch.uint8,
), ),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_scale_inv=1,
)
) )
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
...@@ -483,12 +492,17 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -483,12 +492,17 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# This routine is shared across FP8 and FP8_calibration paths so should not actually # This routine is shared across FP8 and FP8_calibration paths so should not actually
# assume FP8 execution. # assume FP8 execution.
def fp8_init(self, num_gemms: int = 1) -> None: def init_fp8_metadata(self, num_gemms: int = 1) -> None:
"""Initialize fp8 related metadata and tensors during fprop.""" """Initialize fp8 related metadata and tensors during fprop."""
self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
self.fp8 = FP8GlobalStateManager.is_fp8_enabled() self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
if self.fp8_parameters and not self.fp8_initialized:
self.fp8_meta["num_gemms"] = num_gemms
self.init_fp8_meta_tensors()
if self.fp8 or self.fp8_calibration: if self.fp8 or self.fp8_calibration:
# FP8 init has already been run and recipe is the same, don't do anything. # FP8 init has already been run and recipe is the same, don't do anything.
if (self.fp8_initialized if (self.fp8_initialized
...@@ -536,7 +550,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -536,7 +550,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
assert self.tp_group_initialized, "TP group not initialized." assert self.tp_group_initialized, "TP group not initialized."
self.set_activation_dtype(inp) self.set_activation_dtype(inp)
self.fp8_init(num_gemms=num_gemms) self.init_fp8_metadata(num_gemms=num_gemms)
# Create persistent tensors for fp8 weights and their transposes # Create persistent tensors for fp8 weights and their transposes
# only when fp8 weight caching is used. # only when fp8 weight caching is used.
...@@ -765,7 +779,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -765,7 +779,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def get_fp8_weights_empty_tensors( def get_fp8_weights_empty_tensors(
self, self,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]: ) -> List[Float8Tensor]:
""" """
Returns empty tensors to be later used to store fp8 version of weights Returns empty tensors to be later used to store fp8 version of weights
and their transposes (for the bwd pass) for this batch (or microbatch). and their transposes (for the bwd pass) for this batch (or microbatch).
...@@ -781,23 +795,42 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -781,23 +795,42 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
fp8_weight_tensors = [] fp8_weight_tensors = []
for shape in self.fp8_weight_shapes: for shape in self.fp8_weight_shapes:
fp8_weight_tensors.append( fp8_weight_tensors.append(
torch.empty( Float8Tensor(
data=torch.empty(
shape, shape,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=torch.uint8, dtype=torch.uint8,
),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_scale_inv=1,
) )
) )
fp8_weight_tensors.append( fp8_weight_tensors.append(
torch.empty( Float8Tensor(
data=torch.empty(
shape[1], shape[1],
shape[0], shape[0],
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=torch.uint8, dtype=torch.uint8,
),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_scale_inv=1,
) )
) )
return fp8_weight_tensors return fp8_weight_tensors
def state_dict(self, *args, **kwargs) -> Dict:
"""Get dictionary containing module state"""
state = super().state_dict(*args, **kwargs)
# Convert Float8Tensors to plain tensors
# Note: Float8Tensors don't serialize well, especially if they
# contain references to FP8 metadata.
for key, val in state.items():
if isinstance(val, Float8Tensor):
state[key] = val.from_float8()
return state
@abstractmethod @abstractmethod
def forward(self): def forward(self):
......
...@@ -23,7 +23,7 @@ from .base import ( ...@@ -23,7 +23,7 @@ from .base import (
_2X_ACC_DGRAD, _2X_ACC_DGRAD,
_2X_ACC_WGRAD, _2X_ACC_WGRAD,
) )
from ..fp8 import get_fp8_te_dtype from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager
from ..utils import ( from ..utils import (
divide, divide,
get_default_init_method, get_default_init_method,
...@@ -43,6 +43,7 @@ from ..jit import no_torch_dynamo ...@@ -43,6 +43,7 @@ from ..jit import no_torch_dynamo
from ._common import _apply_normalization from ._common import _apply_normalization
from ..float8_tensor import Float8Tensor
__all__ = ["LayerNormLinear"] __all__ = ["LayerNormLinear"]
...@@ -79,10 +80,11 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -79,10 +80,11 @@ class _LayerNormLinear(torch.autograd.Function):
fwd_ln_sm_margin: int, fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int, bwd_ln_sm_margin: int,
zero_centered_gamma: bool, zero_centered_gamma: bool,
normalization: str,
primary_weights_in_fp8: bool,
ub_bulk_wgrad: bool, ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool, ub_bulk_dgrad: bool,
ub_split_ag: bool, ub_split_ag: bool,
normalization: str,
ub_atomic_gemm_ag: bool, ub_atomic_gemm_ag: bool,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
...@@ -159,28 +161,43 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -159,28 +161,43 @@ class _LayerNormLinear(torch.autograd.Function):
) )
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
if update_fp8_weights: if primary_weights_in_fp8:
# Weight is already in FP8
weight.reset_fp8_meta_scale_inv()
weight_fp8 = weight
weight_t_fp8 = None
if is_grad_enabled:
weight_t_fp8 = weight_fp8.transpose(update_cache=is_first_microbatch)
elif update_fp8_weights:
# Need to cast weights to FP8
weight_fp8 = Float8Tensor(
data=weight_fp8._data,
fp8_meta=fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
if is_grad_enabled: if is_grad_enabled:
tex.fp8_cast_transpose_fused( tex.fp8_cast_transpose_fused(
weight, weight,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
cast_out=weight_fp8, cast_out=weight_fp8._data,
transpose_out=weight_t_fp8, transpose_out=weight_t_fp8._data,
) )
else: else:
weight_t_fp8 = None weight_fp8._data = tex.cast_to_fp8(
weight_fp8 = tex.cast_to_fp8(
weight, weight,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward) fp8_dtype_forward,
)
weight_t_fp8 = None
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo
out, _ = tex.fp8_gemm( out, _ = tex.fp8_gemm(
weight_fp8, weight_fp8._data,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
...@@ -356,7 +373,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -356,7 +373,7 @@ class _LayerNormLinear(torch.autograd.Function):
# DGRAD: Evaluated unconditionally to feed into Linear backward # DGRAD: Evaluated unconditionally to feed into Linear backward
_ = tex.fp8_gemm( _ = tex.fp8_gemm(
weight_t_fp8, weight_t_fp8._data,
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
...@@ -544,6 +561,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -544,6 +561,7 @@ class _LayerNormLinear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -646,10 +664,10 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -646,10 +664,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
return_layernorm_output: bool = False, return_layernorm_output: bool = False,
parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
device: Union[torch.device, str] = "cuda",
ub_bulk_wgrad: bool = False, ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
ub_split_ag: bool = False, ub_split_ag: bool = False,
device: Union[torch.device, str] = "cuda",
ub_atomic_gemm_ag: bool = False, ub_atomic_gemm_ag: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -666,6 +684,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -666,6 +684,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.return_layernorm_output = return_layernorm_output self.return_layernorm_output = return_layernorm_output
self.parameters_split = parameters_split self.parameters_split = parameters_split
self.zero_centered_gamma = zero_centered_gamma self.zero_centered_gamma = zero_centered_gamma
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_wgrad = ub_bulk_wgrad
self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_split_ag = ub_split_ag self.ub_split_ag = ub_split_ag
...@@ -719,18 +738,30 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -719,18 +738,30 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.layer_norm_bias = None self.layer_norm_bias = None
self.reset_layer_norm_parameters() self.reset_layer_norm_parameters()
self.weight_tensor = torch.empty( temp_weight = torch.empty(
self.out_features, self.in_features, self.out_features, self.in_features,
device=device, dtype=params_dtype) device=device, dtype=params_dtype)
initialize_affine_weight_gpu( initialize_affine_weight_gpu(
self.weight_tensor, temp_weight,
init_method, init_method,
get_rng_state_tracker, get_rng_state_tracker,
partition_dim=1 if self.parallel_mode == "row" else 0, partition_dim=1 if self.parallel_mode == "row" else 0,
stride=1, stride=1,
) )
if self.primary_weights_in_fp8:
self.init_fp8_metadata()
self.fp8_meta["update_amax_and_scale_fwd"] = True
self.weight_tensor = Float8Tensor.to_float8(
temp_weight,
fp8_meta=self.fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
else:
self.weight_tensor = temp_weight
if self.use_bias: if self.use_bias:
self.bias_tensor = torch.empty( self.bias_tensor = torch.empty(
self.out_features, self.out_features,
...@@ -769,7 +800,14 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -769,7 +800,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
bname = pname + "bias" bname = pname + "bias"
slice_end = slice_begin + slice_size slice_end = slice_begin + slice_size
# NOTE(future): Figure out a way to support slicing when weights
# are of `Float8Tensor` class
if self.primary_weights_in_fp8:
assert len(parameters_split) == 1, ("Slicing operation is not "
"supported in Float8Tensor "
"class!")
self.register_parameter(wname, Parameter(self.weight_tensor))
else:
self.register_parameter( self.register_parameter(
wname, Parameter(self.weight_tensor[slice_begin:slice_end]) wname, Parameter(self.weight_tensor[slice_begin:slice_end])
) )
...@@ -833,7 +871,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -833,7 +871,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
`is_first_microbatch` is not `None`) or return empty fp8 weight `is_first_microbatch` is not `None`) or return empty fp8 weight
tensors (if `is_first_microbatch is None`) tensors (if `is_first_microbatch is None`)
""" """
if not self.fp8: if not self.fp8 or self.primary_weights_in_fp8:
return [None, None] return [None, None]
if is_first_microbatch is None: if is_first_microbatch is None:
...@@ -877,6 +915,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -877,6 +915,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
""" """
with self.prepare_forward(inp, is_first_microbatch) as inp: with self.prepare_forward(inp, is_first_microbatch) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8."
bias_tensor = ( bias_tensor = (
self.bias if self.parameters_split is None self.bias if self.parameters_split is None
else self.bias_tensor if not torch.is_grad_enabled() else self.bias_tensor if not torch.is_grad_enabled()
...@@ -927,10 +967,11 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -927,10 +967,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fwd_ln_sm_margin, self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin, self.bwd_ln_sm_margin,
self.zero_centered_gamma, self.zero_centered_gamma,
self.normalization,
self.primary_weights_in_fp8,
self.ub_bulk_wgrad, self.ub_bulk_wgrad,
self.ub_bulk_dgrad, self.ub_bulk_dgrad,
self.ub_split_ag, self.ub_split_ag,
self.normalization,
self.ub_atomic_gemm_ag, self.ub_atomic_gemm_ag,
) )
out = fwd_fn(*args) out = fwd_fn(*args)
......
...@@ -20,7 +20,7 @@ from .base import ( ...@@ -20,7 +20,7 @@ from .base import (
_2X_ACC_DGRAD, _2X_ACC_DGRAD,
_2X_ACC_WGRAD, _2X_ACC_WGRAD,
) )
from ..fp8 import get_fp8_te_dtype from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager
from ..jit import ( from ..jit import (
bias_gelu_fused, bias_gelu_fused,
bgrad_dgelu_fused, bgrad_dgelu_fused,
...@@ -47,6 +47,7 @@ from .. import cpp_extensions as tex ...@@ -47,6 +47,7 @@ from .. import cpp_extensions as tex
from ..constants import dist_group_type, TE_DType from ..constants import dist_group_type, TE_DType
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ..float8_tensor import Float8Tensor
from ._common import _apply_normalization from ._common import _apply_normalization
...@@ -105,14 +106,15 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -105,14 +106,15 @@ class _LayerNormMLP(torch.autograd.Function):
fwd_ln_sm_margin: int, fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int, bwd_ln_sm_margin: int,
zero_centered_gamma: bool, zero_centered_gamma: bool,
activation: str,
normalization: str,
primary_weights_in_fp8: bool,
ub_bulk_wgrad: bool, ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool, ub_bulk_dgrad: bool,
ub_split_rs: bool, ub_split_rs: bool,
ub_atomic_gemm_rs: bool, ub_atomic_gemm_rs: bool,
ub_split_ag: bool, ub_split_ag: bool,
ub_atomic_gemm_ag: bool, ub_atomic_gemm_ag: bool,
activation: str,
normalization: str,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
...@@ -196,45 +198,68 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -196,45 +198,68 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_bias = cast_if_needed(fc1_bias, bias_dtype) if use_fc1_bias else fc1_bias fc1_bias = cast_if_needed(fc1_bias, bias_dtype) if use_fc1_bias else fc1_bias
fc2_bias = cast_if_needed(fc2_bias, bias_dtype) if use_fc2_bias else fc2_bias fc2_bias = cast_if_needed(fc2_bias, bias_dtype) if use_fc2_bias else fc2_bias
if update_fp8_weights: if primary_weights_in_fp8:
# Weights are already in FP8
fc1_weight.reset_fp8_meta_scale_inv()
fc2_weight.reset_fp8_meta_scale_inv()
fc1_weight_fp8 = fc1_weight
fc2_weight_fp8 = fc2_weight
fc1_weight_t_fp8 = None
fc2_weight_t_fp8 = None
if is_grad_enabled: if is_grad_enabled:
fc1_weight_t_fp8 = fc1_weight_fp8.transpose(update_cache=is_first_microbatch)
fc2_weight_t_fp8 = fc2_weight_fp8.transpose(update_cache=is_first_microbatch)
elif update_fp8_weights:
# Need to cast weights to FP8
fc1_weight_fp8 = Float8Tensor(
data=fc1_weight_fp8._data,
fp8_meta=fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
fc2_weight_fp8 = Float8Tensor(
data=fc2_weight_fp8._data,
fp8_meta=fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT,
)
if is_grad_enabled:
# Fused cast-transpose kernels
tex.fp8_cast_transpose_fused( tex.fp8_cast_transpose_fused(
fc1_weight, fc1_weight,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
cast_out=fc1_weight_fp8, cast_out=fc1_weight_fp8._data,
transpose_out=fc1_weight_t_fp8, transpose_out=fc1_weight_t_fp8._data,
) )
tex.fp8_cast_transpose_fused( tex.fp8_cast_transpose_fused(
fc2_weight, fc2_weight,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM2_WEIGHT, tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
cast_out=fc2_weight_fp8, cast_out=fc2_weight_fp8._data,
transpose_out=fc2_weight_t_fp8, transpose_out=fc2_weight_t_fp8._data,
) )
else: else:
fc1_weight_t_fp8 = None fc1_weight_fp8._data = tex.cast_to_fp8(
fc1_weight_fp8 = tex.cast_to_fp8(
fc1_weight, fc1_weight,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
) )
fc2_weight_t_fp8 = None fc1_weight_t_fp8 = None
fc2_weight_fp8 = tex.cast_to_fp8( fc2_weight_fp8._data = tex.cast_to_fp8(
fc2_weight, fc2_weight,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM2_WEIGHT, tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
) )
fc2_weight_t_fp8 = None
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo
fc1_out, _ = tex.fp8_gemm( fc1_out, _ = tex.fp8_gemm(
fc1_weight_fp8, fc1_weight_fp8._data,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
...@@ -283,7 +308,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -283,7 +308,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo
_ = tex.fp8_gemm( _ = tex.fp8_gemm(
fc2_weight_fp8, fc2_weight_fp8._data,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM2_WEIGHT, tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
...@@ -530,7 +555,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -530,7 +555,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ctx.ub_atomic_gemm_ag else ub_algo ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ctx.ub_atomic_gemm_ag else ub_algo
# FC2 DGRAD; Unconditional # FC2 DGRAD; Unconditional
fc2_dgrad, _ = tex.fp8_gemm( fc2_dgrad, _ = tex.fp8_gemm(
fc2_weight_t_fp8, fc2_weight_t_fp8._data,
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM2_WEIGHT, tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
...@@ -645,7 +670,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -645,7 +670,7 @@ class _LayerNormMLP(torch.autograd.Function):
) )
# FC1 DGRAD: Unconditional # FC1 DGRAD: Unconditional
_ = tex.fp8_gemm( _ = tex.fp8_gemm(
fc1_weight_t_fp8, fc1_weight_t_fp8._data,
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
...@@ -908,6 +933,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -908,6 +933,7 @@ class _LayerNormMLP(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -1020,12 +1046,12 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1020,12 +1046,12 @@ class LayerNormMLP(TransformerEngineBaseModule):
micro_batch_size: Optional[int] = None, micro_batch_size: Optional[int] = None,
set_parallel_mode: bool = False, set_parallel_mode: bool = False,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
device: Union[torch.device, str] = "cuda",
ub_bulk_wgrad: bool = False, ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
ub_split_rs: bool = False, ub_split_rs: bool = False,
ub_atomic_gemm_rs: bool = False, ub_atomic_gemm_rs: bool = False,
ub_split_ag: bool = False, ub_split_ag: bool = False,
device: Union[torch.device, str] = "cuda",
ub_atomic_gemm_ag: bool = False, ub_atomic_gemm_ag: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -1043,6 +1069,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1043,6 +1069,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.activation == 'gelu') self.activation == 'gelu')
self.set_parallel_mode = set_parallel_mode self.set_parallel_mode = set_parallel_mode
self.zero_centered_gamma = zero_centered_gamma self.zero_centered_gamma = zero_centered_gamma
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_wgrad = ub_bulk_wgrad
self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_split_rs = ub_split_rs self.ub_split_rs = ub_split_rs
...@@ -1102,19 +1129,30 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1102,19 +1129,30 @@ class LayerNormMLP(TransformerEngineBaseModule):
else: else:
fc1_output_features = self.size_per_partition fc1_output_features = self.size_per_partition
# FC1 init # FC1 init
self.fc1_weight = Parameter( fc1_temp_weight = torch.empty(
torch.empty(fc1_output_features, hidden_size, device=device, dtype=params_dtype) fc1_output_features, hidden_size, device=device, dtype=params_dtype)
)
self.fp8_weight_shapes.append(self.fc1_weight.shape)
initialize_affine_weight_gpu( initialize_affine_weight_gpu(
self.fc1_weight, fc1_temp_weight,
init_method, init_method,
get_rng_state_tracker, get_rng_state_tracker,
partition_dim=0, set_tp_attributes=False,
stride=1, )
if self.primary_weights_in_fp8:
self.init_fp8_metadata(num_gemms=2)
self.fp8_meta["update_amax_and_scale_fwd"] = True
fc1_temp_weight = Float8Tensor.to_float8(
fc1_temp_weight,
fp8_meta=self.fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
) )
self.fc1_weight = Parameter(fc1_temp_weight)
set_tensor_model_parallel_attributes(self.fc1_weight, True, 0, 1)
self.fp8_weight_shapes.append(self.fc1_weight.shape)
if self.use_bias: if self.use_bias:
self.fc1_bias = Parameter( self.fc1_bias = Parameter(
torch.empty(fc1_output_features, device=device, dtype=params_dtype) torch.empty(fc1_output_features, device=device, dtype=params_dtype)
...@@ -1127,19 +1165,27 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1127,19 +1165,27 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fc1_bias.zero_() self.fc1_bias.zero_()
# FC2 init # FC2 init
self.fc2_weight = Parameter( fc2_temp_weight = torch.empty(
torch.empty(hidden_size, self.size_per_partition, device=device, dtype=params_dtype) hidden_size, self.size_per_partition, device=device, dtype=params_dtype)
)
self.fp8_weight_shapes.append(self.fc2_weight.shape)
initialize_affine_weight_gpu( initialize_affine_weight_gpu(
self.fc2_weight, fc2_temp_weight,
output_layer_init_method, output_layer_init_method,
get_rng_state_tracker, get_rng_state_tracker,
partition_dim=1, set_tp_attributes=False,
stride=1, )
if self.primary_weights_in_fp8:
fc2_temp_weight = Float8Tensor.to_float8(
fc2_temp_weight,
fp8_meta=self.fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT,
) )
self.fc2_weight = Parameter(fc2_temp_weight)
set_tensor_model_parallel_attributes(self.fc2_weight, True, 1, 1)
self.fp8_weight_shapes.append(self.fc2_weight.shape)
if self.use_bias: if self.use_bias:
self.fc2_bias = Parameter( self.fc2_bias = Parameter(
torch.empty(hidden_size, device=device, dtype=params_dtype) torch.empty(hidden_size, device=device, dtype=params_dtype)
...@@ -1192,7 +1238,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1192,7 +1238,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
`is_first_microbatch` is not `None`) or return empty fp8 weight `is_first_microbatch` is not `None`) or return empty fp8 weight
tensors (if `is_first_microbatch is None`) tensors (if `is_first_microbatch is None`)
""" """
if not self.fp8: if not self.fp8 or self.primary_weights_in_fp8:
return [None, None, None, None] return [None, None, None, None]
if is_first_microbatch is None: if is_first_microbatch is None:
...@@ -1235,6 +1281,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1235,6 +1281,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
""" """
with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp: with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8."
# Fetch the fp8 weights placeholders (for linear/gemm) # Fetch the fp8 weights placeholders (for linear/gemm)
weight1_fp8, weight1_t_fp8, weight2_fp8, weight2_t_fp8 = \ weight1_fp8, weight1_t_fp8, weight2_fp8, weight2_t_fp8 = \
self.get_fp8_weights_scratchpad( self.get_fp8_weights_scratchpad(
...@@ -1279,14 +1327,15 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1279,14 +1327,15 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fwd_ln_sm_margin, self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin, self.bwd_ln_sm_margin,
self.zero_centered_gamma, self.zero_centered_gamma,
self.activation,
self.normalization,
self.primary_weights_in_fp8,
self.ub_bulk_wgrad, self.ub_bulk_wgrad,
self.ub_bulk_dgrad, self.ub_bulk_dgrad,
self.ub_split_rs, self.ub_split_rs,
self.ub_atomic_gemm_rs, self.ub_atomic_gemm_rs,
self.ub_split_ag, self.ub_split_ag,
self.ub_atomic_gemm_ag, self.ub_atomic_gemm_ag,
self.activation,
self.normalization,
) )
out = fwd_fn(*args) out = fwd_fn(*args)
......
...@@ -20,7 +20,7 @@ from .base import ( ...@@ -20,7 +20,7 @@ from .base import (
_2X_ACC_DGRAD, _2X_ACC_DGRAD,
_2X_ACC_WGRAD, _2X_ACC_WGRAD,
) )
from ..fp8 import get_fp8_te_dtype from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager
from ..utils import ( from ..utils import (
divide, divide,
get_default_init_method, get_default_init_method,
...@@ -45,6 +45,8 @@ from ..cpp_extensions import ( ...@@ -45,6 +45,8 @@ from ..cpp_extensions import (
from ..constants import GemmParallelModes, dist_group_type from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ..float8_tensor import Float8Tensor
__all__ = ["Linear"] __all__ = ["Linear"]
...@@ -57,9 +59,9 @@ class _Linear(torch.autograd.Function): ...@@ -57,9 +59,9 @@ class _Linear(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
ctx, ctx,
weight: torch.Tensor, weight: Union[Float8Tensor, torch.Tensor],
weight_fp8: Union[torch.Tensor, None], weight_fp8: Union[Float8Tensor, None],
weight_t_fp8: Union[torch.Tensor, None], weight_t_fp8: Union[Float8Tensor, None],
inp: torch.Tensor, inp: torch.Tensor,
bias: torch.Tensor, bias: torch.Tensor,
use_bias: bool, use_bias: bool,
...@@ -75,6 +77,7 @@ class _Linear(torch.autograd.Function): ...@@ -75,6 +77,7 @@ class _Linear(torch.autograd.Function):
activation_dtype: torch.dtype, activation_dtype: torch.dtype,
parallel_mode: Union[str, None], parallel_mode: Union[str, None],
is_grad_enabled: bool, is_grad_enabled: bool,
primary_weights_in_fp8: bool,
ub_split_rs: bool, ub_split_rs: bool,
ub_split_ag: bool, ub_split_ag: bool,
ub_atomic_gemm_rs: bool, ub_atomic_gemm_rs: bool,
...@@ -141,24 +144,38 @@ class _Linear(torch.autograd.Function): ...@@ -141,24 +144,38 @@ class _Linear(torch.autograd.Function):
) )
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
if update_fp8_weights: if primary_weights_in_fp8:
# Weight is already in FP8
weight.reset_fp8_meta_scale_inv()
weight_fp8 = weight
weight_t_fp8 = None
if is_grad_enabled:
weight_t_fp8 = weight_fp8.transpose(update_cache=is_first_microbatch)
elif update_fp8_weights:
# Need to cast weights to FP8
weight_fp8 = Float8Tensor(
data=weight_fp8._data,
fp8_meta=fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
if is_grad_enabled: if is_grad_enabled:
fp8_cast_transpose_fused( fp8_cast_transpose_fused(
weight, weight,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
cast_out=weight_fp8, cast_out=weight_fp8._data,
transpose_out=weight_t_fp8, transpose_out=weight_t_fp8._data,
) )
else: else:
weight_t_fp8 = None weight_fp8._data = cast_to_fp8(
weight_fp8 = cast_to_fp8(
weight, weight,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
) )
weight_t_fp8 = None
proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = (
None, None, None, activation_dtype) None, None, None, activation_dtype)
...@@ -184,7 +201,7 @@ class _Linear(torch.autograd.Function): ...@@ -184,7 +201,7 @@ class _Linear(torch.autograd.Function):
ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo
_ = fp8_gemm( _ = fp8_gemm(
weight_fp8, weight_fp8._data,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
...@@ -245,6 +262,9 @@ class _Linear(torch.autograd.Function): ...@@ -245,6 +262,9 @@ class _Linear(torch.autograd.Function):
if is_grad_enabled: if is_grad_enabled:
fp8_wgrad = fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad fp8_wgrad = fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad
if fp8:
assert hasattr(weight_t_fp8, "_data"), \
"_data attr doesn't exist (before save for bwd)"
ctx.save_for_backward( ctx.save_for_backward(
inputmat_no_fp8 if weight.requires_grad and not fp8_wgrad else None, inputmat_no_fp8 if weight.requires_grad and not fp8_wgrad else None,
inputmat_t if weight.requires_grad and fp8_wgrad else None, inputmat_t if weight.requires_grad and fp8_wgrad else None,
...@@ -294,6 +314,9 @@ class _Linear(torch.autograd.Function): ...@@ -294,6 +314,9 @@ class _Linear(torch.autograd.Function):
weight_t_fp8, weight_t_fp8,
fwd_scale_inverses, fwd_scale_inverses,
) = ctx.saved_tensors ) = ctx.saved_tensors
if weight_t_fp8 is not None:
assert hasattr(weight_t_fp8, "_data"), \
"_data attr doesn't exist (after restore in bwd)"
if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag: if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag:
tp_world_size = get_distributed_world_size(ctx.tp_group) tp_world_size = get_distributed_world_size(ctx.tp_group)
...@@ -349,7 +372,7 @@ class _Linear(torch.autograd.Function): ...@@ -349,7 +372,7 @@ class _Linear(torch.autograd.Function):
if ctx.requires_dgrad: if ctx.requires_dgrad:
if ctx.fp8: if ctx.fp8:
dgrad, _ = fp8_gemm( dgrad, _ = fp8_gemm(
weight_t_fp8, weight_t_fp8._data,
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
...@@ -470,6 +493,7 @@ class _Linear(torch.autograd.Function): ...@@ -470,6 +493,7 @@ class _Linear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -554,9 +578,9 @@ class Linear(TransformerEngineBaseModule): ...@@ -554,9 +578,9 @@ class Linear(TransformerEngineBaseModule):
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
parallel_mode: Optional[str] = None, parallel_mode: Optional[str] = None,
parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
device: Union[torch.device, str] = "cuda",
ub_split_rs: bool = False, ub_split_rs: bool = False,
ub_split_ag: bool = False, ub_split_ag: bool = False,
device: Union[torch.device, str] = "cuda",
ub_atomic_gemm_rs: bool = False, ub_atomic_gemm_rs: bool = False,
ub_atomic_gemm_ag: bool = False, ub_atomic_gemm_ag: bool = False,
) -> None: ) -> None:
...@@ -570,6 +594,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -570,6 +594,7 @@ class Linear(TransformerEngineBaseModule):
self.return_bias = return_bias self.return_bias = return_bias
self.apply_bias = bias and not return_bias self.apply_bias = bias and not return_bias
self.parameters_split = parameters_split self.parameters_split = parameters_split
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.ub_split_rs = ub_split_rs self.ub_split_rs = ub_split_rs
self.ub_split_ag = ub_split_ag self.ub_split_ag = ub_split_ag
self.ub_atomic_gemm_rs = ub_atomic_gemm_rs self.ub_atomic_gemm_rs = ub_atomic_gemm_rs
...@@ -609,18 +634,31 @@ class Linear(TransformerEngineBaseModule): ...@@ -609,18 +634,31 @@ class Linear(TransformerEngineBaseModule):
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
self.weight_tensor = torch.empty( temp_weight = torch.empty(
self.out_features, self.in_features, self.out_features, self.in_features,
device=device, dtype=params_dtype) device=device, dtype=params_dtype)
# TODO(ksivaman): This functionality works with FP8 outside TE.
initialize_affine_weight_gpu( initialize_affine_weight_gpu(
self.weight_tensor, temp_weight,
init_method, init_method,
get_rng_state_tracker, get_rng_state_tracker,
partition_dim=1 if self.parallel_mode == "row" else 0, partition_dim=1 if self.parallel_mode == "row" else 0,
stride=1, stride=1,
) )
if self.primary_weights_in_fp8:
self.init_fp8_metadata()
self.fp8_meta["update_amax_and_scale_fwd"] = True
self.weight_tensor = Float8Tensor.to_float8(
temp_weight,
fp8_meta=self.fp8_meta,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)
else:
self.weight_tensor = temp_weight
if self.use_bias: if self.use_bias:
self.bias_tensor = torch.empty(self.out_features, device=device, dtype=params_dtype) self.bias_tensor = torch.empty(self.out_features, device=device, dtype=params_dtype)
else: else:
...@@ -657,6 +695,14 @@ class Linear(TransformerEngineBaseModule): ...@@ -657,6 +695,14 @@ class Linear(TransformerEngineBaseModule):
slice_end = slice_begin + slice_size slice_end = slice_begin + slice_size
# TODO(ksivaman): Add indexing op to torch dispatcher for float8
if self.primary_weights_in_fp8:
assert len(parameters_split) == 1, ("Slicing operation is not "
"supported in Float8Tensor "
"class!")
self.register_parameter(wname, Parameter(self.weight_tensor))
else:
self.register_parameter( self.register_parameter(
wname, Parameter(self.weight_tensor[slice_begin:slice_end]) wname, Parameter(self.weight_tensor[slice_begin:slice_end])
) )
...@@ -697,13 +743,13 @@ class Linear(TransformerEngineBaseModule): ...@@ -697,13 +743,13 @@ class Linear(TransformerEngineBaseModule):
def get_fp8_weights_scratchpad( def get_fp8_weights_scratchpad(
self, self,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]: ) -> List[Float8Tensor]:
""" """
Fetch the fp8 weight tensor placeholders if they exist (when Fetch the fp8 weight tensor placeholders if they exist (when
`is_first_microbatch` is not `None`) or return empty fp8 weight `is_first_microbatch` is not `None`) or return empty fp8 weight
tensors (if `is_first_microbatch is None`) tensors (if `is_first_microbatch is None`)
""" """
if not self.fp8: if not self.fp8 or self.primary_weights_in_fp8:
return [None, None] return [None, None]
if is_first_microbatch is None: if is_first_microbatch is None:
...@@ -747,6 +793,8 @@ class Linear(TransformerEngineBaseModule): ...@@ -747,6 +793,8 @@ class Linear(TransformerEngineBaseModule):
""" """
with self.prepare_forward(inp, is_first_microbatch) as inp: with self.prepare_forward(inp, is_first_microbatch) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8."
bias_tensor = ( bias_tensor = (
self.bias if self.parameters_split is None self.bias if self.parameters_split is None
else self.bias_tensor if not torch.is_grad_enabled() else self.bias_tensor if not torch.is_grad_enabled()
...@@ -790,6 +838,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -790,6 +838,7 @@ class Linear(TransformerEngineBaseModule):
self.activation_dtype, self.activation_dtype,
self.parallel_mode, self.parallel_mode,
torch.is_grad_enabled(), torch.is_grad_enabled(),
self.primary_weights_in_fp8,
self.ub_split_rs, self.ub_split_rs,
self.ub_split_ag, self.ub_split_ag,
self.ub_atomic_gemm_rs, self.ub_atomic_gemm_rs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment