Unverified Commit 85a91997 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Generalize quantization APIs for FP8/FP4/.. recipes (#2256)



* Initial API change
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change all imports and api
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* format
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix typo
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix recipe tets
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix more tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix docs, tests, and make Jax change as well
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change internal uses of fp8_autocast
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address nits
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rename file
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* CG function, and small test fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change instances of make_graphed_callables internally
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix distributed tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix test and add more docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Cleanup test imports and minimize internal file imports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Make is_bf16_available public
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Better docs and better api
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* format
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply suggestions from code review
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* fix nvfp4 test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent ca6fedcf
......@@ -86,7 +86,7 @@ PyTorch
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3)
# Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
with te.autocast(enabled=True, recipe=fp8_recipe):
out = model(inp)
loss = out.sum()
......@@ -121,7 +121,7 @@ Flax
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID)
# Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
with te.autocast(enabled=True, recipe=fp8_recipe):
model = te_flax.DenseGeneral(features=HIDDEN)
def loss_fn(params, other_vars, inp):
......
......@@ -6,11 +6,10 @@ import argparse
import torch
import torch.utils.benchmark as benchmark
import pandas as pd
import pathlib
from transformer_engine.pytorch.module import GroupedLinear
from transformer_engine.common.recipe import Float8BlockScaling, MXFP8BlockScaling
from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager
from transformer_engine.pytorch.quantization import autocast, FP8GlobalStateManager
from contextlib import nullcontext
"""
......@@ -51,9 +50,7 @@ fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
def run_linear_multiple_steps(layer, x, m_splits, mode, gradient, run_num_steps=1, recipe=None):
assert mode in ["fwd_only", "fwd_bwd"]
fp8_context = (
fp8_autocast(enabled=True, fp8_recipe=recipe) if recipe is not None else nullcontext()
)
fp8_context = autocast(enabled=True, fp8_recipe=recipe) if recipe is not None else nullcontext()
# print(f"fp8_context: {fp8_context} and is it nullcontext? {isinstance(fp8_context, nullcontext)}")
if mode == "fwd_only":
......
......@@ -30,6 +30,7 @@ Modules
.. autoapifunction:: transformer_engine.jax.fp8_autocast
.. autoapifunction:: transformer_engine.jax.autocast
.. autoapifunction:: transformer_engine.jax.update_collections
......
......@@ -41,8 +41,28 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.fp8_model_init
.. autoapifunction:: transformer_engine.pytorch.autocast
.. autoapifunction:: transformer_engine.pytorch.quantized_model_init
.. autoapifunction:: transformer_engine.pytorch.checkpoint
.. autoapifunction:: transformer_engine.pytorch.is_fp8_available
.. autoapifunction:: transformer_engine.pytorch.is_mxfp8_available
.. autoapifunction:: transformer_engine.pytorch.is_fp8_block_scaling_available
.. autoapifunction:: transformer_engine.pytorch.is_nvfp4_available
.. autoapifunction:: transformer_engine.pytorch.is_bf16_available
.. autoapifunction:: transformer_engine.pytorch.get_cudnn_version
.. autoapifunction:: transformer_engine.pytorch.get_device_compute_capability
.. autoapifunction:: transformer_engine.pytorch.get_default_recipe
.. autoapifunction:: transformer_engine.pytorch.make_graphed_callables
.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context
......
......@@ -69,7 +69,7 @@ Let's look at a simple example of training a Transformer layer using Transformer
for epoch in range(5):
transformer_layer.train()
optimizer.zero_grad()
with te.fp8_autocast(enabled=True):
with te.autocast(enabled=True):
output = transformer_layer(dummy_input)
loss = criterion(output, dummy_target)
loss.backward()
......
......@@ -71,7 +71,7 @@
" amax_compute_algo=\"max\",\n",
")\n",
"# Training step\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
"with te.autocast(enabled=True, recipe=fp8_recipe):\n",
" y = basic_transformer(x, attention_mask=None)\n",
"y.backward(dy)\n",
"\n",
......@@ -81,7 +81,7 @@
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
" fp8_autocast_kwargs = { \"enabled\": True, \"fp8_recipe\": fp8_recipe },\n",
" autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe },\n",
")"
]
},
......@@ -135,7 +135,7 @@
"\n",
"Enabling data parallelism with Transformer Engine is similar to enabling data parallelism with standard PyTorch models: simply wrap the modules with [torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html). Transformer Engine modules also have native support for tensor and sequence parallelism. If the user provides a process group for tensor parallelism, the modules will distribute the data and perform communication internally. If sequence parallelism is enabled, it will be applied for operations that are not amenable to tensor parallelism and it will use the tensor-parallel process group.\n",
"\n",
"One important consideration for multi-GPU FP8 training is how to synchronize the FP8 scaling factors between GPUs. If tensor parallelism is enabled, the scales must be synchronized over the tensor-parallel group. However, synchronizing over both the data-parallel and tensor-parallel groups is recommended for the best convergence. This can be configured with the **fp8_group** argument in the [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager."
"One important consideration for multi-GPU FP8 training is how to synchronize the FP8 scaling factors between GPUs. If tensor parallelism is enabled, the scales must be synchronized over the tensor-parallel group. However, synchronizing over both the data-parallel and tensor-parallel groups is recommended for the best convergence. This can be configured with the **fp8_group** argument in the [autocast](../api/pytorch.rst#transformer_engine.pytorch.autocast) context manager."
]
},
{
......@@ -169,7 +169,7 @@
")\n",
"\n",
"# Training step\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=world_group):\n",
"with te.autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=world_group):\n",
" y = parallel_transformer(x, attention_mask=None)\n",
"y.backward(dy)\n",
"\n",
......@@ -179,10 +179,10 @@
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
" fp8_autocast_kwargs = {\n",
" autocast_kwargs = {\n",
" \"enabled\": True,\n",
" \"fp8_recipe\": fp8_recipe,\n",
" \"fp8_group\": world_group,\n",
" \"recipe\": fp8_recipe,\n",
" \"amax_reduction_group\": world_group,\n",
" },\n",
")"
]
......@@ -234,7 +234,7 @@
" param.main_grad = torch.zeros_like(param, dtype=torch.float32)\n",
"\n",
"# Training step\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
"with te.autocast(enabled=True, recipe=fp8_recipe):\n",
" y = wgrad_transformer(x, attention_mask=None)\n",
"y.backward(dy)\n",
"for param in wgrad_transformer.parameters():\n",
......@@ -248,7 +248,7 @@
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
" fp8_autocast_kwargs = { \"enabled\": True, \"fp8_recipe\": fp8_recipe },\n",
" autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe },\n",
")"
]
},
......@@ -268,7 +268,7 @@
"\n",
"</div>\n",
"\n",
"Since weights are typically trained in FP32, a type conversion is required before we can perform compute in FP8. By default, the [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager will handle this internally by casting non-FP8 tensors to FP8 as they are encountered. However, we can improve upon this in some cases. In particular, if our training iteration is split into multiple gradient accumulation steps, each micro-batch will encounter the same weight tensors. Thus, we only need to cast the weights to FP8 in the first gradient accumulation step and we can cache the resulting FP8 weights for the remaining gradient accumulation steps.\n",
"Since weights are typically trained in FP32, a type conversion is required before we can perform compute in FP8. By default, the [autocast](../api/pytorch.rst#transformer_engine.pytorch.autocast) context manager will handle this internally by casting non-FP8 tensors to FP8 as they are encountered. However, we can improve upon this in some cases. In particular, if our training iteration is split into multiple gradient accumulation steps, each micro-batch will encounter the same weight tensors. Thus, we only need to cast the weights to FP8 in the first gradient accumulation step and we can cache the resulting FP8 weights for the remaining gradient accumulation steps.\n",
"\n",
"<div class=\"alert alert-warning\">\n",
"\n",
......@@ -303,12 +303,12 @@
"weight_caching_transformer.to(dtype=dtype).cuda()\n",
"\n",
"# Cast weights in first gradient accumulation step\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
"with te.autocast(enabled=True, recipe=fp8_recipe):\n",
" y = weight_caching_transformer(x, attention_mask=None, is_first_microbatch=True)\n",
"y.backward(dy)\n",
"\n",
"# Reuse FP8 weights in subsequent gradient accumulation steps\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
"with te.autocast(enabled=True, recipe=fp8_recipe):\n",
" y = weight_caching_transformer(x, attention_mask=None, is_first_microbatch=False)\n",
"y.backward(dy)\n",
"\n",
......@@ -318,7 +318,7 @@
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None, \"is_first_microbatch\": False },\n",
" fp8_autocast_kwargs = { \"enabled\": True, \"fp8_recipe\": fp8_recipe },\n",
" autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe },\n",
")"
]
}
......
......@@ -132,7 +132,7 @@
" - 2D Scaling. The non-square size of the quantization blocks, while increasing granularity, has a property that the quantized tensor and its transpose no longer hold the same values. This is important since the transposed tensors are used when calculating gradients of the linear layers. While most tensors are not sensitive to this issue during training, it does affect the training accuracy when applied to the weight tensors. Therefore, the weights of the linear layers are quantized using a 2D scheme, where a single scaling factor is shared by a 2D block of 16x16 elements.\n",
" - Random Hadamard Transforms. While microscaling reduces the dynamic range needed to represent tensor values, outliers can still have a\n",
"disproportionate impact on FP4 formats, degrading model accuracy. Random Hadamard transforms address this by reshaping the tensor distribution to be more Gaussian-like, which smooths outliers and makes tensors easier to represent accurately in NVFP4. In Transformer Engine, we use a 16x16 Hadamard matrix for activations and gradients when performing weight gradient computation.\n",
" - Last few layers in higher precision. The last few layers of the LLM are more sensitive to the quantization and so we recommend running them in higher precision (for example MXFP8). This is not done automatically in Transformer Engine, since TE does not have the full information about the structure of the network being trained. This can be easily achieved though by modifying the model training code to run the last few layers under a different `fp8_autocast` (or nesting 2 autocasts in order to override the recipe for a part of the network).\n",
" - Last few layers in higher precision. The last few layers of the LLM are more sensitive to the quantization and so we recommend running them in higher precision (for example MXFP8). This is not done automatically in Transformer Engine, since TE does not have the full information about the structure of the network being trained. This can be easily achieved though by modifying the model training code to run the last few layers under a different `autocast` (or nesting 2 autocasts in order to override the recipe for a part of the network).\n",
"\n",
"The full linear layer utilizing NVFP4 is presented in Figure 9.\n",
"\n",
......@@ -193,7 +193,7 @@
"source": [
"### FP8 autocasting\n",
"\n",
"Not every operation is safe to be performed using FP8. All of the modules provided by Transformer Engine library were designed to provide maximum performance benefit from FP8 datatype while maintaining accuracy. In order to enable FP8 operations, TE modules need to be wrapped inside the [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager."
"Not every operation is safe to be performed using FP8. All of the modules provided by Transformer Engine library were designed to provide maximum performance benefit from FP8 datatype while maintaining accuracy. In order to enable FP8 operations, TE modules need to be wrapped inside the [autocast](../api/pytorch.rst#transformer_engine.pytorch.autocast) context manager."
]
},
{
......@@ -212,7 +212,7 @@
"\n",
"inp = torch.rand((1024, 768)).cuda()\n",
"\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
"with te.autocast(enabled=True, recipe=fp8_recipe):\n",
" out_fp8 = my_linear(inp)"
]
},
......@@ -221,7 +221,7 @@
"id": "e41161f1",
"metadata": {},
"source": [
"The `fp8_autocast` context manager hides the complexity of handling FP8:\n",
"The `autocast` context manager hides the complexity of handling FP8:\n",
"\n",
"- All FP8-safe operations have their inputs cast to FP8\n",
"- Amax history is updated\n",
......@@ -243,9 +243,9 @@
"source": [
"### Handling backward pass\n",
"\n",
"When a model is run inside the `fp8_autocast` region, especially in multi-GPU training, some communication is required in order to synchronize the scaling factors and amax history. In order to perform that communication without introducing much overhead, `fp8_autocast` context manager aggregates the tensors before performing the communication.\n",
"When a model is run inside the `autocast` region, especially in multi-GPU training, some communication is required in order to synchronize the scaling factors and amax history. In order to perform that communication without introducing much overhead, `autocast` context manager aggregates the tensors before performing the communication.\n",
"\n",
"Due to this aggregation the backward call needs to happen outside of the `fp8_autocast` context manager. It has no impact on the computation precision - the precision of the backward pass is determined by the precision of the forward pass."
"Due to this aggregation the backward call needs to happen outside of the `autocast` context manager. It has no impact on the computation precision - the precision of the backward pass is determined by the precision of the forward pass."
]
},
{
......@@ -257,11 +257,11 @@
"source": [
"loss_fp8 = out_fp8.mean()\n",
"\n",
"loss_fp8.backward() # This backward pass uses FP8, since out_fp8 was calculated inside fp8_autocast\n",
"loss_fp8.backward() # This backward pass uses FP8, since out_fp8 was calculated inside autocast\n",
"\n",
"out_fp32 = my_linear(inp)\n",
"loss_fp32 = out_fp32.mean()\n",
"loss_fp32.backward() # This backward pass does not use FP8, since out_fp32 was calculated outside fp8_autocast"
"loss_fp32.backward() # This backward pass does not use FP8, since out_fp32 was calculated outside autocast"
]
},
{
......@@ -451,9 +451,9 @@
"\n",
"inp = inp.bfloat16()\n",
"\n",
"with te.fp8_autocast(fp8_recipe=nvfp4_recipe):\n",
"with te.autocast(recipe=nvfp4_recipe):\n",
" y = my_linear1(inp)\n",
" with te.fp8_autocast(fp8_recipe=mxfp8_recipe):\n",
" with te.autocast(recipe=mxfp8_recipe):\n",
" out = my_linear2(y)\n",
"\n",
"print(out)\n",
......
......@@ -80,7 +80,7 @@
"model = Model().eval().cuda()\n",
"inps = (torch.randn([S, B, H], device=\"cuda\"),)\n",
"def _inference(fp8_enabled):\n",
" with torch.no_grad(), te.pytorch.fp8_autocast(enabled=fp8_enabled):\n",
" with torch.no_grad(), te.pytorch.autocast(enabled=fp8_enabled):\n",
" model(*inps)\n",
"\n",
"te_fp32_time = _measure_time(lambda: _inference(fp8_enabled=False))\n",
......@@ -138,7 +138,7 @@
"from transformer_engine.pytorch.export import te_translation_table\n",
"\n",
"def export(model, fname, inputs, fp8=True):\n",
" with torch.no_grad(), te.pytorch.fp8_autocast(enabled=fp8):\n",
" with torch.no_grad(), te.pytorch.autocast(enabled=fp8):\n",
" # ! IMPORTANT !\n",
" # Transformer Engine models must have warm-up run\n",
" # before export. FP8 recipe during warm-up should \n",
......
......@@ -548,7 +548,7 @@
"\n",
"</div>\n",
"\n",
"Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager. Note that fp8_autocast should only be used to wrap the forward pass and must exit before starting a backward pass. See the [FP8 tutorial](fp8_primer.ipynb) for a detailed explanation of FP8 recipes and the supported options."
"Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [autocast](../api/pytorch.rst#transformer_engine.pytorch.autocast) context manager. Note that autocast should only be used to wrap the forward pass and must exit before starting a backward pass. See the [FP8 tutorial](fp8_primer.ipynb) for a detailed explanation of FP8 recipes and the supported options."
]
},
{
......@@ -567,7 +567,7 @@
"fp8_format = Format.HYBRID\n",
"fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")\n",
"torch.manual_seed(1234)\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
"with te.autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
" y = te_transformer(x, attention_mask=None)"
]
},
......@@ -591,7 +591,7 @@
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
" fp8_autocast_kwargs = { \"enabled\": True, \"fp8_recipe\": fp8_recipe },\n",
" autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe },\n",
")"
]
}
......
......@@ -13,7 +13,7 @@ def speedometer(
input: torch.Tensor,
output_grad: torch.Tensor,
forward_kwargs: dict = {},
fp8_autocast_kwargs: Optional[dict] = None,
autocast_kwargs: Optional[dict] = None,
timing_iters: int = 50,
warmup_iters: int = 50,
) -> None:
......@@ -23,20 +23,20 @@ def speedometer(
"""
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
if fp8_autocast_kwargs is None:
fp8_autocast_kwargs = {"enabled": False}
if autocast_kwargs is None:
autocast_kwargs = {"enabled": False}
# Warmup runs
torch.cuda.synchronize()
for _ in range(warmup_iters):
with te.fp8_autocast(**fp8_autocast_kwargs):
with te.autocast(**autocast_kwargs):
output = module(input, **forward_kwargs)
output.backward(output_grad)
# Timing runs
start.record()
for _ in range(timing_iters):
with te.fp8_autocast(**fp8_autocast_kwargs):
with te.autocast(**autocast_kwargs):
output = module(input, **forward_kwargs)
output.backward(output_grad)
end.record()
......
......@@ -14,7 +14,7 @@ from torch.amp import autocast
import transformer_engine as te
from transformer_engine.pytorch.attention import InferenceParams, RotaryPositionEmbedding
from transformer_engine.common.recipe import Format, DelayedScaling
from transformer_engine.pytorch.fp8 import get_default_fp8_recipe
from transformer_engine.pytorch.quantization import get_default_fp8_recipe
import transformers
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM, GemmaConfig, GemmaModel
......@@ -461,8 +461,8 @@ class TEGemmaForCausalLM(GemmaForCausalLM):
# Both autocasts are needed: FP8 for operations that can run in lower
# precision and BF16 for those that cannot.
with autocast("cuda", dtype=torch.bfloat16, cache_enabled=False), te.pytorch.fp8_autocast(
enabled=self.config.fp8, fp8_recipe=self.fp8_recipe if self.config.fp8 else None
with autocast("cuda", dtype=torch.bfloat16, cache_enabled=False), te.pytorch.autocast(
enabled=self.config.fp8, recipe=self.fp8_recipe if self.config.fp8 else None
):
lengths = torch.sum(input_ids.ne(pad_token_id), dim=-1).squeeze()
# If padding is at the beginning, then shift it to the end
......@@ -694,8 +694,8 @@ class TEGemmaForCausalLMCudaGraphs(TEGemmaForCausalLM):
graphed_function = te.pytorch.make_graphed_callables(
function,
(input_tensor,),
fp8_enabled=self.config.fp8,
fp8_recipe=fp8_recipe,
enabled=self.config.fp8,
recipe=fp8_recipe,
allow_unused_input=True,
num_warmup_iters=5,
sample_kwargs=sample_kwargs,
......
......@@ -9,7 +9,7 @@ import torch
from typing import List
from transformer_engine.pytorch.fp8 import fp8_model_init
from transformer_engine.pytorch.quantization import quantized_model_init
from transformers.modeling_utils import load_state_dict
from transformers.utils.hub import get_checkpoint_shard_files
......@@ -88,10 +88,10 @@ def load_te_model(cls, config):
config.use_cache = False # To make TransformerLayer compatible with GemmaModel
# Loading model with FP8 only weights needs both the following context managers.
# 1. fp8_model_init(config.fp8_model_init) to tell TE to use FP8 only weights.
# 1. quantized_model_init(config.quantized_model_init) to tell TE to use FP8 only weights.
# 2. torch.no_grad() during TE modules' initilization so that they respect
# the `fp8_model_init` context manager.
with torch.no_grad(), fp8_model_init(config.fp8_model_init):
# the `quantized_model_init` context manager.
with torch.no_grad(), quantized_model_init(config.quantized_model_init):
# Just create a model with random weights.
vanilla_model = cls(config).cuda()
......
......@@ -77,7 +77,7 @@
"\n",
"This tutorial uses the `DelayedScaling` recipe for FP8 precision, which relies on the correct calculation of \"scaling factors\".\n",
"\n",
"If a model is trained in BF16/FP32, obtaining correct FP8 scaling factors becomes important when it is then run under `fp8_autocast()` context manager. The value of these scaling factors defaults to their initial values, which do not capture the distribution of higher precision weights and input tensors and can cause numerical errors upon usage. Calibration involves capturing an appropriate distribution of higher precision weights and input tensor values and, in turn, calculating appropriate FP8 scaling factors from those. Once these factors are computed, the model becomes numerically stable.\n",
"If a model is trained in BF16/FP32, obtaining correct FP8 scaling factors becomes important when it is then run under `autocast()` context manager. The value of these scaling factors defaults to their initial values, which do not capture the distribution of higher precision weights and input tensors and can cause numerical errors upon usage. Calibration involves capturing an appropriate distribution of higher precision weights and input tensor values and, in turn, calculating appropriate FP8 scaling factors from those. Once these factors are computed, the model becomes numerically stable.\n",
"\n",
"It is highly recommended to familiarize oneself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the importance of proper scaling factors.\n",
"\n",
......@@ -94,12 +94,12 @@
"\n",
"The typical approach is to store weights in higher precision and then cast them to FP8 before operations. This may prevent accuracy drops in training. However, for inference, this level of precision is not necessary.\n",
"\n",
"The Transformer Engine includes a wrapper `fp8_model_init`, which allows for the creation of models that store only the FP8 copy of the weights. This eliminates the need to cast model weights from higher precision to FP8 every time, thus saving time in the forward pass during token generation. \n",
"The Transformer Engine includes a wrapper `quantized_model_init`, which allows for the creation of models that store only the FP8 copy of the weights. This eliminates the need to cast model weights from higher precision to FP8 every time, thus saving time in the forward pass during token generation. \n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"./media/fp8_model_init.svg\" style=\"border: 1px solid #000; border-radius: 0;\" alt=\"\">\n",
"<figcaption>\n",
"Figure 3: Model under <b>fp8_autocast()</b> stores weights in high precision by default, and casts them if needed. If used without consideration, it could potentially not provide the expected speedup and also end up unnecessarily increasing overall GPU memory usage. Using <b>fp8_model_init()</b> results in storing model weights in FP8 by default, which can help with these potential issues.\n",
"Figure 3: Model under <b>autocast()</b> stores weights in high precision by default, and casts them if needed. If used without consideration, it could potentially not provide the expected speedup and also end up unnecessarily increasing overall GPU memory usage. Using <b>quantized_model_init()</b> results in storing model weights in FP8 by default, which can help with these potential issues.\n",
"</figcaption>\n",
"</figure>\n",
"\n",
......@@ -405,8 +405,8 @@
" graphed_function = te.pytorch.make_graphed_callables(\n",
" function,\n",
" (input_tensor,),\n",
" fp8_enabled=self.config.fp8,\n",
" fp8_recipe=fp8_recipe,\n",
" enabled=self.config.fp8,\n",
" recipe=fp8_recipe,\n",
" allow_unused_input=True,\n",
" num_warmup_iters=5,\n",
" sample_kwargs=sample_kwargs,\n",
......@@ -540,14 +540,14 @@
"source": [
"### Calibrating FP8 scaling factors for correctness\n",
"\n",
"Implementing token generation in FP8 precision with the Gemma model is not straightforward because this model was initially trained using BF16 precision, and the necessary FP8 scaling factors are missing when used with `fp8_autocast` context manager. As Figure 5 shows, scaling factors are needed for two types of tensors for this tutorial:\n",
"Implementing token generation in FP8 precision with the Gemma model is not straightforward because this model was initially trained using BF16 precision, and the necessary FP8 scaling factors are missing when used with `autocast` context manager. As Figure 5 shows, scaling factors are needed for two types of tensors for this tutorial:\n",
"\n",
"1. Model weight tensors\n",
"2. Input tensors\n",
"\n",
"If the model is run in FP8 precision with incorrect scaling factors, the resulting FP8-cast model weights and FP8-cast inputs (both converted from BF16 precision) will be significantly misaligned, potentially leading to large errors and inaccurate results.\n",
"\n",
"To address this issue, \"calibration\" is used. This involves running several forward iterations in BF16 precision within the context `te.fp8_autocast(enabled=False, calibration=True)`. This setup allows the forward pass to operate at higher precision, while simultaneously collecting `amax_history` and other parameters related to the FP8 precision, which are essential for calculating the \"scaling factors\" that are then used to cast higher precision tensors to FP8 precision more accurately. Calibration in the forward passes calculates the scaling factors for weight and input tensors.\n",
"To address this issue, \"calibration\" is used. This involves running several forward iterations in BF16 precision within the context `te.autocast(enabled=False, calibration=True)`. This setup allows the forward pass to operate at higher precision, while simultaneously collecting `amax_history` and other parameters related to the FP8 precision, which are essential for calculating the \"scaling factors\" that are then used to cast higher precision tensors to FP8 precision more accurately. Calibration in the forward passes calculates the scaling factors for weight and input tensors.\n",
"\n",
"*Note that other tensors might need calibration in specific use-cases, but for the generation process in this tutorial, calibrating only the input and weight tensors is needed, and so only the forward pass is considered.*\n",
" \n",
......@@ -590,14 +590,14 @@
"model = init_te_gemma_model(run_config)\n",
"\n",
"# Calibration\n",
"with te.fp8_autocast(enabled=False, calibrating=True), torch.autocast(\n",
"with te.autocast(enabled=False, calibrating=True), torch.autocast(\n",
" device_type=\"cuda\", dtype=torch.bfloat16\n",
"):\n",
" model.train()\n",
" run_forward_pass(model, run_config, num_iters=64)\n",
"\n",
"# Compute scale_fwd with enabled fp8 autocast\n",
"with te.fp8_autocast(enabled=True), torch.autocast(\n",
"with te.autocast(enabled=True), torch.autocast(\n",
" device_type=\"cuda\", dtype=torch.bfloat16\n",
"):\n",
" run_forward_pass(model, run_config, 1)\n",
......@@ -734,7 +734,7 @@
"2. Faster forward pass - since there is no cast kernel to cast higher precision weights to FP8 every time before a GEMM operation. (Unless the inputs are in FP8 precision already, there's still one cast kernel to cast inputs to FP8 precision.) \n",
"\n",
"\n",
"Transformer Engine supports maintaining FP8-only weights with the `fp8_model_init` context manager. Let's see a small example:"
"Transformer Engine supports maintaining FP8-only weights with the `quantized_model_init` context manager. Let's see a small example:"
]
},
{
......@@ -778,7 +778,7 @@
"del linear_bf16\n",
"\n",
"# Initialize model weights in FP8 precision\n",
"with torch.no_grad(), te.fp8_model_init(enabled=True):\n",
"with torch.no_grad(), te.quantized_model_init(enabled=True):\n",
" linear_fp8 = te.Linear(H, D)\n",
"print(f\"Actual GPU memory usage with a TE FP8 linear layer: {torch.cuda.memory_allocated()/1024**2:.2f} MB\")\n",
"del linear_fp8"
......@@ -793,11 +793,11 @@
"<figure align=\"center\">\n",
"<img src=\"./media/fp8_model_init_2_half.svg\" style=\"border: 1px solid #000; border-radius: 0;\">\n",
"<figcaption>\n",
" Figure 8: Using fp8_model_init stores the weights directly in FP8 format, which reduces both time and memory usage. Note that the inputs still need a cast kernel.\n",
" Figure 8: Using quantized_model_init stores the weights directly in FP8 format, which reduces both time and memory usage. Note that the inputs still need a cast kernel.\n",
"</figcaption>\n",
"</figure>\n",
"\n",
"Let's run the code with `fp8_model_init`:"
"Let's run the code with `quantized_model_init`:"
]
},
{
......@@ -862,7 +862,7 @@
"\n",
"# Enable FP8 math and FP8 model weights\n",
"run_config.fp8 = True\n",
"run_config.fp8_model_init = True # This will result in storing only fp8 weights.\n",
"run_config.quantized_model_init = True # This will result in storing only fp8 weights.\n",
"run_config.fp8_model_weights_filename = (\n",
" \"calibrated_weights.pth\" # <-- Add calibrated weights location here.\n",
")\n",
......@@ -885,7 +885,7 @@
"| HF (baseline) | 46.6 s | - | - | - |\n",
"| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer) | 12.25 s | 3.8x | 12.24 s | 3.8x |\n",
"| TE (te.TransformerLayer) + CUDA Graphs | 6.39 s | 7.2x | 6.47 s | 7.2x |\n",
"| TE (te.TransformerLayer) + CUDA Graphs + FP8 (with `fp8_model_init`) | 4.99 s | 9.3x | 5.05 s | 9.2x |"
"| TE (te.TransformerLayer) + CUDA Graphs + FP8 (with `quantized_model_init`) | 4.99 s | 9.3x | 5.05 s | 9.2x |"
]
},
{
......@@ -911,7 +911,7 @@
"It's worth noting that these features in TE are also readily applicable to other use-cases which haven't been extensively talked about in the tutorial: \n",
"\n",
"1. Longer context lengths (with paged KV cache) \n",
"2. Using less memory during generation (by storing weights in FP8 precision using `fp8_model_init`)\n",
"2. Using less memory during generation (by storing weights in FP8 precision using `quantized_model_init`)\n",
"\n",
"Readers are encouraged to explore these use cases by playing around with this tutorial, especially with larger models."
]
......
......@@ -34,7 +34,7 @@ class RunConfiguration:
# FP8 precision settings
self.fp8 = False
self.fp8_model_weights_filename = None
self.fp8_model_init = False
self.quantized_model_init = False
# Cuda graphs
self.generation_cuda_graphs = False
......
......@@ -15,8 +15,8 @@ Here, we take the `MultiheadAttention` module as an example. Its FP8 attention m
.. code-block:: python
>>> from transformer_engine.pytorch import MultiheadAttention, fp8_model_init
>>> with fp8_model_init(enabled=True):
>>> from transformer_engine.pytorch import MultiheadAttention, quantized_model_init
>>> with quantized_model_init(enabled=True):
... mha = MultiheadAttention(
... hidden_size=1024,
... num_attention_heads=16,
......
......@@ -24,7 +24,7 @@ from common import (
from transformer_engine.jax.dense import dense
from transformer_engine.jax.quantize import fp8_autocast
from transformer_engine.jax.quantize import autocast
from transformer_engine.jax.cpp_extensions.gemm import (
CollectiveOp,
CollectiveOpSet,
......@@ -98,12 +98,12 @@ def run_dense_grad_tests(args, mesh=None):
)
collective_op_set = CollectiveOpSet.create(forward_collective_op=collective_op)
with mesh, fp8_autocast(
with mesh, autocast(
enabled=False,
fp8_recipe=None,
recipe=None,
mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS),
):
# Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast
# Get the base axis rules and extend them with TE's rules. This must be done inside autocast
axis_rules = flax.linen.get_logical_axis_rules()
axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS))
te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules)
......
......@@ -32,7 +32,7 @@ from common import (
)
import transformer_engine.jax.cpp_extensions as tex
from transformer_engine.jax.quantize import fp8_autocast
from transformer_engine.jax.quantize import autocast
from transformer_engine.jax.cpp_extensions.gemm import CollectiveOp
from transformer_engine.jax.sharding import MeshResource
......@@ -109,9 +109,9 @@ def run_gemm_tests(args, mesh=None):
else CollectiveOp.REDUCE_SCATTER
)
with mesh, fp8_autocast(
with mesh, autocast(
enabled=False,
fp8_recipe=None,
recipe=None,
mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS),
):
print(f"Device mesh: {mesh}")
......
......@@ -24,7 +24,7 @@ from common import (
from transformer_engine.jax.layernorm_mlp import layernorm_mlp
from transformer_engine.jax.quantize import fp8_autocast
from transformer_engine.jax.quantize import autocast
from transformer_engine.jax.cpp_extensions.gemm import (
CollectiveOpSet,
CollectiveOp,
......@@ -151,12 +151,12 @@ def run_layernorm_mlp_grad_tests(args, mesh=None):
collective_op_sets = (collective_op_set_1, collective_op_set_2)
noop_collective_op_sets = (noop_collective_op_set, noop_collective_op_set)
with mesh, fp8_autocast(
with mesh, autocast(
enabled=False,
fp8_recipe=None,
recipe=None,
mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS),
):
# Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast
# Get the base axis rules and extend them with TE's rules. This must be done inside autocast
axis_rules = flax.linen.get_logical_axis_rules()
axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS))
te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules)
......
......@@ -8,7 +8,7 @@ This example uses Transformer Encoder to demonstrate the Transformer Engine usag
2. Define model: The `Net` class is a small Transformer Encoder model for sentence classification. The Transformer Engine provides `te.TransformerLayer` as encoder block and `te.DenseGeneral`. The structure of encoder block can be referred to [Scaling Up Models and Data with t5x and seqio](https://arxiv.org/abs/2203.17189)
3. Build training loop: The `train_and_evaluate` is the main routine to initialize the model and start training and evaluating. Use `fp8_autocast` context manager to enable FP8 training and check `var_collect` if the variable collection contains `Float8`.
3. Build training loop: The `train_and_evaluate` is the main routine to initialize the model and start training and evaluating. Use `autocast` context manager to enable FP8 training and check `var_collect` if the variable collection contains `Float8`.
4. Training process: In `train_step`, combine the FP8 metadata and latest model parameters into var_collect as a frozen dictionary and fill it to the gradient function.
......@@ -29,7 +29,7 @@ python test_single_gpu_encoder.py --use-fp8
3. On the model side, the logical axis of each weight tensor of the model can be named. The `te.TransformerLayer` has the default names, which are stored in `abs_var_collect`, a collection of variables returned by `jax.eval_shape(encoder.init, ...)`. The key index is `params_axes`. The `te.DenseGeneral` doesn't have the default named axis because it is generic. Also, data-parallel sharding doesn't need to divide weight tensor, so named axis is not required for this case.
4. The next is to create sharding rules, mapping the device axis to the logical axis. The `te.extend_logical_axis_rules` under fp8_autocast will return a list of pairs of the mapping, such as `(('batch', 'data'), ...)`. The first is the logical axis and second is the device axis.
4. The next is to create sharding rules, mapping the device axis to the logical axis. The `te.extend_logical_axis_rules` under autocast will return a list of pairs of the mapping, such as `(('batch', 'data'), ...)`. The first is the logical axis and second is the device axis.
5. Refer structure of `abs_var_collect['params']` and `abs_var_collect['params_axes']` to set up `PartitionSpec` for parallel jit. All logical axes should be replaced by device axes. If the value of PartitionSpec is None, that means no sharding, broadcasting the data to every device. Note that the `params_axes` attribute is provided by Transformer Engine. The Flax's module doesn't have it, such as `nn.Embed`. For nn.Embed, assigning an empty PartitionSpec is fine because each device has its own embedding layer in DP mode. The `get_params_pspec` routine is used for this purpose. Because each device has a complete model in DP mode, all values of PartitionSpec in params_pspec should be None. This will be different in the model parallelism example.
......
......@@ -269,9 +269,9 @@ def train_and_evaluate(args):
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
) as mesh, te.fp8_autocast(
) as mesh, te.autocast(
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
recipe=fp8_recipe,
mesh_resource=te.MeshResource(
dp_resource=DEVICE_DP_AXIS,
tpsp_resource=DEVICE_TP_AXIS,
......@@ -287,7 +287,7 @@ def train_and_evaluate(args):
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size]
# Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast
# Get the base axis rules and extend them with TE's rules. This must be done inside autocast
axis_rules = flax.linen.get_logical_axis_rules()
axis_rules += ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment