[JAX] Context Parallel Attention with All-Gather (#1106)
Implementation of context parallel fused attention using all-gather.
Signed-off-by:
Michael Goldfarb <mgoldfarb@nvidia.com>
Showing
Please register or sign in to comment
Implementation of context parallel fused attention using all-gather.
Signed-off-by:
Michael Goldfarb <mgoldfarb@nvidia.com>