[JAX] Fix `rng_state` shape in fused attention (#2217)
fix rng_state shape Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com> Co-authored-by:
jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Showing
Please register or sign in to comment