[Paddle] Fix forward and backward logic of te.Linear(parallel_mode='column')...
[Paddle] Fix forward and backward logic of te.Linear(parallel_mode='column') to adapt DiT of PaddleMIX (#963) [Paddle] Fix forward and backward of Linear(parallel_mode='column') When te.Linear(parallel_mode='column') is not used in pairs with te.Linear(parallel_mode='row'), the output should to be all-gathered when forward and reduce-scattered when backward. Signed-off-by:minyu <minyu@nvidia.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Showing
Please register or sign in to comment