Unverified Commit 8e7f4c8c authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Documentation for advanced performance optimizations (#20)



* Documentation for advanced perf optimizations

Fix bug where we were doing backward passes inside fp8_autocast in example notebooks.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Minor tweaks to advanced perf optimization docs

Review suggestions from @ptrendx
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Rewording sequence parallelism in advanced perf optimization docs

Review suggestion from @ksivaman
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 681bf9ad
{
"cells": [
{
"cell_type": "markdown",
"id": "24184f3f",
"metadata": {},
"source": [
"# Advanced Performance Optimizations"
]
},
{
"cell_type": "markdown",
"id": "6dcbf25a",
"metadata": {},
"source": [
"This guide is a follow-up to the discussion in the [quickstart guide](quickstart.ipynb). We will focus on techniques to achieve maximum performance when training a basic GPT encoder layer. For convenience, we use some helper functions defined in [quickstart_utils.py](quickstart_utils.py). "
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "2b53dfa7",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import transformer_engine.pytorch as te\n",
"from transformer_engine.common.recipe import Format, DelayedScaling\n",
"import quickstart_utils as utils\n",
"\n",
"# Layer configuration\n",
"hidden_size = 4096\n",
"sequence_length = 2048\n",
"batch_size = 4\n",
"ffn_hidden_size = 16384\n",
"num_attention_heads = 32\n",
"dtype = torch.float16\n",
"\n",
"# Synthetic data\n",
"x = torch.rand(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype)\n",
"dy = torch.rand(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b96a9ef6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 27.82952880859375 ms\n"
]
}
],
"source": [
"# Construct layer\n",
"basic_transformer = te.TransformerLayer(\n",
" hidden_size,\n",
" ffn_hidden_size,\n",
" num_attention_heads,\n",
")\n",
"basic_transformer.to(dtype=dtype).cuda()\n",
"\n",
"fp8_format = Format.HYBRID\n",
"fp8_recipe = DelayedScaling(\n",
" fp8_format=fp8_format,\n",
" amax_history_len=16,\n",
" amax_compute_algo=\"max\",\n",
")\n",
"# Training step\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
" y = basic_transformer(x, attention_mask=None)\n",
"y.backward(dy)\n",
"\n",
"# Measure step time\n",
"utils.speedometer(\n",
" basic_transformer,\n",
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
" fp8_autocast_kwargs = { \"enabled\": True, \"fp8_recipe\": fp8_recipe },\n",
")"
]
},
{
"cell_type": "markdown",
"id": "11367f5b",
"metadata": {},
"source": [
"## Multi-GPU training\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Summary</b>\n",
" \n",
"We parallelize a Transformer layer with data, tensor, and sequence parallelism.\n",
"\n",
"</div>\n",
"\n",
"A variety of parallelism strategies can be used to enable multi-GPU training of Transformer models, often based on different approaches to distribute their $\\text{sequence_length} \\times \\text{batch_size} \\times \\text{hidden_size}$ activation tensors. The most common approach is data parallelism, which distributes along the $\\text{batch_size}$ dimension. By storing duplicate copies of the model on each GPU, the forward and backward passes of the training step can be done independently, followed by a gradient synchronization. A more advanced strategy is tensor parallelism, a type of model parallelism that distributes along the $\\text{hidden_size}$ dimension. This allows us to scale past the limits of data parallelism (typically $\\text{hidden_size} > \\text{batch_size}$) and to reduce the per-GPU memory usage (since model parameters are also distributed), but it also incurs the overhead of communicating activation tensors between GPUs at every step. For a more detailed explanation, please see the [Megatron-LM paper](https://arxiv.org/pdf/1909.08053.pdf). Finally, sequence parallelism distributes along the $\\text{sequence_length}$ dimension. This can be used when tensor parallelism is enabled in order to parallelize operations that run outside the tensor-parallel region (e.g. layer norm). For more details, please see [this paper](https://arxiv.org/pdf/2205.05198.pdf).\n",
"\n",
"To show this in action, let's first initialize NCCL with a trivial process group:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "fca06ec3",
"metadata": {},
"outputs": [],
"source": [
"# Configure parallel groups\n",
"import os\n",
"import torch\n",
"world_group = torch.distributed.init_process_group(\n",
" \"nccl\",\n",
" init_method=\"file:///tmp/rdzv\",\n",
" world_size=1,\n",
" rank=0,\n",
")\n",
"data_parallel_group = torch.distributed.new_group(ranks=[0], backend=\"nccl\")\n",
"tensor_parallel_group = torch.distributed.new_group(ranks=[0], backend=\"nccl\")"
]
},
{
"cell_type": "markdown",
"id": "1f2b80d0",
"metadata": {},
"source": [
"We only initialize with one GPU to keep this example simple. Please consult the documentation [torch.distributed](https://pytorch.org/docs/stable/distributed.html) for guidance on running with multiple GPUs. Note that we require that each distributed process corresponds to exactly one GPU, so we treat them interchangeably. In practice, there are multiple factors that can affect the optimal parallel layout: the system hardware, the network topology, usage of other parallelism schemes like pipeline parallelism. A rough rule-of-thumb is to interpret the GPUs as a 2D grid with dimensions of $\\text{num_nodes} \\times \\text{gpus_per_node}$. The rows are tensor-parallel groups and the columns are data-parallel groups.\n",
"\n",
"Enabling data parallelism with Transformer Engine is similar to enabling data parallelism with standard PyTorch models: simply wrap the modules with [torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html). FP8 training requires extra synchronization for the scaling factors, so the data-parallel process group must also be passed to the [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager. Transformer Engine modules also have native support for tensor and sequence parallelism. If the user provides a process group for tensor parallelism, the modules will distribute the data and perform communication internally. If sequence parallelism is enabled, it will be applied for operations that are not amenable to tensor parallelism and it will use the tensor-parallel process group."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "1892cc9d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 29.09606689453125 ms\n"
]
}
],
"source": [
"# Construct layer\n",
"parallel_transformer = te.TransformerLayer(\n",
" hidden_size,\n",
" ffn_hidden_size,\n",
" num_attention_heads,\n",
" set_parallel_mode=True,\n",
" tp_group=tensor_parallel_group,\n",
" sequence_parallel=True,\n",
")\n",
"parallel_transformer.to(dtype=dtype).cuda()\n",
"parallel_transformer = torch.nn.parallel.DistributedDataParallel(\n",
" parallel_transformer,\n",
" process_group=data_parallel_group,\n",
")\n",
"\n",
"# Training step\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=data_parallel_group):\n",
" y = parallel_transformer(x, attention_mask=None)\n",
"y.backward(dy)\n",
"\n",
"# Measure step time\n",
"utils.speedometer(\n",
" parallel_transformer,\n",
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
" fp8_autocast_kwargs = {\n",
" \"enabled\": True,\n",
" \"fp8_recipe\": fp8_recipe,\n",
" \"fp8_group\": data_parallel_group,\n",
" },\n",
")"
]
},
{
"cell_type": "markdown",
"id": "5f03f6d8",
"metadata": {},
"source": [
"## Gradient accumulation fusion\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Summary</b>\n",
" \n",
"We take advantage of the ability of Tensor Cores to accumulate outputs directly into FP32.\n",
"\n",
"</div>\n",
"\n",
"PyTorch's autograd functionality assumes that a model parameter and its corresponding gradient have the same data type. However, while low-precision data types like FP8 are sufficient for evaluating a neural network's forward and backward passes, the optimization step typically requires full FP32 precision to avoid signficant learning degradation. In addition, Tensor Cores on Hopper GPUs have the option to accumulate matrix products directly into FP32, resulting in better numerical accuracy and avoiding the need for a separate casting kernel. Thus, Transformer Engine provides an option to directly generate FP32 gradients for weight tensors. The FP32 gradients are not output to the parameter's `grad` tensor, but rather to a `main_grad` tensor that must be initialized before the backward pass."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "a7f612ec",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 27.510029296875 ms\n"
]
}
],
"source": [
"# Construct layer\n",
"wgrad_transformer = te.TransformerLayer(\n",
" hidden_size,\n",
" ffn_hidden_size,\n",
" num_attention_heads,\n",
" fuse_wgrad_accumulation=True,\n",
" fuse_qkv_params=True, # Required for fuse_wgrad_accumulation\n",
")\n",
"wgrad_transformer.to(dtype=dtype).cuda()\n",
"for param in wgrad_transformer.parameters():\n",
" param.grad = None\n",
" param.main_grad = torch.zeros_like(param, dtype=torch.float32)\n",
"\n",
"# Training step\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
" y = wgrad_transformer(x, attention_mask=None)\n",
"y.backward(dy)\n",
"for param in wgrad_transformer.parameters():\n",
" if param.grad is not None:\n",
" param.main_grad.copy_(param.grad)\n",
" param.grad = None\n",
"\n",
"# Measure step time\n",
"utils.speedometer(\n",
" wgrad_transformer,\n",
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
" fp8_autocast_kwargs = { \"enabled\": True, \"fp8_recipe\": fp8_recipe },\n",
")"
]
},
{
"cell_type": "markdown",
"id": "add64bd5",
"metadata": {},
"source": [
"## FP8 weight caching\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Summary</b>\n",
" \n",
"We avoid redundant FP8 casting when training with multiple gradient accumulation steps.\n",
"\n",
"</div>\n",
"\n",
"Since weights are typically trained in FP32, a type conversion is required before we can perform compute in FP8. By default, the [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager will handle this internally by casting non-FP8 tensors to FP8 as they are encountered. However, we can improve upon this in some cases. In particular, if our training iteration is split into multiple gradient accumulation steps, each micro-batch will encounter the same weight tensors. Thus, we only need to cast the weights to FP8 in the first gradient accumulation step and we can cache the resulting FP8 weights for the remaining gradient accumulation steps."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "abbc218e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 27.262666015625 ms\n"
]
}
],
"source": [
"# Construct layer\n",
"weight_caching_transformer = te.TransformerLayer(\n",
" hidden_size,\n",
" ffn_hidden_size,\n",
" num_attention_heads,\n",
")\n",
"weight_caching_transformer.to(dtype=dtype).cuda()\n",
"\n",
"# Cast weights in first gradient accumulation step\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
" y = weight_caching_transformer(x, attention_mask=None, is_first_microbatch=True)\n",
"y.backward(dy)\n",
"\n",
"# Reuse FP8 weights in subsequent gradient accumulation steps\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
" y = weight_caching_transformer(x, attention_mask=None, is_first_microbatch=False)\n",
"y.backward(dy)\n",
"\n",
"# Measure step time\n",
"utils.speedometer(\n",
" weight_caching_transformer,\n",
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None, \"is_first_microbatch\": False },\n",
" fp8_autocast_kwargs = { \"enabled\": True, \"fp8_recipe\": fp8_recipe },\n",
")"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
...@@ -162,9 +162,9 @@ ...@@ -162,9 +162,9 @@
], ],
"source": [ "source": [
"basic_transformer = BasicTransformerLayer(\n", "basic_transformer = BasicTransformerLayer(\n",
" hidden_size, \n", " hidden_size,\n",
" ffn_hidden_size, \n", " ffn_hidden_size,\n",
" num_attention_heads\n", " num_attention_heads,\n",
")\n", ")\n",
"basic_transformer.to(dtype=dtype).cuda()" "basic_transformer.to(dtype=dtype).cuda()"
] ]
...@@ -190,12 +190,17 @@ ...@@ -190,12 +190,17 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Mean time: 41.4469287109375 ms\n" "Mean time: 43.0663916015625 ms\n"
] ]
} }
], ],
"source": [ "source": [
"utils.speedometer(basic_transformer, x, dy)" "utils.speedometer(\n",
" basic_transformer,\n",
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
")"
] ]
}, },
{ {
...@@ -340,12 +345,17 @@ ...@@ -340,12 +345,17 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Mean time: 41.3155712890625 ms\n" "Mean time: 43.1413232421875 ms\n"
] ]
} }
], ],
"source": [ "source": [
"utils.speedometer(basic_te_transformer, x, dy)" "utils.speedometer(\n",
" basic_te_transformer,\n",
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
")"
] ]
}, },
{ {
...@@ -456,12 +466,17 @@ ...@@ -456,12 +466,17 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Mean time: 41.5097509765625 ms\n" "Mean time: 43.1981201171875 ms\n"
] ]
} }
], ],
"source": [ "source": [
"utils.speedometer(fused_te_transformer, x, dy)" "utils.speedometer(\n",
" fused_te_transformer,\n",
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
")"
] ]
}, },
{ {
...@@ -505,12 +520,17 @@ ...@@ -505,12 +520,17 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Mean time: 38.391796875 ms\n" "Mean time: 39.99169921875 ms\n"
] ]
} }
], ],
"source": [ "source": [
"utils.speedometer(te_transformer, x, dy)" "utils.speedometer(\n",
" te_transformer,\n",
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
")"
] ]
}, },
{ {
...@@ -528,7 +548,7 @@ ...@@ -528,7 +548,7 @@
"\n", "\n",
"</div>\n", "</div>\n",
"\n", "\n",
"Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager. See the [FP8 tutorial](fp8_primer.ipynb) for a detailed explanation of FP8 recipes and the supported options." "Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager. Note that fp8_autocast should only be used to wrap the forward pass and must exit before starting a backward pass. See the [FP8 tutorial](fp8_primer.ipynb) for a detailed explanation of FP8 recipes and the supported options."
] ]
}, },
{ {
...@@ -548,7 +568,7 @@ ...@@ -548,7 +568,7 @@
"fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")\n", "fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")\n",
"torch.manual_seed(1234)\n", "torch.manual_seed(1234)\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n", "with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
" y = te_transformer(x, attention_mask=None)\n" " y = te_transformer(x, attention_mask=None)"
] ]
}, },
{ {
...@@ -561,13 +581,18 @@ ...@@ -561,13 +581,18 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Mean time: 27.991220703125 ms\n" "Mean time: 28.61394775390625 ms\n"
] ]
} }
], ],
"source": [ "source": [
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n", "utils.speedometer(\n",
" utils.speedometer(te_transformer, x, dy)" " te_transformer,\n",
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
" fp8_autocast_kwargs = { \"enabled\": True, \"fp8_recipe\": fp8_recipe },\n",
")"
] ]
} }
], ],
......
...@@ -5,11 +5,15 @@ ...@@ -5,11 +5,15 @@
import math import math
from typing import Callable, Optional from typing import Callable, Optional
import torch import torch
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import DelayedScaling, dist_group_type
def speedometer( def speedometer(
module: torch.nn.Module, module: torch.nn.Module,
input: torch.Tensor, input: torch.Tensor,
output_grad: torch.Tensor, output_grad: torch.Tensor,
forward_kwargs: dict = {},
fp8_autocast_kwargs: Optional[dict] = None,
timing_iters: int = 50, timing_iters: int = 50,
warmup_iters: int = 50, warmup_iters: int = 50,
) -> None: ) -> None:
...@@ -19,17 +23,21 @@ def speedometer( ...@@ -19,17 +23,21 @@ def speedometer(
""" """
start = torch.cuda.Event(enable_timing=True) start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True)
if fp8_autocast_kwargs is None:
fp8_autocast_kwargs = { "enabled": False }
# Warmup runs # Warmup runs
torch.cuda.synchronize() torch.cuda.synchronize()
for _ in range(warmup_iters): for _ in range(warmup_iters):
output = module(input, attention_mask=None) with te.fp8_autocast(**fp8_autocast_kwargs):
output = module(input, **forward_kwargs)
output.backward(output_grad) output.backward(output_grad)
# Timing runs # Timing runs
start.record() start.record()
for _ in range(timing_iters): for _ in range(timing_iters):
output = module(input, attention_mask=None) with te.fp8_autocast(**fp8_autocast_kwargs):
output = module(input, **forward_kwargs)
output.backward(output_grad) output.backward(output_grad)
end.record() end.record()
torch.cuda.synchronize() torch.cuda.synchronize()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment