Commit 5b6ef054 authored by yuguo's avatar yuguo
Browse files
parents 76060570 a7eeb28b
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import datetime
import os
import pathlib
import subprocess
from builtins import str
# Basic project info
project = "Transformer Engine"
author = "NVIDIA CORPORATION & AFFILIATES"
# Copyright statement
release_year = 2022
current_year = datetime.date.today().year
if current_year == release_year:
copyright_year = release_year
else:
copyright_year = str(release_year) + "-" + str(current_year)
copyright = f"{copyright_year}, NVIDIA CORPORATION & AFFILIATES. All rights reserved."
# Transformer Engine root directory
root_path = pathlib.Path(__file__).resolve().parent.parent
# Git hash
git_sha = os.getenv("GIT_SHA")
if not git_sha:
try:
git_sha = (
subprocess.check_output(["git", "log", "--pretty=format:'%h'", "-n1"])
.decode("ascii")
.replace("'", "")
.strip()
)
except:
git_sha = "0000000"
git_sha = git_sha[:7] if len(git_sha) > 7 else git_sha
# Version
with open(root_path / "build_tools" / "VERSION.txt", "r") as f:
_raw_version = f.readline().strip()
if "dev" in _raw_version:
version = str(_raw_version + "-" + git_sha)
else:
version = str(_raw_version)
release = _raw_version
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.mathjax",
"sphinx.ext.napoleon",
"sphinx.ext.ifconfig",
"nbsphinx",
"breathe",
"autoapi.extension",
]
templates_path = ["_templates"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
source_suffix = ".rst"
master_doc = "index"
pygments_style = "sphinx"
# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = "sphinx_rtd_theme"
html_static_path = ["_static"]
html_show_sphinx = False
html_css_files = [
"css/nvidia_font.css",
"css/nvidia_footer.css",
]
html_theme_options = {
"collapse_navigation": False,
"logo_only": False,
"version_selector": False,
"language_selector": False,
}
napoleon_custom_sections = [
("Parallelism parameters", "params_style"),
("Optimization parameters", "params_style"),
("Values", "params_style"),
("Graphing parameters", "params_style"),
("FP8-related parameters", "params_style"),
]
breathe_projects = {"TransformerEngine": root_path / "docs" / "doxygen" / "xml"}
breathe_default_project = "TransformerEngine"
autoapi_generate_api_docs = False
autoapi_dirs = [root_path / "transformer_engine"]
{
"cells": [
{
"cell_type": "markdown",
"id": "24184f3f",
"metadata": {},
"source": [
"# Performance Optimizations"
]
},
{
"cell_type": "markdown",
"id": "6dcbf25a",
"metadata": {},
"source": [
"This guide is a follow-up to the discussion in the [quickstart guide](quickstart.ipynb). We will focus on techniques to achieve maximum performance when training a basic GPT encoder layer. For convenience, we use some helper functions defined in [quickstart_utils.py](quickstart_utils.py). "
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "2b53dfa7",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import transformer_engine.pytorch as te\n",
"from transformer_engine.common.recipe import Format, DelayedScaling\n",
"import quickstart_utils as utils\n",
"\n",
"# Layer configuration\n",
"hidden_size = 4096\n",
"sequence_length = 2048\n",
"batch_size = 4\n",
"ffn_hidden_size = 16384\n",
"num_attention_heads = 32\n",
"dtype = torch.float16\n",
"\n",
"# Synthetic data\n",
"x = torch.rand(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype)\n",
"dy = torch.rand(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b96a9ef6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 27.82952880859375 ms\n"
]
}
],
"source": [
"# Construct layer\n",
"basic_transformer = te.TransformerLayer(\n",
" hidden_size,\n",
" ffn_hidden_size,\n",
" num_attention_heads,\n",
")\n",
"basic_transformer.to(dtype=dtype).cuda()\n",
"\n",
"fp8_format = Format.HYBRID\n",
"fp8_recipe = DelayedScaling(\n",
" fp8_format=fp8_format,\n",
" amax_history_len=16,\n",
" amax_compute_algo=\"max\",\n",
")\n",
"# Training step\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
" y = basic_transformer(x, attention_mask=None)\n",
"y.backward(dy)\n",
"\n",
"# Measure step time\n",
"utils.speedometer(\n",
" basic_transformer,\n",
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
" fp8_autocast_kwargs = { \"enabled\": True, \"fp8_recipe\": fp8_recipe },\n",
")"
]
},
{
"cell_type": "markdown",
"id": "11367f5b",
"metadata": {},
"source": [
"## Multi-GPU training\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Summary</b>\n",
" \n",
"We parallelize a Transformer layer with data, tensor, and sequence parallelism.\n",
"\n",
"</div>\n",
"\n",
"A variety of parallelism strategies can be used to enable multi-GPU training of Transformer models, often based on different approaches to distribute their $\\text{sequence_length} \\times \\text{batch_size} \\times \\text{hidden_size}$ activation tensors. The most common approach is data parallelism, which distributes along the $\\text{batch_size}$ dimension. By storing duplicate copies of the model on each GPU, the forward and backward passes of the training step can be done independently, followed by a gradient synchronization. A more advanced strategy is tensor parallelism, a type of model parallelism that distributes along the $\\text{hidden_size}$ dimension. This allows us to scale past the limits of data parallelism (typically $\\text{hidden_size} > \\text{batch_size}$) and to reduce the per-GPU memory usage (since model parameters are also distributed), but it also incurs the overhead of communicating activation tensors between GPUs at every step. For a more detailed explanation, please see the [Megatron-LM paper](https://arxiv.org/pdf/1909.08053.pdf). Finally, sequence parallelism distributes along the $\\text{sequence_length}$ dimension. This can be used when tensor parallelism is enabled in order to parallelize operations that run outside the tensor-parallel region (e.g. layer norm). For more details, please see [this paper](https://arxiv.org/pdf/2205.05198.pdf).\n",
"\n",
"To show this in action, let's first initialize NCCL with a trivial process group:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "fca06ec3",
"metadata": {},
"outputs": [],
"source": [
"# Configure parallel groups\n",
"import os\n",
"import torch\n",
"torch.distributed.init_process_group(\n",
" \"nccl\",\n",
" init_method=\"file:///tmp/rdzv\",\n",
" world_size=1,\n",
" rank=0,\n",
")\n",
"world_group = torch.distributed.new_group(ranks=[0], backend=\"nccl\")\n",
"data_parallel_group = torch.distributed.new_group(ranks=[0], backend=\"nccl\")\n",
"tensor_parallel_group = torch.distributed.new_group(ranks=[0], backend=\"nccl\")"
]
},
{
"cell_type": "markdown",
"id": "1f2b80d0",
"metadata": {},
"source": [
"We only initialize with one GPU to keep this example simple. Please consult the documentation [torch.distributed](https://pytorch.org/docs/stable/distributed.html) for guidance on running with multiple GPUs. Note that we require that each distributed process corresponds to exactly one GPU, so we treat them interchangeably. In practice, there are multiple factors that can affect the optimal parallel layout: the system hardware, the network topology, usage of other parallelism schemes like pipeline parallelism. A rough rule-of-thumb is to interpret the GPUs as a 2D grid with dimensions of $\\text{num_nodes} \\times \\text{gpus_per_node}$. The rows are tensor-parallel groups and the columns are data-parallel groups.\n",
"\n",
"Enabling data parallelism with Transformer Engine is similar to enabling data parallelism with standard PyTorch models: simply wrap the modules with [torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html). Transformer Engine modules also have native support for tensor and sequence parallelism. If the user provides a process group for tensor parallelism, the modules will distribute the data and perform communication internally. If sequence parallelism is enabled, it will be applied for operations that are not amenable to tensor parallelism and it will use the tensor-parallel process group.\n",
"\n",
"One important consideration for multi-GPU FP8 training is how to synchronize the FP8 scaling factors between GPUs. If tensor parallelism is enabled, the scales must be synchronized over the tensor-parallel group. However, synchronizing over both the data-parallel and tensor-parallel groups is recommended for the best convergence. This can be configured with the **fp8_group** argument in the [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "1892cc9d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 29.09606689453125 ms\n"
]
}
],
"source": [
"# Construct layer\n",
"parallel_transformer = te.TransformerLayer(\n",
" hidden_size,\n",
" ffn_hidden_size,\n",
" num_attention_heads,\n",
" set_parallel_mode=True,\n",
" tp_group=tensor_parallel_group,\n",
" sequence_parallel=True,\n",
")\n",
"parallel_transformer.to(dtype=dtype).cuda()\n",
"parallel_transformer = torch.nn.parallel.DistributedDataParallel(\n",
" parallel_transformer,\n",
" process_group=data_parallel_group,\n",
")\n",
"\n",
"# Training step\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=world_group):\n",
" y = parallel_transformer(x, attention_mask=None)\n",
"y.backward(dy)\n",
"\n",
"# Measure step time\n",
"utils.speedometer(\n",
" parallel_transformer,\n",
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
" fp8_autocast_kwargs = {\n",
" \"enabled\": True,\n",
" \"fp8_recipe\": fp8_recipe,\n",
" \"fp8_group\": world_group,\n",
" },\n",
")"
]
},
{
"cell_type": "markdown",
"id": "5f03f6d8",
"metadata": {},
"source": [
"## Gradient accumulation fusion\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Summary</b>\n",
" \n",
"We take advantage of the ability of Tensor Cores to accumulate outputs directly into FP32.\n",
"\n",
"</div>\n",
"\n",
"PyTorch's autograd functionality assumes that a model parameter and its corresponding gradient have the same data type. However, while low-precision data types like FP8 are sufficient for evaluating a neural network's forward and backward passes, the optimization step typically requires full FP32 precision to avoid significant learning degradation. In addition, Tensor Cores on Hopper GPUs have the option to accumulate matrix products directly into FP32, resulting in better numerical accuracy and avoiding the need for a separate casting kernel. Thus, Transformer Engine provides an option to directly generate FP32 gradients for weight tensors. The FP32 gradients are not output to the parameter's `grad` tensor, but rather to a `main_grad` tensor that must be initialized before the backward pass."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "a7f612ec",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 27.510029296875 ms\n"
]
}
],
"source": [
"# Construct layer\n",
"wgrad_transformer = te.TransformerLayer(\n",
" hidden_size,\n",
" ffn_hidden_size,\n",
" num_attention_heads,\n",
" fuse_wgrad_accumulation=True,\n",
" fuse_qkv_params=True, # Required for fuse_wgrad_accumulation\n",
")\n",
"wgrad_transformer.to(dtype=dtype).cuda()\n",
"for param in wgrad_transformer.parameters():\n",
" param.grad = None\n",
" param.main_grad = torch.zeros_like(param, dtype=torch.float32)\n",
"\n",
"# Training step\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
" y = wgrad_transformer(x, attention_mask=None)\n",
"y.backward(dy)\n",
"for param in wgrad_transformer.parameters():\n",
" if param.grad is not None:\n",
" param.main_grad.copy_(param.grad)\n",
" param.grad = None\n",
"\n",
"# Measure step time\n",
"utils.speedometer(\n",
" wgrad_transformer,\n",
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
" fp8_autocast_kwargs = { \"enabled\": True, \"fp8_recipe\": fp8_recipe },\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "add64bd5",
"metadata": {},
"source": [
"## FP8 weight caching\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Summary</b>\n",
" \n",
"We avoid redundant FP8 casting when training with multiple gradient accumulation steps.\n",
"\n",
"</div>\n",
"\n",
"Since weights are typically trained in FP32, a type conversion is required before we can perform compute in FP8. By default, the [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager will handle this internally by casting non-FP8 tensors to FP8 as they are encountered. However, we can improve upon this in some cases. In particular, if our training iteration is split into multiple gradient accumulation steps, each micro-batch will encounter the same weight tensors. Thus, we only need to cast the weights to FP8 in the first gradient accumulation step and we can cache the resulting FP8 weights for the remaining gradient accumulation steps.\n",
"\n",
"<div class=\"alert alert-warning\">\n",
"\n",
"<b>Warning!</b> \n",
"\n",
"The precise numerical outputs with and without the FP8 weight caching optimization may not be bitwise identical. This is because while the weights remain frozen across a gradient accumulation cycle, the scaling factors and amaxes for the FP8 weights can change as they are updated at the end of every iteration. These changes in amax tensors are incorporated into the amax history, which is not frozen.\n",
"\n",
"</div>"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "abbc218e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 27.262666015625 ms\n"
]
}
],
"source": [
"# Construct layer\n",
"weight_caching_transformer = te.TransformerLayer(\n",
" hidden_size,\n",
" ffn_hidden_size,\n",
" num_attention_heads,\n",
")\n",
"weight_caching_transformer.to(dtype=dtype).cuda()\n",
"\n",
"# Cast weights in first gradient accumulation step\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
" y = weight_caching_transformer(x, attention_mask=None, is_first_microbatch=True)\n",
"y.backward(dy)\n",
"\n",
"# Reuse FP8 weights in subsequent gradient accumulation steps\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
" y = weight_caching_transformer(x, attention_mask=None, is_first_microbatch=False)\n",
"y.backward(dy)\n",
"\n",
"# Measure step time\n",
"utils.speedometer(\n",
" weight_caching_transformer,\n",
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None, \"is_first_microbatch\": False },\n",
" fp8_autocast_kwargs = { \"enabled\": True, \"fp8_recipe\": fp8_recipe },\n",
")"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import torch
from typing import Tuple
from tests.pytorch.fused_attn.test_fused_attn import ModelConfig
from transformer_engine.pytorch.attention import DotProductAttention
# Initialize RNG state
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
def reset_rng_states() -> None:
"""Revert back to initial RNG state"""
torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
def _run_dot_product_attention(
dtype: torch.dtype,
config: ModelConfig,
qkv_layout: str,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run DotProductAttention module with one forward pass and one backward pass"""
reset_rng_states()
seqlens_q = torch.full(
[config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
)
seqlens_kv = torch.full(
[config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
)
inp = torch.randn(
[config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim_qk],
dtype=dtype,
device="cuda",
)
q = inp[:, :, 0, :, :]
k = inp[:, :, 1, :, :]
v = inp[:, :, 2, :, :]
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True
out_grad = torch.randn(
[config.batch_size, config.max_seqlen_q, config.num_heads * config.head_dim_v],
dtype=dtype,
device="cuda",
)
# Create attention mask / bias
attention_mask = None
bias = None
if config.attn_mask_type == "arbitrary":
attention_mask = torch.randint(
-10,
10,
[config.batch_size, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv],
).to(dtype=torch.bool, device="cuda")
if config.attn_bias_type == "post_scale_bias":
# convert mask to bias
attention_mask = torch.randint(
-10,
10,
[config.batch_size, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv],
).to(dtype=torch.bool, device="cuda")
bias = attention_mask.clone()
neginf = -(2**50) if dtype == torch.bfloat16 else -(2**15)
bias = torch.where(bias == 0, 0, neginf).to(dtype=dtype, device="cuda")
bias.requires_grad = False
attention_mask = None
block = DotProductAttention(
config.num_heads,
config.head_dim_qk,
num_gqa_groups=config.num_gqa_groups,
qkv_format="bshd",
attention_dropout=config.dropout_p,
sequence_parallel=False,
tp_size=1,
get_rng_state_tracker=None,
tp_group=None,
layer_number=1,
attn_mask_type="no_mask",
window_size=(-1, -1),
).to(dtype=dtype, device="cuda")
# Run a forward and backward pass
out = None
if config.attn_mask_type == "arbitrary":
out = block(
q,
k,
v,
attention_mask=attention_mask, # attention_mask
qkv_format="bshd",
attn_mask_type=config.attn_mask_type, # 'arbitrary'
core_attention_bias_type=config.attn_bias_type, # 'no_bias'
core_attention_bias=bias, # None
window_size=(-1, -1),
)
out.backward(out_grad)
if config.attn_bias_type == "post_scale_bias":
out = block(
q,
k,
v,
attention_mask=attention_mask, # None
qkv_format="bshd",
attn_mask_type=config.attn_mask_type, # no_mask
core_attention_bias_type=config.attn_bias_type, # 'post_scale_bias'
core_attention_bias=bias, # bias
window_size=(-1, -1),
)
out.backward(out_grad)
return out, (q.grad, k.grad, v.grad)
dtype = torch.bfloat16
model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"test_mask": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "arbitrary", "no_bias"),
"test_bias": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "post_scale_bias"),
}
print("Run with post_scale_bias:")
config = model_configs["test_bias"]
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd")
print()
print("Run with arbitrary mask:")
config = model_configs["test_mask"]
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd")
torch.testing.assert_close(unfused_attn_fwd, fused_attn_fwd, atol=2.5e-2, rtol=2.5e-2)
for i in range(3):
torch.testing.assert_close(unfused_attn_bwd[i], fused_attn_bwd[i], atol=2.5e-2, rtol=2.5e-2)
print()
print("Test passed!")
This diff is collapsed.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os, sys, time
import subprocess
import pandas as pd
import numpy as np
import torch
import nvtx
import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import (
ModelConfig,
_get_attention_backends,
_run_dot_product_attention,
)
# data type
dtype = torch.bfloat16
# number of iterations after 3 warmup iterations
num_iters = 3
# checkpointing
ckpt_attn = False
# workspace optimization path for cuDNN attention
workspace_opt = True
# QKV memory layout
qkv_layout = "bshd_bshd_bshd"
# sliding window attention
swa = False
# padding between sequences for qkv_format=thd
pad_between_seqs = False
# training mode
is_training = True
model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
"test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
"test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
"test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
}
def example_attention(model, fused_attn_supported, flash_attn_supported):
config = model_configs[model]
if dtype == torch.bfloat16:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
else:
tols = dict(atol=5e-3, rtol=5e-3)
if fused_attn_supported:
print()
print("Run cuDNN attention...")
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
pad_between_seqs,
is_training,
)
if flash_attn_supported:
print()
print("Run flash-attention...")
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype,
config,
"FlashAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
pad_between_seqs,
is_training,
)
if fused_attn_supported and flash_attn_supported:
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
print()
print("Test passed.")
def main():
models = ["test_0"]
for model in models:
config = model_configs[model]
available_backends, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
example_attention(model, fused_attn_supported, flash_attn_supported)
if __name__ == "__main__":
main()
This diff is collapsed.
This diff is collapsed.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import math
from typing import Optional
import torch
import transformer_engine.pytorch as te
def speedometer(
module: torch.nn.Module,
input: torch.Tensor,
output_grad: torch.Tensor,
forward_kwargs: dict = {},
fp8_autocast_kwargs: Optional[dict] = None,
timing_iters: int = 50,
warmup_iters: int = 50,
) -> None:
"""Measure average run time for a PyTorch module
Performs forward and backward passes.
"""
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
if fp8_autocast_kwargs is None:
fp8_autocast_kwargs = {"enabled": False}
# Warmup runs
torch.cuda.synchronize()
for _ in range(warmup_iters):
with te.fp8_autocast(**fp8_autocast_kwargs):
output = module(input, **forward_kwargs)
output.backward(output_grad)
# Timing runs
start.record()
for _ in range(timing_iters):
with te.fp8_autocast(**fp8_autocast_kwargs):
output = module(input, **forward_kwargs)
output.backward(output_grad)
end.record()
torch.cuda.synchronize()
print(f"Mean time: {start.elapsed_time(end)/timing_iters} ms")
class DotProductAttention(torch.nn.Module):
"""Attention operation in Transformer layer
Built with plain PyTorch modules.
"""
def __init__(
self,
num_attention_heads: int,
kv_channels: int,
attention_dropout: float,
) -> None:
super().__init__()
self.projection_size = kv_channels * num_attention_heads
self.hidden_size_per_attention_head = kv_channels
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
self.dropout = torch.nn.Dropout(attention_dropout)
def masked_softmax(self, inp: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
if mask is not None:
inp.masked_fill_(mask, -10000.0)
return torch.nn.Softmax(dim=-1)(inp)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
b = query.size(1)
np = query.size(2)
sq = query.size(0)
sk = key.size(0)
hn = value.size(3)
# [sq, b, np, hn] -> [sq, b * np, hn]
query = query.view(sq, b * np, -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key = key.view(sk, b * np, -1)
bmm1 = (
torch.bmm(query.transpose(0, 1), key.transpose(0, 1).transpose(1, 2)) / self.norm_factor
)
# change view to [b, np, sq, sk]
attention_scores = bmm1.view(b, np, sq, sk)
attention_probs = self.masked_softmax(attention_scores, attention_mask)
attention_probs = self.dropout(attention_probs)
# change view [sk, b * np, hn]
value = value.view(sk, b * np, -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(b * np, sq, -1)
# matmul: [b * np, sq, hn]
context = torch.bmm(attention_probs, value.transpose(0, 1))
# change view [b, np, sq, hn]
context = context.view(b, np, sq, hn)
# [b, np, sq, hn] --> [sq, b, np, hn]
context = context.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
context = context.view(sq, b, self.projection_size)
return context
class BasicMLP(torch.nn.Module):
"""Feed-forward network in Transformer layer
Built with plain PyTorch modules.
"""
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
) -> None:
super().__init__()
self.linear1 = torch.nn.Linear(hidden_size, ffn_hidden_size, bias=True)
self.linear2 = torch.nn.Linear(ffn_hidden_size, hidden_size, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear1(x)
x = torch.nn.functional.gelu(x, approximate="tanh")
x = self.linear2(x)
return x
def share_parameters_with_basic_te_model(te_model, basic_model):
"""Initialize parameters for TE Transformer layer with basic modules
Parameter values are copied from pure PyTorch implementation.
"""
te_model.ln1.weight = basic_model.ln1.weight
te_model.ln1.bias = basic_model.ln1.bias
te_model.qkv_projection.weight = basic_model.qkv_projection.weight
te_model.qkv_projection.bias = basic_model.qkv_projection.bias
te_model.projection.weight = basic_model.projection.weight
te_model.projection.bias = basic_model.projection.bias
te_model.ln2.weight = basic_model.ln2.weight
te_model.ln2.bias = basic_model.ln2.bias
te_model.mlp.linear1.weight = basic_model.mlp.linear1.weight
te_model.mlp.linear1.bias = basic_model.mlp.linear1.bias
te_model.mlp.linear2.weight = basic_model.mlp.linear2.weight
te_model.mlp.linear2.bias = basic_model.mlp.linear2.bias
def share_parameters_with_fused_te_model(te_model, basic_model):
"""Initialize parameters for TE Transformer layer with fused modules
Parameter values are copied from pure PyTorch implementation.
"""
te_model.ln_qkv.layer_norm_weight = basic_model.ln1.weight
te_model.ln_qkv.layer_norm_bias = basic_model.ln1.bias
te_model.ln_qkv.weight = basic_model.qkv_projection.weight
te_model.ln_qkv.bias = basic_model.qkv_projection.bias
te_model.projection.weight = basic_model.projection.weight
te_model.projection.bias = basic_model.projection.bias
te_model.ln_mlp.layer_norm_weight = basic_model.ln2.weight
te_model.ln_mlp.layer_norm_bias = basic_model.ln2.bias
te_model.ln_mlp.fc1_weight = basic_model.mlp.linear1.weight
te_model.ln_mlp.fc1_bias = basic_model.mlp.linear1.bias
te_model.ln_mlp.fc2_weight = basic_model.mlp.linear2.weight
te_model.ln_mlp.fc2_bias = basic_model.mlp.linear2.bias
def share_parameters_with_transformerlayer_te_model(te_model, basic_model):
"""Initialize parameters for monolithic TE Transformer layer
Parameter values are copied from pure PyTorch implementation.
"""
te_model.self_attention.layernorm_qkv.layer_norm_weight = basic_model.ln1.weight
te_model.self_attention.layernorm_qkv.layer_norm_bias = basic_model.ln1.bias
te_model.self_attention.layernorm_qkv.weight = basic_model.qkv_projection.weight
te_model.self_attention.layernorm_qkv.bias = basic_model.qkv_projection.bias
te_model.self_attention.proj.weight = basic_model.projection.weight
te_model.self_attention.proj.bias = basic_model.projection.bias
te_model.layernorm_mlp.layer_norm_weight = basic_model.ln2.weight
te_model.layernorm_mlp.layer_norm_bias = basic_model.ln2.bias
te_model.layernorm_mlp.fc1_weight = basic_model.mlp.linear1.weight
te_model.layernorm_mlp.fc1_bias = basic_model.mlp.linear1.bias
te_model.layernorm_mlp.fc2_weight = basic_model.mlp.linear2.weight
te_model.layernorm_mlp.fc2_bias = basic_model.mlp.linear2.bias
def cast_to_representable(inp, scale=1.0, fp8_format="e4m3"):
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
import transformer_engine_torch as tex
fp8_type = tex.DType.kFloat8E4M3 if fp8_format == "e4m3" else tex.DType.kFloat8E5M2
scale = torch.ones(1, dtype=torch.float32, device="cuda") * scale
amax_history = torch.zeros(1, 1, dtype=torch.float32, device="cuda")
quantizer = Float8Quantizer(scale=scale, amax=amax_history, fp8_dtype=fp8_type)
ret = quantizer(inp)
ret = ret.dequantize()
return ret
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment