Unverified Commit 87bfc348 authored by Min Yu's avatar Min Yu Committed by GitHub
Browse files

[Paddle] Fix forward and backward logic of te.Linear(parallel_mode='column')...


[Paddle] Fix forward and backward logic of te.Linear(parallel_mode='column') to adapt DiT of PaddleMIX (#963)

[Paddle] Fix forward and backward of Linear(parallel_mode='column')

When te.Linear(parallel_mode='column') is not used in pairs with te.Linear(parallel_mode='row'), the output should to be all-gathered when forward and reduce-scattered when backward.
Signed-off-by: default avatarminyu <minyu@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent a435ec01
...@@ -146,6 +146,7 @@ def allgather( ...@@ -146,6 +146,7 @@ def allgather(
input_: paddle.Tensor, input_: paddle.Tensor,
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
sync_op: bool = True, sync_op: bool = True,
axis: int = 0,
) -> Tuple[paddle.Tensor, Any]: ) -> Tuple[paddle.Tensor, Any]:
"""All-gather the input tensor across model parallel group.""" """All-gather the input tensor across model parallel group."""
...@@ -155,7 +156,7 @@ def allgather( ...@@ -155,7 +156,7 @@ def allgather(
parallelism = tp_group.nranks parallelism = tp_group.nranks
output_shape = input_.shape output_shape = input_.shape
output_shape[0] = output_shape[0] * parallelism output_shape[axis] = output_shape[axis] * parallelism
output = paddle.empty(shape=output_shape, dtype=input_.dtype) output = paddle.empty(shape=output_shape, dtype=input_.dtype)
wait_handle = tp_group.process_group.all_gather_into_tensor(output, input_, sync_op) wait_handle = tp_group.process_group.all_gather_into_tensor(output, input_, sync_op)
if sync_op: if sync_op:
......
...@@ -205,6 +205,7 @@ def _linear_fwd( ...@@ -205,6 +205,7 @@ def _linear_fwd(
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
is_grad_enabled: bool, is_grad_enabled: bool,
is_first_microbatch: bool = None, is_first_microbatch: bool = None,
gather_output: bool = False,
): ):
if fp8_enabled: if fp8_enabled:
out, weight_t_fp8 = _linear_fwd_fp8( out, weight_t_fp8 = _linear_fwd_fp8(
...@@ -241,6 +242,9 @@ def _linear_fwd( ...@@ -241,6 +242,9 @@ def _linear_fwd(
sequence_parallel, sequence_parallel,
tp_group, tp_group,
) )
if gather_output and tensor_parallel and parallel_mode == "column":
out, _ = allgather(out, tp_group, axis=-1)
return ( return (
out, out,
weight_t_fp8 if fp8_enabled else None, weight_t_fp8 if fp8_enabled else None,
...@@ -521,6 +525,7 @@ class _Linear(paddle.autograd.PyLayer): ...@@ -521,6 +525,7 @@ class _Linear(paddle.autograd.PyLayer):
tp_size: int, tp_size: int,
fuse_wgrad_accumulation: bool, fuse_wgrad_accumulation: bool,
is_first_microbatch: bool, is_first_microbatch: bool,
gather_output: bool,
) -> paddle.Tensor: ) -> paddle.Tensor:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = weight.shape[-1] in_features = weight.shape[-1]
...@@ -575,6 +580,7 @@ class _Linear(paddle.autograd.PyLayer): ...@@ -575,6 +580,7 @@ class _Linear(paddle.autograd.PyLayer):
tp_group, tp_group,
is_grad_enabled, is_grad_enabled,
is_first_microbatch, is_first_microbatch,
gather_output,
) )
if is_grad_enabled: if is_grad_enabled:
...@@ -606,6 +612,7 @@ class _Linear(paddle.autograd.PyLayer): ...@@ -606,6 +612,7 @@ class _Linear(paddle.autograd.PyLayer):
ctx.requires_wgrad = not weight.stop_gradient ctx.requires_wgrad = not weight.stop_gradient
ctx.requires_bgrad = use_bias and not bias.stop_gradient ctx.requires_bgrad = use_bias and not bias.stop_gradient
ctx.is_first_microbatch = is_first_microbatch ctx.is_first_microbatch = is_first_microbatch
ctx.reduce_scatter_output = gather_output
return out.reshape((-1, *inp.shape[1:-1], out.shape[-1])) return out.reshape((-1, *inp.shape[1:-1], out.shape[-1]))
...@@ -668,6 +675,10 @@ class _Linear(paddle.autograd.PyLayer): ...@@ -668,6 +675,10 @@ class _Linear(paddle.autograd.PyLayer):
# bgrad is fused with gemm for non-FP8 path # bgrad is fused with gemm for non-FP8 path
bgrad = bgrad_ bgrad = bgrad_
if ctx.reduce_scatter_output:
wgrad, _ = reduce_scatter(wgrad, ctx.tp_group)
bgrad, _ = reduce_scatter(bgrad, ctx.tp_group)
if not ctx.fp8_enabled or ctx.is_first_microbatch is None: if not ctx.fp8_enabled or ctx.is_first_microbatch is None:
weight_cache_grad = () weight_cache_grad = ()
else: else:
...@@ -742,6 +753,7 @@ class Linear(TransformerEngineBaseLayer): ...@@ -742,6 +753,7 @@ class Linear(TransformerEngineBaseLayer):
sequence_parallel: bool = False, sequence_parallel: bool = False,
tp_group: Union[dist_group_type, None] = None, tp_group: Union[dist_group_type, None] = None,
fuse_wgrad_accumulation: bool = False, fuse_wgrad_accumulation: bool = False,
gather_output: bool = False,
backend: str = "transformer_engine", backend: str = "transformer_engine",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -751,6 +763,7 @@ class Linear(TransformerEngineBaseLayer): ...@@ -751,6 +763,7 @@ class Linear(TransformerEngineBaseLayer):
self._weight_attr = weight_attr self._weight_attr = weight_attr
self._bias_attr = bias_attr self._bias_attr = bias_attr
self._dtype = self._helper.get_default_dtype() self._dtype = self._helper.get_default_dtype()
self.gather_output = gather_output
# Set parallel configs # Set parallel configs
self.tp_group, self.tp_size = get_tp_group_and_world_size( self.tp_group, self.tp_size = get_tp_group_and_world_size(
...@@ -854,6 +867,7 @@ class Linear(TransformerEngineBaseLayer): ...@@ -854,6 +867,7 @@ class Linear(TransformerEngineBaseLayer):
self.tp_size, self.tp_size,
self.fuse_wgrad_accumulation, self.fuse_wgrad_accumulation,
is_first_microbatch, is_first_microbatch,
self.gather_output,
) )
if not self.gemm_bias_fused_add: if not self.gemm_bias_fused_add:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment