Unverified Commit a282136c authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128)



* add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>


---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent c9508000
...@@ -289,6 +289,13 @@ def _layernorm_mlp_fwd_rule( ...@@ -289,6 +289,13 @@ def _layernorm_mlp_fwd_rule(
bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
dot_1_output += jnp.reshape(bias_1, bias_1_new_shape) dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
# This sharding constraint is needed to correct the Shardy sharding propagation
if dot_2_input_axes is not None:
dot_1_output_axes = (
dot_2_input_axes[:-1] + (None,) + dot_2_input_axes[-1:]
) # add the act_num axis
dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes)
dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
# (batch..., hidden_in) -> (batch..., hidden) # (batch..., hidden_in) -> (batch..., hidden)
......
...@@ -165,7 +165,7 @@ def with_sharding_constraint_by_logical_axes( ...@@ -165,7 +165,7 @@ def with_sharding_constraint_by_logical_axes(
flax_rules = flax.linen.get_logical_axis_rules() flax_rules = flax.linen.get_logical_axis_rules()
if len(flax_rules) > 0: if len(flax_rules) > 0:
return flax.linen.with_logical_constraint( return flax.linen.with_logical_constraint(
x, logical_axis_names, fallback=flax.linen.spmd.RulesFallback.NO_CONSTRAINT x, logical_axis_names, fallback=flax.linen.spmd.RulesFallback.AXIS_IS_UNSHARDED
) )
except ImportError: except ImportError:
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment