Unverified Commit f5d720a0 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

Incorrect use of extend_fsdp_sharding_meta() in cross_fused_attn() (#482)



fixed incorrect of extend_fsdp_sharding_meta() in cross_fused_attn()
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
parent 2a86df2b
...@@ -206,7 +206,7 @@ def cross_fused_attn(q: jnp.ndarray, ...@@ -206,7 +206,7 @@ def cross_fused_attn(q: jnp.ndarray,
tp_dims=([2, 3, None, None], [2]), tp_dims=([2, 3, None, None], [2]),
dp_axis_name=dp_axis_name, dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name) tp_axis_name=tp_axis_name)
sharding_meta = extend_fsdp_sharding_meta(sharding_meta, {0: 0, 2: 0}) sharding_meta, _ = extend_fsdp_sharding_meta(sharding_meta, {0: 0, 2: 0})
inputs_ = tuple( inputs_ = tuple(
jnp.reshape(x, new_shape) if x is not None else None jnp.reshape(x, new_shape) if x is not None else None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment