"vscode:/vscode.git/clone" did not exist on "9e5e8bc91e30af5cdc321362b553f6c0da332e30"
Unverified Commit 666539f3 authored by Tian Zheng's avatar Tian Zheng Committed by GitHub
Browse files

[Paddle] Add TP overlap (#443)



* Add async TP comm
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add set CUDA_DEVICE_MAX_CONNECTIONS warning for tp overlap
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

---------
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 88649838
...@@ -3,8 +3,10 @@ ...@@ -3,8 +3,10 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Methods needed for distributed training.""" """Methods needed for distributed training."""
import os
import warnings
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional, Union, Tuple from typing import Any, Optional, Union, Tuple
import paddle import paddle
...@@ -35,6 +37,19 @@ def get_tp_group_and_world_size(tp_group: Union[dist_group_type, None], ...@@ -35,6 +37,19 @@ def get_tp_group_and_world_size(tp_group: Union[dist_group_type, None],
if tp_group is None else tp_group) if tp_group is None else tp_group)
world_size = (tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size() world_size = (tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size()
if tp_group is None else tp_group.nranks) if tp_group is None else tp_group.nranks)
"""
When using TP, the NCCL communication needs to be scheduled
before the GEMM for a guaranteed overlap. From the host side
in TE, the comm calls are always launched first, but to ensure
that the GEMM isn't scheduled first, the environment variable
`CUDA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to force a
single channel.
"""
num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0"))
if num_cuda_work_queues != 1:
warnings.warn("To guarantee overlapping TP and SP collectives with the backward"
"GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1")
return model_parallel_group, world_size return model_parallel_group, world_size
...@@ -69,7 +84,8 @@ def set_weight_tensor_dist_attr(tensor: paddle.Tensor, is_parallel: bool, ...@@ -69,7 +84,8 @@ def set_weight_tensor_dist_attr(tensor: paddle.Tensor, is_parallel: bool,
def allreduce( def allreduce(
input_: paddle.Tensor, input_: paddle.Tensor,
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
) -> paddle.Tensor: sync_op: bool = True,
) -> Tuple[paddle.Tensor, Any]:
"""All-reduce the input tensor across model parallel group.""" """All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
...@@ -77,14 +93,25 @@ def allreduce( ...@@ -77,14 +93,25 @@ def allreduce(
return input_ return input_
# All-reduce. # All-reduce.
output = mp_ops._mp_allreduce( if sync_op:
output = mp_ops._mp_allreduce(
input_,
group=tp_group,
use_calc_stream=True,
use_model_parallel=True,
)
return output, None
wait_handle = paddle.distributed.all_reduce(
input_, input_,
op=paddle.distributed.ReduceOp.SUM,
group=tp_group, group=tp_group,
use_calc_stream=True, sync_op=False,
use_model_parallel=True,
) )
return output output = input_
return output, wait_handle
def identity( def identity(
......
...@@ -532,7 +532,7 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -532,7 +532,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
ln_out = identity(ln_out, self.tp_group) ln_out = identity(ln_out, self.tp_group)
out = F.linear(ln_out, self.weight, self.bias if self.gemm_bias_fused_add else None) out = F.linear(ln_out, self.weight, self.bias if self.gemm_bias_fused_add else None)
if self.parallel_mode == 'row' and self.tensor_parallel: if self.parallel_mode == 'row' and self.tensor_parallel:
out = allreduce(out, self.tp_group) out, _ = allreduce(out, self.tp_group)
out = out + self.bias if self.bias is not None else out out = out + self.bias if self.bias is not None else out
if self.return_layernorm_output: if self.return_layernorm_output:
return out, ln_out return out, ln_out
......
...@@ -787,7 +787,7 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -787,7 +787,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
out = F.linear(act_out, self.fc2_weight, out = F.linear(act_out, self.fc2_weight,
self.fc2_bias if self.gemm_bias_fused_add else None) self.fc2_bias if self.gemm_bias_fused_add else None)
if self.set_parallel_mode and self.tensor_parallel: if self.set_parallel_mode and self.tensor_parallel:
out = allreduce(out, self.tp_group) out, _ = allreduce(out, self.tp_group)
out = out + self.fc2_bias if self.fc2_bias is not None else out out = out + self.fc2_bias if self.fc2_bias is not None else out
if self.return_layernorm_output: if self.return_layernorm_output:
return out, ln_out return out, ln_out
......
...@@ -94,7 +94,7 @@ def _linear_fwd_fp8( ...@@ -94,7 +94,7 @@ def _linear_fwd_fp8(
# Row Parallel Linear # Row Parallel Linear
if parallel_mode == "row" and tensor_parallel: if parallel_mode == "row" and tensor_parallel:
out = allreduce(out, tp_group) out, _ = allreduce(out, tp_group)
return out, weight_t_fp8 return out, weight_t_fp8
...@@ -146,7 +146,7 @@ def _linear_fwd_non_fp8( ...@@ -146,7 +146,7 @@ def _linear_fwd_non_fp8(
out, _, _ = outputs out, _, _ = outputs
# Row Parallel Linear # Row Parallel Linear
if parallel_mode == "row" and tensor_parallel: if parallel_mode == "row" and tensor_parallel:
out = allreduce(out, tp_group) out, _ = allreduce(out, tp_group)
return out return out
...@@ -221,7 +221,7 @@ def _linear_bwd_fp8( ...@@ -221,7 +221,7 @@ def _linear_bwd_fp8(
tensor_parallel: bool, tensor_parallel: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
): ):
dgrad, wgrad = None, None dgrad, wgrad, handle = None, None, None
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) fp8_dtype_backward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False)
if requires_dgrad: if requires_dgrad:
...@@ -239,7 +239,7 @@ def _linear_bwd_fp8( ...@@ -239,7 +239,7 @@ def _linear_bwd_fp8(
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=_2X_ACC_DGRAD,
) )
if parallel_mode == "column" and tensor_parallel: if parallel_mode == "column" and tensor_parallel:
dgrad = allreduce(dgrad, tp_group) dgrad, handle = allreduce(dgrad, tp_group, sync_op=False)
if requires_wgrad: if requires_wgrad:
if not fp8_meta["recipe"].override_linear_precision.wgrad: if not fp8_meta["recipe"].override_linear_precision.wgrad:
...@@ -265,6 +265,10 @@ def _linear_bwd_fp8( ...@@ -265,6 +265,10 @@ def _linear_bwd_fp8(
layout="NT", layout="NT",
grad=True, grad=True,
) )
if parallel_mode == "column" and tensor_parallel and handle is not None:
handle.wait()
return dgrad, wgrad return dgrad, wgrad
...@@ -285,7 +289,7 @@ def _linear_bwd_non_fp8( ...@@ -285,7 +289,7 @@ def _linear_bwd_non_fp8(
""" """
Performs Linear Backward. Optionally, fuses GELU backward and dbias. Performs Linear Backward. Optionally, fuses GELU backward and dbias.
""" """
dgrad, wgrad, bgrad = None, None, None dgrad, wgrad, bgrad, handle = None, None, None, None
if requires_dgrad: if requires_dgrad:
dgrad, _, _ = gemm( dgrad, _, _ = gemm(
weight, weight,
...@@ -298,7 +302,7 @@ def _linear_bwd_non_fp8( ...@@ -298,7 +302,7 @@ def _linear_bwd_non_fp8(
grad=True, grad=True,
) )
if parallel_mode == "column" and tensor_parallel: if parallel_mode == "column" and tensor_parallel:
dgrad = allreduce(dgrad, tp_group) dgrad, handle = allreduce(dgrad, tp_group, sync_op=False)
if requires_wgrad: if requires_wgrad:
wgrad, bgrad, _ = gemm( wgrad, bgrad, _ = gemm(
...@@ -313,6 +317,9 @@ def _linear_bwd_non_fp8( ...@@ -313,6 +317,9 @@ def _linear_bwd_non_fp8(
elif requires_bgrad: elif requires_bgrad:
bgrad = grad_output.sum(axis=0) bgrad = grad_output.sum(axis=0)
if parallel_mode == "column" and tensor_parallel and handle is not None:
handle.wait()
return dgrad, wgrad, bgrad return dgrad, wgrad, bgrad
...@@ -676,7 +683,7 @@ class Linear(TransformerEngineBaseLayer): ...@@ -676,7 +683,7 @@ class Linear(TransformerEngineBaseLayer):
inp = identity(inp, self.tp_group) inp = identity(inp, self.tp_group)
out = F.linear(inp, self.weight, self.bias if self.gemm_bias_fused_add else None) out = F.linear(inp, self.weight, self.bias if self.gemm_bias_fused_add else None)
if self.parallel_mode == 'row' and self.tensor_parallel: if self.parallel_mode == 'row' and self.tensor_parallel:
out = allreduce(out, self.tp_group) out, _ = allreduce(out, self.tp_group)
out = out + self.bias if self.bias is not None else out out = out + self.bias if self.bias is not None else out
return out return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment