ring_flash_attn.py 9.91 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Jiarui Fang.
# Adapted from https://github.com/feifeibear/long-context-attention


import torch

from vllm_omni.diffusion.attention.backends.ring.ring_selector import AttnType, select_flash_attn_impl
from vllm_omni.diffusion.attention.backends.ring.ring_utils import update_out_and_lse
from vllm_omni.diffusion.distributed.comm import RingComm


def ring_flash_attn_forward(
    process_group,
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    softmax_scale,
    dropout_p=0,
    causal=True,
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
    deterministic=False,
    attn_type: AttnType = AttnType.FA,
    attn_processor=None,
    joint_tensor_key=None,
    joint_tensor_value=None,
    joint_strategy="front",
):
    # Validate causal + joint_strategy combination
    # When causal=True and joint_strategy="rear", the causal mask would incorrectly
    # prevent local query tokens from attending to joint key tokens (which are
    # concatenated at the end). This breaks the semantics where joint tokens
    # (e.g., text conditioning) should be visible to all local tokens.
    if causal and joint_tensor_key is not None and joint_strategy == "rear":
        raise ValueError(
            "joint_strategy='rear' is not compatible with causal=True in Ring Attention. "
            "When using causal attention with joint tokens, use joint_strategy='front' "
            "to ensure joint tokens act as a visible prefix for all local tokens. "
            "With 'rear' strategy, the causal mask would incorrectly block local tokens "
            "from seeing the joint tokens."
        )

    comm = RingComm(process_group)

    out = None
    lse = None

    next_k, next_v = None, None

    # Check and adjust q, k, v to be contiguous
    if not q.is_contiguous():
        q = q.contiguous()
    if not k.is_contiguous():
        k = k.contiguous()
    if not v.is_contiguous():
        v = v.contiguous()

    for step in range(comm.world_size):
        if step + 1 != comm.world_size:
            next_k: torch.Tensor
            next_v: torch.Tensor
            next_k = comm.send_recv(k)
            next_v = comm.send_recv(v)
            comm.commit()

        if not causal or step <= comm.rank:
            step_k = k
            step_v = v
            if step == 0 and joint_tensor_key is not None:
                if joint_strategy == "front":
                    step_k = torch.cat([joint_tensor_key, step_k], dim=1)
                    step_v = torch.cat([joint_tensor_value, step_v], dim=1)
                else:
                    step_k = torch.cat([step_k, joint_tensor_key], dim=1)
                    step_v = torch.cat([step_v, joint_tensor_value], dim=1)

            fn = select_flash_attn_impl(attn_type, stage="fwd-only", attn_processor=attn_processor)
            block_out, block_lse = fn(
                q,
                step_k,
                step_v,
                dropout_p=dropout_p,
                softmax_scale=softmax_scale,
                causal=causal and step == 0,
                window_size=window_size,
                softcap=softcap,
                alibi_slopes=alibi_slopes,
                return_softmax=True and dropout_p > 0,
            )

            # Ensure block_out is contiguous if needed, though usually it is from FA

            if attn_type == AttnType.SPARSE_SAGE:
                out, lse = block_out, block_lse
            else:
                out, lse = update_out_and_lse(out, lse, block_out, block_lse)

        if step + 1 != comm.world_size:
            comm.wait()
            k = next_k
            v = next_v

    out = out.to(q.dtype)
    if attn_type != AttnType.SPARSE_SAGE:
        lse = lse.squeeze(dim=-1).transpose(1, 2)
    return out, lse


class RingFlashAttnFunc(torch.autograd.Function):
    """Ring Flash Attention autograd function (inference only, no backward)."""

    @staticmethod
    def forward(
        ctx,
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        softcap,
        alibi_slopes,
        deterministic,
        return_softmax,
        group,
        attn_type,
        attn_processor,
        joint_tensor_key=None,
        joint_tensor_value=None,
        joint_strategy="front",
    ):
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        assert alibi_slopes is None
        q = q.contiguous()
        k = k.contiguous()
        v = v.contiguous()

        out, softmax_lse = ring_flash_attn_forward(
            group,
            q,
            k,
            v,
            softmax_scale=softmax_scale,
            dropout_p=dropout_p,
            causal=causal,
            window_size=window_size,
            softcap=softcap,
            alibi_slopes=alibi_slopes,
            deterministic=False,
            attn_type=attn_type,
            attn_processor=attn_processor,
            joint_tensor_key=joint_tensor_key,
            joint_tensor_value=joint_tensor_value,
            joint_strategy=joint_strategy,
        )
        return out if not return_softmax else (out, softmax_lse, None)


def ring_flash_attn_qkvpacked_func(
    qkv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
    deterministic=False,
    return_attn_probs=False,
    group=None,
    attn_type: AttnType = AttnType.FA,
):
    return RingFlashAttnFunc.apply(
        qkv[:, :, 0],
        qkv[:, :, 1],
        qkv[:, :, 2],
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        softcap,
        alibi_slopes,
        deterministic,
        return_attn_probs,
        group,
        attn_type,
        None,  # attn_processor
        None,  # joint_tensor_key
        None,  # joint_tensor_value
        "front",  # joint_strategy
    )


def ring_flash_attn_kvpacked_func(
    q,
    kv,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
    deterministic=False,
    return_attn_probs=False,
    group=None,
    attn_type: AttnType = AttnType.FA,
):
    return RingFlashAttnFunc.apply(
        q,
        kv[:, :, 0],
        kv[:, :, 1],
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        softcap,
        alibi_slopes,
        deterministic,
        return_attn_probs,
        group,
        attn_type,
        None,  # attn_processor
        None,  # joint_tensor_key
        None,  # joint_tensor_value
        "front",  # joint_strategy
    )


def ring_flash_attn_func(
    q,
    k,
    v,
    dropout_p=0.0,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),
    softcap=0.0,
    alibi_slopes=None,
    deterministic=False,
    return_attn_probs=False,
    group=None,
    attn_type: AttnType = AttnType.FA,
    attn_processor=None,
    joint_tensor_key=None,
    joint_tensor_value=None,
    joint_strategy="front",
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, None]:
    """Ring Attention forward pass using Flash Attention backend.

    Implements Ring Attention with sequence parallelism using a ring-based P2P
    communication pattern. The sequence dimension is sharded across devices, and
    Key/Value blocks are circulated through the ring to accumulate attention results.

    Args:
        q (torch.Tensor): Query tensor of shape (batch, seq_len, num_heads, head_dim).
            Sequence dimension is sharded across the ring group.
        k (torch.Tensor): Key tensor of shape (batch, seq_len, num_heads, head_dim).
            Sequence dimension is sharded across the ring group.
        v (torch.Tensor): Value tensor of shape (batch, seq_len, num_heads, head_dim).
            Sequence dimension is sharded across the ring group.
        dropout_p (float): Dropout probability. Defaults to 0.0.
        softmax_scale (float | None): Scaling factor for softmax.
            If None, computed as head_dim^(-0.5).
        causal (bool): Whether to apply causal masking. Defaults to False.
        window_size (tuple[int, int]): Sliding window size for attention.
            (-1, -1) means no windowing.
        softcap (float): Soft capping value for attention logits. Defaults to 0.0.
        alibi_slopes (torch.Tensor | None): ALiBi slopes for positional bias.
            Not supported.
        deterministic (bool): Whether to use deterministic algorithms.
            Defaults to False.
        return_attn_probs (bool): If True, returns (out, softmax_lse, None).
            Defaults to False.
        group (ProcessGroup | None): Process group for ring communication.
            Defaults to None.
        attn_type (AttnType): Flash Attention implementation type
            (AttnType.FA, AttnType.FA3, etc.).
        attn_processor (Callable | None): Custom attention processor for sparse
            attention. Defaults to None.
        joint_tensor_key (torch.Tensor | None): Additional key tensor for joint
            attention (e.g., text + image). Concatenated only at step=0.
            Defaults to None.
        joint_tensor_value (torch.Tensor | None): Additional value tensor for
            joint attention (e.g., text + image). Concatenated only at step=0.
            Defaults to None.
        joint_strategy (str): Concatenation strategy ("front" or "back").
            Defaults to "front".

    Returns:
        Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, None]]:
            - If return_attn_probs is False: Output tensor (batch, seq_len, num_heads, head_dim).
            - If return_attn_probs is True: A tuple (out, softmax_lse, None).
    """
    return RingFlashAttnFunc.apply(
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        softcap,
        alibi_slopes,
        deterministic,
        return_attn_probs,
        group,
        attn_type,
        attn_processor,
        joint_tensor_key,
        joint_tensor_value,
        joint_strategy,
    )