Unverified Commit b0d562d8 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Fix `rng_state` shape in fused attention (#2217)



fix rng_state shape
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
parent ac4e0fd6
...@@ -1820,7 +1820,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1820,7 +1820,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
# RNG shape should be the shared shape. This is unused for ring attention as we do not # RNG shape should be the shared shape. This is unused for ring attention as we do not
# support dropout currently. # support dropout currently.
rng_state_shape = (result_infos[2].shape[0] // mesh.size, *result_infos[2].shape[1:]) rng_state_shape = (seed.shape[0], *result_infos[2].shape[1:])
rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype) rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype)
def scan_kv_block(idx, carry): def scan_kv_block(idx, carry):
...@@ -2306,7 +2306,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -2306,7 +2306,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
# RNG shape should be the shared shape. This is unused for ring attention as we do not # RNG shape should be the shared shape. This is unused for ring attention as we do not
# support dropout currently. # support dropout currently.
rng_state_shape = (result_infos[2].shape[0] // mesh.size, *result_infos[2].shape[1:]) rng_state_shape = (seed.shape[0], *result_infos[2].shape[1:])
rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype) rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype)
def scan_kv_block(idx, carry): def scan_kv_block(idx, carry):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment