You need to sign in or sign up before continuing.
[JAX] Correct fused attention output after each step of ring attention (#1393)
Correct fused attention output after each step to reduce intermediate memory use.
Signed-off-by:
Michael Goldfarb <mgoldfarb@nvidia.com>
Showing
Please register or sign in to comment