Unverified Commit 87bfc348 authored by Min Yu's avatar Min Yu Committed by GitHub
Browse files

[Paddle] Fix forward and backward logic of te.Linear(parallel_mode='column')...


[Paddle] Fix forward and backward logic of te.Linear(parallel_mode='column') to adapt DiT of PaddleMIX (#963)

[Paddle] Fix forward and backward of Linear(parallel_mode='column')

When te.Linear(parallel_mode='column') is not used in pairs with te.Linear(parallel_mode='row'), the output should to be all-gathered when forward and reduce-scattered when backward.
Signed-off-by: default avatarminyu <minyu@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent a435ec01
......@@ -146,6 +146,7 @@ def allgather(
input_: paddle.Tensor,
tp_group: Optional[dist_group_type] = None,
sync_op: bool = True,
axis: int = 0,
) -> Tuple[paddle.Tensor, Any]:
"""All-gather the input tensor across model parallel group."""
......@@ -155,7 +156,7 @@ def allgather(
parallelism = tp_group.nranks
output_shape = input_.shape
output_shape[0] = output_shape[0] * parallelism
output_shape[axis] = output_shape[axis] * parallelism
output = paddle.empty(shape=output_shape, dtype=input_.dtype)
wait_handle = tp_group.process_group.all_gather_into_tensor(output, input_, sync_op)
if sync_op:
......
......@@ -205,6 +205,7 @@ def _linear_fwd(
tp_group: Union[dist_group_type, None],
is_grad_enabled: bool,
is_first_microbatch: bool = None,
gather_output: bool = False,
):
if fp8_enabled:
out, weight_t_fp8 = _linear_fwd_fp8(
......@@ -241,6 +242,9 @@ def _linear_fwd(
sequence_parallel,
tp_group,
)
if gather_output and tensor_parallel and parallel_mode == "column":
out, _ = allgather(out, tp_group, axis=-1)
return (
out,
weight_t_fp8 if fp8_enabled else None,
......@@ -521,6 +525,7 @@ class _Linear(paddle.autograd.PyLayer):
tp_size: int,
fuse_wgrad_accumulation: bool,
is_first_microbatch: bool,
gather_output: bool,
) -> paddle.Tensor:
# Make sure input dimensions are compatible
in_features = weight.shape[-1]
......@@ -575,6 +580,7 @@ class _Linear(paddle.autograd.PyLayer):
tp_group,
is_grad_enabled,
is_first_microbatch,
gather_output,
)
if is_grad_enabled:
......@@ -606,6 +612,7 @@ class _Linear(paddle.autograd.PyLayer):
ctx.requires_wgrad = not weight.stop_gradient
ctx.requires_bgrad = use_bias and not bias.stop_gradient
ctx.is_first_microbatch = is_first_microbatch
ctx.reduce_scatter_output = gather_output
return out.reshape((-1, *inp.shape[1:-1], out.shape[-1]))
......@@ -668,6 +675,10 @@ class _Linear(paddle.autograd.PyLayer):
# bgrad is fused with gemm for non-FP8 path
bgrad = bgrad_
if ctx.reduce_scatter_output:
wgrad, _ = reduce_scatter(wgrad, ctx.tp_group)
bgrad, _ = reduce_scatter(bgrad, ctx.tp_group)
if not ctx.fp8_enabled or ctx.is_first_microbatch is None:
weight_cache_grad = ()
else:
......@@ -742,6 +753,7 @@ class Linear(TransformerEngineBaseLayer):
sequence_parallel: bool = False,
tp_group: Union[dist_group_type, None] = None,
fuse_wgrad_accumulation: bool = False,
gather_output: bool = False,
backend: str = "transformer_engine",
) -> None:
super().__init__()
......@@ -751,6 +763,7 @@ class Linear(TransformerEngineBaseLayer):
self._weight_attr = weight_attr
self._bias_attr = bias_attr
self._dtype = self._helper.get_default_dtype()
self.gather_output = gather_output
# Set parallel configs
self.tp_group, self.tp_size = get_tp_group_and_world_size(
......@@ -854,6 +867,7 @@ class Linear(TransformerEngineBaseLayer):
self.tp_size,
self.fuse_wgrad_accumulation,
is_first_microbatch,
self.gather_output,
)
if not self.gemm_bias_fused_add:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment