"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "5fdd7bb94c702a2e48ccd060b84362b9e63ac684"
Unverified Commit 9ff2c076 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Use correct FP8 group in multi-GPU docs (#852)



* Use correct FP8 group in multi-GPU docs

FP8 process group should be tensor-parallel group
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Synchronize FP8 scales over world group in multi-GPU docs
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 9bd938bc
......@@ -115,12 +115,13 @@
"# Configure parallel groups\n",
"import os\n",
"import torch\n",
"world_group = torch.distributed.init_process_group(\n",
"torch.distributed.init_process_group(\n",
" \"nccl\",\n",
" init_method=\"file:///tmp/rdzv\",\n",
" world_size=1,\n",
" rank=0,\n",
")\n",
"world_group = torch.distributed.new_group(ranks=[0], backend=\"nccl\")\n",
"data_parallel_group = torch.distributed.new_group(ranks=[0], backend=\"nccl\")\n",
"tensor_parallel_group = torch.distributed.new_group(ranks=[0], backend=\"nccl\")"
]
......@@ -132,7 +133,9 @@
"source": [
"We only initialize with one GPU to keep this example simple. Please consult the documentation [torch.distributed](https://pytorch.org/docs/stable/distributed.html) for guidance on running with multiple GPUs. Note that we require that each distributed process corresponds to exactly one GPU, so we treat them interchangeably. In practice, there are multiple factors that can affect the optimal parallel layout: the system hardware, the network topology, usage of other parallelism schemes like pipeline parallelism. A rough rule-of-thumb is to interpret the GPUs as a 2D grid with dimensions of $\\text{num_nodes} \\times \\text{gpus_per_node}$. The rows are tensor-parallel groups and the columns are data-parallel groups.\n",
"\n",
"Enabling data parallelism with Transformer Engine is similar to enabling data parallelism with standard PyTorch models: simply wrap the modules with [torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html). FP8 training requires extra synchronization for the scaling factors, so the data-parallel process group must also be passed to the [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager. Transformer Engine modules also have native support for tensor and sequence parallelism. If the user provides a process group for tensor parallelism, the modules will distribute the data and perform communication internally. If sequence parallelism is enabled, it will be applied for operations that are not amenable to tensor parallelism and it will use the tensor-parallel process group. In this case, the tensor parallel group must also be passed to the **fp8_group** argument in the [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager, either directly or as a subset of a larger distributed group."
"Enabling data parallelism with Transformer Engine is similar to enabling data parallelism with standard PyTorch models: simply wrap the modules with [torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html). Transformer Engine modules also have native support for tensor and sequence parallelism. If the user provides a process group for tensor parallelism, the modules will distribute the data and perform communication internally. If sequence parallelism is enabled, it will be applied for operations that are not amenable to tensor parallelism and it will use the tensor-parallel process group.\n",
"\n",
"One important consideration for multi-GPU FP8 training is how to synchronize the FP8 scaling factors between GPUs. If tensor parallelism is enabled, the scales must be synchronized over the tensor-parallel group. However, synchronizing over both the data-parallel and tensor-parallel groups is recommended for the best convergence. This can be configured with the **fp8_group** argument in the [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager."
]
},
{
......@@ -166,7 +169,7 @@
")\n",
"\n",
"# Training step\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=data_parallel_group):\n",
"with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=world_group):\n",
" y = parallel_transformer(x, attention_mask=None)\n",
"y.backward(dy)\n",
"\n",
......@@ -179,7 +182,7 @@
" fp8_autocast_kwargs = {\n",
" \"enabled\": True,\n",
" \"fp8_recipe\": fp8_recipe,\n",
" \"fp8_group\": data_parallel_group,\n",
" \"fp8_group\": world_group,\n",
" },\n",
")"
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment