Unverified Commit a4cb1d17 authored by Michael Goldfarb's avatar Michael Goldfarb Committed by GitHub
Browse files

[JAX] Correct fused attention output after each step of ring attention (#1393)



Correct fused attention output after each step to reduce intermediate memory use.
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
parent 61cf1020
......@@ -401,7 +401,7 @@ class TestDistributedContextParallelSelfAttn:
raise ValueError(f"Unsupported {qkv_layout=}")
return qkv_args
def impl_test_contex_parallel_attn(
def impl_test_context_parallel_attn(
self,
device_count,
mesh_shape,
......@@ -583,7 +583,7 @@ class TestDistributedContextParallelSelfAttn:
assert_allclose(target_grads[i], ref_grads[i], dtype=dtype)
def test_contex_parallel_allgather_attn(
def test_context_parallel_allgather_attn(
self,
device_count,
mesh_shape,
......@@ -596,7 +596,7 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout,
load_balanced,
):
return self.impl_test_contex_parallel_attn(
return self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
......@@ -623,7 +623,7 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout,
load_balanced,
):
return self.impl_test_contex_parallel_attn(
return self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
......
......@@ -1549,12 +1549,19 @@ class _FusedAttnCPWithP2PHelper:
"""Permutes kv around the ring as described by cp_perm."""
return lax_paral_op(kv, lax.ppermute, self.config.cp_axis, mesh=self.mesh, perm=cp_perm)
def correct_softmax_aux(self, softmax_aux, softmax_aux_per_step):
"""Apply soft max correction after an attention step."""
max_scale = jnp.maximum(softmax_aux, softmax_aux_per_step)
min_scale = jnp.minimum(softmax_aux, softmax_aux_per_step)
new_softmax_aux = max_scale + jnp.log(1 + jnp.exp(min_scale - max_scale))
return new_softmax_aux
@staticmethod
def correct_output_and_softmax_aux(output, softmax_aux, partial_output, partial_softmax_aux):
"""
Corrects the output and softmax_aux tensor after each iteration of ring attention.
See https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 for
derivation of this equation.
"""
new_out = output - jax.nn.sigmoid(partial_softmax_aux - softmax_aux).transpose(
0, 2, 1, 3
) * (output - partial_output)
new_aux = softmax_aux - jax.nn.log_sigmoid(softmax_aux - partial_softmax_aux)
return new_out, new_aux
def adjust_seqlen(self, seqlen, max_seqlen, idx):
"""Adjust the sequence length per step."""
......@@ -1615,10 +1622,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]
output_per_steps = jnp.zeros((cp_size, *q.shape), dtype=q.dtype)
softmax_aux_per_steps = jnp.zeros(
(cp_size, batch, head, q_max_seqlen, 1), dtype=jnp.float32
)
output = jnp.zeros(q.shape).astype(jnp.float32)
softmax_aux = jnp.full((batch, head, q_max_seqlen, 1), -jnp.inf, dtype=jnp.float32)
# RNG shape should be the shared shape. This is unused for ring attention as we do not
......@@ -1627,7 +1631,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype)
def scan_kv_block(idx, carry):
kv, softmax_aux, output_per_steps, softmax_aux_per_steps = carry
kv, output, softmax_aux = carry
# Send KV block to next step so we can overlap compute.
kv_next = helper.permute_kv(kv, cp_perm)
......@@ -1718,25 +1722,38 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
else:
output_per_step, softmax_aux_per_step = no_mask_compute()
softmax_aux = helper.correct_softmax_aux(softmax_aux, softmax_aux_per_step)
output_per_steps = output_per_steps.at[idx].set(output_per_step)
softmax_aux_per_steps = softmax_aux_per_steps.at[idx].set(softmax_aux_per_step)
def skip_correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
# No correction done here but we cast outputs to float32 and perform reduction
# in full precision.
# pylint: disable=unused-argument
return output_per_step.astype(jnp.float32), softmax_aux_per_step
def correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
return helper.correct_output_and_softmax_aux(
output, softmax_aux, output_per_step, softmax_aux_per_step
)
return (kv_next, softmax_aux, output_per_steps, softmax_aux_per_steps)
# first step there is no correction we get initial output and stats
output, softmax_aux = lax.cond(
(idx == 0),
skip_correction,
correction,
output,
softmax_aux,
output_per_step,
softmax_aux_per_step,
)
return (kv_next, output, softmax_aux)
carry = (kv, softmax_aux, output_per_steps, softmax_aux_per_steps)
carry = (kv, output, softmax_aux)
if helper.use_scanloop():
carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
else:
for i in range(0, cp_size):
carry = scan_kv_block(i, carry)
(kv, softmax_aux, output_per_steps, softmax_aux_per_steps) = carry
(kv, output, softmax_aux) = carry
output = jnp.zeros(q.shape).astype(jnp.float32)
for idx in range(cp_size):
output = output + output_per_steps[idx].astype(jnp.float32) * jnp.exp(
softmax_aux_per_steps[idx] - softmax_aux
).transpose(0, 2, 1, 3)
output = output.astype(q.dtype)
return output, softmax_aux, rng_state
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment