Unverified Commit 666539f3 authored by Tian Zheng's avatar Tian Zheng Committed by GitHub
Browse files

[Paddle] Add TP overlap (#443)



* Add async TP comm
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add set CUDA_DEVICE_MAX_CONNECTIONS warning for tp overlap
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

---------
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 88649838
......@@ -3,8 +3,10 @@
# See LICENSE for license information.
"""Methods needed for distributed training."""
import os
import warnings
from contextlib import contextmanager
from typing import Optional, Union, Tuple
from typing import Any, Optional, Union, Tuple
import paddle
......@@ -35,6 +37,19 @@ def get_tp_group_and_world_size(tp_group: Union[dist_group_type, None],
if tp_group is None else tp_group)
world_size = (tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size()
if tp_group is None else tp_group.nranks)
"""
When using TP, the NCCL communication needs to be scheduled
before the GEMM for a guaranteed overlap. From the host side
in TE, the comm calls are always launched first, but to ensure
that the GEMM isn't scheduled first, the environment variable
`CUDA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to force a
single channel.
"""
num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0"))
if num_cuda_work_queues != 1:
warnings.warn("To guarantee overlapping TP and SP collectives with the backward"
"GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1")
return model_parallel_group, world_size
......@@ -69,7 +84,8 @@ def set_weight_tensor_dist_attr(tensor: paddle.Tensor, is_parallel: bool,
def allreduce(
input_: paddle.Tensor,
tp_group: Optional[dist_group_type] = None,
) -> paddle.Tensor:
sync_op: bool = True,
) -> Tuple[paddle.Tensor, Any]:
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
......@@ -77,14 +93,25 @@ def allreduce(
return input_
# All-reduce.
if sync_op:
output = mp_ops._mp_allreduce(
input_,
group=tp_group,
use_calc_stream=True,
use_model_parallel=True,
)
return output, None
return output
wait_handle = paddle.distributed.all_reduce(
input_,
op=paddle.distributed.ReduceOp.SUM,
group=tp_group,
sync_op=False,
)
output = input_
return output, wait_handle
def identity(
......
......@@ -532,7 +532,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
ln_out = identity(ln_out, self.tp_group)
out = F.linear(ln_out, self.weight, self.bias if self.gemm_bias_fused_add else None)
if self.parallel_mode == 'row' and self.tensor_parallel:
out = allreduce(out, self.tp_group)
out, _ = allreduce(out, self.tp_group)
out = out + self.bias if self.bias is not None else out
if self.return_layernorm_output:
return out, ln_out
......
......@@ -787,7 +787,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
out = F.linear(act_out, self.fc2_weight,
self.fc2_bias if self.gemm_bias_fused_add else None)
if self.set_parallel_mode and self.tensor_parallel:
out = allreduce(out, self.tp_group)
out, _ = allreduce(out, self.tp_group)
out = out + self.fc2_bias if self.fc2_bias is not None else out
if self.return_layernorm_output:
return out, ln_out
......
......@@ -94,7 +94,7 @@ def _linear_fwd_fp8(
# Row Parallel Linear
if parallel_mode == "row" and tensor_parallel:
out = allreduce(out, tp_group)
out, _ = allreduce(out, tp_group)
return out, weight_t_fp8
......@@ -146,7 +146,7 @@ def _linear_fwd_non_fp8(
out, _, _ = outputs
# Row Parallel Linear
if parallel_mode == "row" and tensor_parallel:
out = allreduce(out, tp_group)
out, _ = allreduce(out, tp_group)
return out
......@@ -221,7 +221,7 @@ def _linear_bwd_fp8(
tensor_parallel: bool,
tp_group: Union[dist_group_type, None],
):
dgrad, wgrad = None, None
dgrad, wgrad, handle = None, None, None
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False)
if requires_dgrad:
......@@ -239,7 +239,7 @@ def _linear_bwd_fp8(
use_split_accumulator=_2X_ACC_DGRAD,
)
if parallel_mode == "column" and tensor_parallel:
dgrad = allreduce(dgrad, tp_group)
dgrad, handle = allreduce(dgrad, tp_group, sync_op=False)
if requires_wgrad:
if not fp8_meta["recipe"].override_linear_precision.wgrad:
......@@ -265,6 +265,10 @@ def _linear_bwd_fp8(
layout="NT",
grad=True,
)
if parallel_mode == "column" and tensor_parallel and handle is not None:
handle.wait()
return dgrad, wgrad
......@@ -285,7 +289,7 @@ def _linear_bwd_non_fp8(
"""
Performs Linear Backward. Optionally, fuses GELU backward and dbias.
"""
dgrad, wgrad, bgrad = None, None, None
dgrad, wgrad, bgrad, handle = None, None, None, None
if requires_dgrad:
dgrad, _, _ = gemm(
weight,
......@@ -298,7 +302,7 @@ def _linear_bwd_non_fp8(
grad=True,
)
if parallel_mode == "column" and tensor_parallel:
dgrad = allreduce(dgrad, tp_group)
dgrad, handle = allreduce(dgrad, tp_group, sync_op=False)
if requires_wgrad:
wgrad, bgrad, _ = gemm(
......@@ -313,6 +317,9 @@ def _linear_bwd_non_fp8(
elif requires_bgrad:
bgrad = grad_output.sum(axis=0)
if parallel_mode == "column" and tensor_parallel and handle is not None:
handle.wait()
return dgrad, wgrad, bgrad
......@@ -676,7 +683,7 @@ class Linear(TransformerEngineBaseLayer):
inp = identity(inp, self.tp_group)
out = F.linear(inp, self.weight, self.bias if self.gemm_bias_fused_add else None)
if self.parallel_mode == 'row' and self.tensor_parallel:
out = allreduce(out, self.tp_group)
out, _ = allreduce(out, self.tp_group)
out = out + self.bias if self.bias is not None else out
return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment