ring_attn.py 14.2 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
import torch
import torch.distributed as dist
import torch.nn.functional as F
from loguru import logger

from lightx2v.utils.envs import *
from lightx2v.utils.quant_utils import dequant_fp8_vllm, quant_fp8_vllm
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER

from .template import AttnWeightTemplate
from .utils.ring_comm import RingComm

try:
    import flash_attn
    from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
    logger.info("flash_attn_varlen_func not found, please install flash_attn2 first")
    flash_attn_varlen_func = None


@torch.jit.script
def _update_out_and_lse(
    out,
    lse,
    block_out,
    block_lse,
):
    block_out = block_out.to(torch.float32)
    block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)

    # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
    # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out
    # For additional context and discussion, please refer to:
    # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
    out = out - F.sigmoid(block_lse - lse) * (out - block_out)
    lse = lse - F.logsigmoid(lse - block_lse)
    return out, lse


@ATTN_WEIGHT_REGISTER("ring")
class RingAttnWeight(AttnWeightTemplate):
    def __init__(self):
        self.config = {}
        self.helper = RingAttnHelper()

    def apply(
        self,
        q,
        k,
        v,
        slice_qkv_len,
        cu_seqlens_qkv,
        attention_module=None,
        attention_type="flash_attn2",
        seq_p_group=None,
        use_fp8_comm=False,
        use_tensor_fusion=False,
        enable_head_parallel=False,
        **kwargs,
    ):
        """
        执行 Ring 注意力机制,结合图像和文本的查询、键和值。

        参数:
            q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
            k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
            v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
            slice_qkv_len (int): 图像查询、键和值的长度
            cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
            attention_type (str): 注意力类型,默认为 "flash_attn2"

        返回:
            torch.Tensor: 计算得到的注意力结果
        """
        assert not enable_head_parallel, "RingAttn can't support head parallel mode."

        use_kv_fusion = use_tensor_fusion
        # 获取当前进程的排名和全局进程数
        cur_rank = dist.get_rank(seq_p_group)
        world_size = dist.get_world_size(seq_p_group)

        img_qkv_len = slice_qkv_len
        txt_qkv_len, txt_mask_len = self.helper._get_text_lengths(cu_seqlens_qkv, img_qkv_len)

        # if RING_COMM is None:
        #     init_ring_comm()

        RING_COMM = RingComm(seq_p_group)

        # if len(cu_seqlens_qkv) == 3:
        #     txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len  # 文本查询、键和值的长度
        #     txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len  # 文本掩码长度
        # elif len(cu_seqlens_qkv) == 2:
        #     txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len  # 文本查询、键和值的长度
        #     txt_mask_len = None
        q = q.unsqueeze(0)
        k = k.unsqueeze(0)
        v = v.unsqueeze(0)

        heads, hidden_dims = k.shape[-2], k.shape[-1]
        img_q, img_k, img_v = q[:, :img_qkv_len, :, :].contiguous(), k[:, :img_qkv_len, :, :].contiguous(), v[:, :img_qkv_len, :, :].contiguous()
        txt_q, txt_k, txt_v = (
            q[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(),
            k[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(),
            v[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(),
        )

        out, lse, next_k, next_v = None, None, None, None

        if len(cu_seqlens_qkv) == 3:
            q = torch.cat((img_q, txt_q), dim=1)
        k = img_k
        v = img_v

        if use_kv_fusion:
            txt_kv = torch.stack([txt_k, txt_v], dim=0).reshape(2, txt_qkv_len, heads, hidden_dims).contiguous()
            kv, original_dtype, original_shape = self.helper._prepare_kv_tensors(k, v, use_kv_fusion)
        else:
            original_dtype = k.dtype
            original_shape = k.shape

        for step in range(world_size):
            if step + 1 != world_size:
                if use_fp8_comm:
                    if use_kv_fusion:
                        next_kv_fp8, next_kv_scale = self.helper._send_recv_tensor(kv, hidden_dims, RING_COMM, use_fp8_comm, original_shape)
                    else:
                        next_k_fp8, next_k_scale = self.helper._send_recv_tensor(k, hidden_dims, RING_COMM, use_fp8_comm, original_shape)
                        next_v_fp8, next_v_scale = self.helper._send_recv_tensor(v, hidden_dims, RING_COMM, use_fp8_comm, original_shape)
                else:
                    if use_kv_fusion:
                        next_kv = self.helper._send_recv_tensor(kv, hidden_dims, RING_COMM, use_fp8_comm, original_shape)[0]
                    else:
                        next_k = self.helper._send_recv_tensor(k, hidden_dims, RING_COMM, use_fp8_comm, original_shape)[0]
                        next_v = self.helper._send_recv_tensor(v, hidden_dims, RING_COMM, use_fp8_comm, original_shape)[0]
                RING_COMM.commit()

            if step + 1 == world_size:
                if use_kv_fusion:
                    kv = torch.cat((kv, txt_kv), dim=1)
                else:
                    k = torch.cat((k, txt_k), dim=1)
                    v = torch.cat((v, txt_v), dim=1)

            if use_kv_fusion:
                block_out, block_lse = self.ring_attn_sub_kv_fusion(q, kv)
            else:
                block_out, block_lse = self.ring_attn_sub(q, k, v)

            out, lse = self.update_out_and_lse(out, lse, block_out, block_lse)

            if step + 1 != world_size:
                RING_COMM.wait()

                if use_fp8_comm:
                    if use_kv_fusion:
                        kv = self.helper._dequantize_received(next_kv_fp8, next_kv_scale, original_dtype, original_shape, use_kv_fusion=True, is_kv_fusion=True)
                    else:
                        k, v = self.helper._dequantize_received(
                            next_k_fp8, next_k_scale, original_dtype, original_shape, use_kv_fusion=False, is_kv_fusion=False, v_fp8=next_v_fp8, v_scale=next_v_scale
                        )
                else:
                    if use_kv_fusion:
                        kv = next_kv
                    else:
                        k, v = next_k, next_v

        attn1 = out.to(GET_DTYPE()).squeeze(0).reshape(img_qkv_len + txt_qkv_len, -1)

        if txt_mask_len > 0:
            attn2, *_ = flash_attn.flash_attn_interface._flash_attn_forward(
                q[:, -(txt_mask_len - txt_qkv_len) :, :, :].contiguous(),
                k[:, -(txt_mask_len - txt_qkv_len) :, :, :].contiguous(),
                v[:, -(txt_mask_len - txt_qkv_len) :, :, :].contiguous(),
                dropout_p=0.0,
                softmax_scale=q.shape[-1] ** (-0.5),
                causal=False,
                window_size_left=-1,
                window_size_right=-1,
                softcap=0.0,
                alibi_slopes=None,
                return_softmax=False,
            )

            attn2 = attn2.to(GET_DTYPE()).squeeze(0).reshape((txt_mask_len - txt_qkv_len), -1)
            attn1 = torch.cat([attn1, attn2], dim=0)

        return attn1

    def ring_attn_sub_kv_fusion(self, q, kv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), softcap=0.0, alibi_slopes=None, return_softmax=False):
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)

        block_out, block_lse, _, _ = flash_attn.flash_attn_interface._flash_attn_forward(
            q,
            kv[:1, :, :, :],
            kv[1:, :, :, :],
            dropout_p=dropout_p,
            softmax_scale=softmax_scale,
            causal=causal,
            window_size_left=window_size[0],
            window_size_right=window_size[1],
            softcap=softcap,
            alibi_slopes=alibi_slopes,
            return_softmax=return_softmax,
        )
        return block_out, block_lse

    def ring_attn_sub(self, q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), softcap=0.0, alibi_slopes=None, return_softmax=False):
        if softmax_scale is None:
            softmax_scale = q.shape[-1] ** (-0.5)
        block_out, block_lse, _, _ = flash_attn.flash_attn_interface._flash_attn_forward(
            q,
            k,
            v,
            dropout_p=dropout_p,
            softmax_scale=softmax_scale,
            causal=causal,
            window_size_left=window_size[0],
            window_size_right=window_size[1],
            softcap=softcap,
            alibi_slopes=alibi_slopes,
            return_softmax=return_softmax,
        )
        return block_out, block_lse

    def update_out_and_lse(
        self,
        out,
        lse,
        block_out,
        block_lse,
        slice_=None,
    ):
        if out is None:
            if slice_ is not None:
                raise RuntimeError("first update_out_and_lse should not pass slice_ args")
            out = block_out.to(torch.float32)
            lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
        elif slice_ is not None:
            slice_out, slice_lse = out[slice_], lse[slice_]
            slice_out, slice_lse = _update_out_and_lse(slice_out, slice_lse, block_out, block_lse)
            out[slice_], lse[slice_] = slice_out, slice_lse
        else:
            out, lse = _update_out_and_lse(out, lse, block_out, block_lse)
        return out, lse


class RingAttnHelper:
    """辅助函数类,处理 Ring Attention 中的量化、通信和反量化逻辑"""

    @staticmethod
    def _quant_and_send(tensor, hidden_dims, comm, original_shape=None):
        """
        对张量进行 FP8 量化并通过通信器发送/接收

        参数:
            tensor: 要量化和发送的张量
            hidden_dims: 隐藏维度大小
            comm: 通信器对象
            original_shape: 原始形状(用于 reshape 回原始形状)

        返回:
            tuple: (量化后的张量, scale 张量)
        """
        if original_shape is None:
            original_shape = tensor.shape

        # 量化为 FP8
        tensor_fp8, tensor_scale = quant_fp8_vllm(tensor.reshape(-1, hidden_dims))

        # reshape 回原始形状
        tensor_fp8 = tensor_fp8.reshape(original_shape)
        tensor_scale = tensor_scale.reshape(original_shape[0], original_shape[1], original_shape[2], 1)

        # 发送/接收量化后的张量
        next_tensor_fp8 = comm.send_recv(tensor_fp8)
        next_tensor_scale = comm.send_recv(tensor_scale)

        return next_tensor_fp8, next_tensor_scale

    @staticmethod
    def _prepare_kv_tensors(k, v, use_kv_fusion):
        """
        准备 K 和 V 张量,根据是否使用 KV 融合返回适当的张量

        参数:
            k: 键张量
            v: 值张量
            use_kv_fusion: 是否使用 KV 融合

        返回:
            tuple: (主张量, 原始数据类型, 原始形状)
        """
        original_dtype = k.dtype
        original_shape = k.shape

        if use_kv_fusion:
            # 融合 K 和 V
            kv = torch.stack([k, v], dim=0).reshape(2, k.shape[1], k.shape[2], k.shape[3]).contiguous()
            return kv, original_dtype, kv.shape
        else:
            return k, original_dtype, original_shape

    @staticmethod
    def _dequantize_received(next_tensor_fp8, next_tensor_scale, original_dtype, original_shape, use_kv_fusion=False, is_kv_fusion=False, v_fp8=None, v_scale=None):
        """
        反量化接收到的 FP8 张量

        参数:
            next_tensor_fp8: 接收到的量化张量
            next_tensor_scale: 接收到的 scale 张量
            original_dtype: 原始数据类型
            original_shape: 原始形状
            use_kv_fusion: 是否使用 KV 融合模式
            is_kv_fusion: 当前张量是否为 KV 融合张量
            v_fp8, v_scale: 分离模式下的 V 张量和 scale

        返回:
            tuple: 反量化后的张量 (k, v) 或 kv
        """
        if use_kv_fusion and is_kv_fusion:
            # KV 融合模式
            return dequant_fp8_vllm(next_tensor_fp8, next_tensor_scale, original_dtype)
        elif not use_kv_fusion:
            # 分离模式
            k = dequant_fp8_vllm(next_tensor_fp8, next_tensor_scale, original_dtype)
            v = dequant_fp8_vllm(v_fp8, v_scale, original_dtype)
            return k, v
        else:
            # 默认返回单个张量
            return dequant_fp8_vllm(next_tensor_fp8, next_tensor_scale, original_dtype)

    @staticmethod
    def _send_recv_tensor(tensor, hidden_dims, comm, use_fp8_comm, original_shape=None):
        """
        发送/接收张量,根据是否使用 FP8 选择通信方式

        参数:
            tensor: 要发送的张量
            hidden_dims: 隐藏维度大小
            comm: 通信器对象
            use_fp8_comm: 是否使用 FP8 通信
            original_shape: 原始形状

        返回:
            tuple: 接收到的张量(和可能的 scale)
        """
        if use_fp8_comm:
            return RingAttnHelper._quant_and_send(tensor, hidden_dims, comm, original_shape)
        else:
            next_tensor = comm.send_recv(tensor)
            return next_tensor, None

    @staticmethod
    def _get_text_lengths(cu_seqlens_qkv, img_qkv_len):
        """
        从累积序列长度中获取文本长度

        参数:
            cu_seqlens_qkv: 累积序列长度
            img_qkv_len: 图像序列长度

        返回:
            tuple: (文本QKV长度, 文本掩码长度)
        """
        if len(cu_seqlens_qkv) == 3:
            txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len
            txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len
        elif len(cu_seqlens_qkv) == 2:
            txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len
            txt_mask_len = 0
        else:
            raise ValueError(f"Invalid cu_seqlens_qkv length: {len(cu_seqlens_qkv)}")

        return txt_qkv_len, txt_mask_len