Unverified Commit 3fc1e4bf authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Fix for TE GEMM - Always AllGather RHS non-contracting dims with FSDP axis (#2075)



* fix fsdp
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 0e3e270f
...@@ -476,21 +476,24 @@ class GemmPrimitive(BasePrimitive): ...@@ -476,21 +476,24 @@ class GemmPrimitive(BasePrimitive):
lhs_cspecs = tuple(s if s == reduce_spec else None for s in lhs_cspecs) lhs_cspecs = tuple(s if s == reduce_spec else None for s in lhs_cspecs)
rhs_cspecs = tuple(s if s == reduce_spec else None for s in rhs_cspecs) rhs_cspecs = tuple(s if s == reduce_spec else None for s in rhs_cspecs)
# Non-batched non-contracting dims of RHS needs to be unsharded (i.e. FSDP) # Non-contracting dims of RHS always needs to be gathered, i.e. for TP + activation_hidden
# Check if spec is not the batch-dim is not needed as rhs_non_cspecs never includes batch-dim # No batch-dim check needed as `rhs_non_cspecs` never contains batch-dim.
# rhs_specs only includes batch-dim in the Wgrad GEMM, but there batch-dim belongs to rhs_cspecs # In `rhs_specs`, the batch dim appears only in Wgrad GEMM under `rhs_cspecs`.
rhs_non_cspecs = tuple( rhs_non_cspecs = tuple(
None if spec in lhs_non_cspecs else spec for spec in rhs_non_cspecs None if spec in lhs_non_cspecs else spec for spec in rhs_non_cspecs
) )
else: else:
# Otherwise, require contracting dims of both operands to be unsharded # Otherwise, require contracting dims of both operands to be unsharded
lhs_cspecs = (None,) * len(lhs_cspecs) lhs_cspecs = (None,) * len(lhs_cspecs)
rhs_cspecs = (None,) * len(rhs_cspecs) rhs_cspecs = (None,) * len(rhs_cspecs)
# Non-batched non-contracting dims of LHS to be unsharded, i.e gather SP dim # Non-contracting dims of RHS always needs to be gathered along the FSDP axis
# The spec for batch_dim in lhs_non_cspecs won't ever appear in the rhs_non_cspecs as rhs_non_cspecs = tuple(
# rhs_non_cspecs never has batch-dim. Hence, spec for batch_dim of lhs_non_cspecs won't be None if spec is not None and "fsdp" in spec else spec for spec in rhs_non_cspecs
# overwrite )
# Non-contracting dims of LHS to be gathered along the SP axis.
# Minor note: This causes MaxText TP (= Megatron TP + activation_hidden sharding) gathering x for # Minor note: This causes MaxText TP (= Megatron TP + activation_hidden sharding) gathering x for
# dW1 = x^T * dY1 which is unexpected. This is a known issue and no solution has found yet. # dW1 = x^T * dY1 which is unexpected. This is a known issue and no solution has found yet.
lhs_non_cspecs = tuple(None if spec in rhs_non_cspecs else spec for spec in lhs_non_cspecs) lhs_non_cspecs = tuple(None if spec in rhs_non_cspecs else spec for spec in lhs_non_cspecs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment