ulysses_attn.py 16 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
import torch
import torch.distributed as dist
PengGao's avatar
PengGao committed
3
4
5
6

from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER

from .template import AttnWeightTemplate
helloyongyang's avatar
helloyongyang committed
7
from .utils.all2all import all2all_head2seq, all2all_seq2head
helloyongyang's avatar
helloyongyang committed
8
9
10
11
12
13
14


@ATTN_WEIGHT_REGISTER("ulysses")
class UlyssesAttnWeight(AttnWeightTemplate):
    def __init__(self):
        self.config = {}

helloyongyang's avatar
helloyongyang committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
    def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None, model_cls=None):
        """
        执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。

        参数:
            q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
            k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
            v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
            img_qkv_len (int): 图像查询、键和值的长度
            cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
            attention_type (str): 注意力类型,默认为 "flash_attn2"

        返回:
            torch.Tensor: 计算得到的注意力结果
        """
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
30
31
32
33
34
        if len(q.shape) == 4:
            q = q.reshape(-1, q.shape[-2], q.shape[-1])
            k = k.reshape(-1, k.shape[-2], k.shape[-1])
            v = v.reshape(-1, v.shape[-2], v.shape[-1])

helloyongyang's avatar
helloyongyang committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        # 获取当前进程的排名和全局进程数
        world_size = dist.get_world_size(seq_p_group)
        cur_rank = dist.get_rank(seq_p_group)

        # 获取序列长度和文本相关的长度
        seq_len = q.shape[0]
        if len(cu_seqlens_qkv) == 3:
            txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len  # 文本查询、键和值的长度
            txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len  # 文本掩码长度
        elif len(cu_seqlens_qkv) == 2:
            txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len  # 文本查询、键和值的长度
            txt_mask_len = None

        # 获取查询张量的头数和隐藏维度
        _, heads, hidden_dims = q.shape
        shard_heads = heads // world_size  # 每个进程处理的头数
        shard_seqlen = img_qkv_len  # 每个进程处理的序列长度

        # 分割图像和文本的查询、键和值
        img_q, img_k, img_v = q[:img_qkv_len, :, :].contiguous(), k[:img_qkv_len, :, :].contiguous(), v[:img_qkv_len, :, :].contiguous()
        txt_q, txt_k, txt_v = q[img_qkv_len:, :, :].contiguous(), k[img_qkv_len:, :, :].contiguous(), v[img_qkv_len:, :, :].contiguous()

        # 将图像的查询、键和值转换为头的格式
        img_q = all2all_seq2head(img_q, group=seq_p_group)
        img_k = all2all_seq2head(img_k, group=seq_p_group)
        img_v = all2all_seq2head(img_v, group=seq_p_group)
Kane's avatar
Kane committed
61
        self.device_synchronize()  # 确保CUDA操作完成
helloyongyang's avatar
helloyongyang committed
62
63
64
65
66
67
68
69
70
71
72
73

        # 处理文本的查询、键和值,选择当前进程的头
        txt_q = txt_q[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
        txt_k = txt_k[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
        txt_v = txt_v[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]

        # 合并图像和文本的查询、键和值
        q = torch.cat((img_q, txt_q), dim=0)
        k = torch.cat((img_k, txt_k), dim=0)
        v = torch.cat((img_v, txt_v), dim=0)

        # 初始化累积序列长度张量
Kane's avatar
Kane committed
74
        cu_seqlens_qkv = torch.zeros([2], dtype=torch.int32, device=self.config.get("run_device", "cuda"))
helloyongyang's avatar
helloyongyang committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        s = txt_qkv_len + img_q.shape[0]  # 计算文本和图像的总长度
        s1 = s  # 当前样本的结束位置
        cu_seqlens_qkv[1] = s1  # 设置累积序列长度
        if txt_mask_len:
            s2 = txt_mask_len + img_q.shape[0]  # 文本掩码的结束位置
            cu_seqlens_qkv = torch.cat(cu_seqlens_qkv, s2)
        max_seqlen_qkv = img_q.shape[0] + txt_q.shape[0]  # 最大序列长度

        # 调用注意力函数计算注意力结果
        # attn = attention(attention_type=attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv)
        attn = attention_module.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv, model_cls=model_cls)

        # 分割图像和文本的注意力结果
        img_attn, txt_attn = attn[: img_q.shape[0], :], attn[img_q.shape[0] :,]

        # 收集所有进程的文本注意力结果
        gathered_txt_attn = [torch.empty_like(txt_attn) for _ in range(world_size)]
        dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group)

        img_attn = self._reshape_img_attn(img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group)

        txt_attn = torch.cat(gathered_txt_attn, dim=1)  # 合并所有进程的文本注意力结果

        # 合并图像和文本的注意力结果
        attn = torch.cat([img_attn, txt_attn], dim=0)

        return attn  # 返回最终的注意力结果

    @torch.compiler.disable
    def _reshape_img_attn(self, img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group):
        img_attn = img_attn.reshape(world_size * shard_seqlen, shard_heads, hidden_dims)  # 重塑图像注意力结果
        img_attn = all2all_head2seq(img_attn, group=seq_p_group)  # 将头的格式转换回序列格式
        img_attn = img_attn.reshape(shard_seqlen, -1)  # 重塑为 [shard_seqlen, -1] 形状
Kane's avatar
Kane committed
108
        self.device_synchronize()  # 确保CUDA操作完成
helloyongyang's avatar
helloyongyang committed
109
110
        return img_attn

Kane's avatar
Kane committed
111
112
113
114
115
116
117
118
119
120
    def device_synchronize(
        self,
    ):
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            self.config["run_device"] = "cuda"
        elif hasattr(torch, "mlu") and torch.mlu.is_available():
            torch.mlu.synchronize()
            self.config["run_device"] = "mlu"

helloyongyang's avatar
helloyongyang committed
121
122
123
124
125
126

@ATTN_WEIGHT_REGISTER("ulysses-4090")
class Ulysses4090AttnWeight(AttnWeightTemplate):
    def __init__(self):
        self.config = {}

127
    def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None, model_cls=None):
helloyongyang's avatar
helloyongyang committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        """
        执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。

        参数:
            q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
            k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
            v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
            img_qkv_len (int): 图像查询、键和值的长度
            cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
            attention_type (str): 注意力类型,默认为 "flash_attn2"

        返回:
            torch.Tensor: 计算得到的注意力结果
        """
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
142
143
144
145
        if len(q.shape) == 4:
            q = q.reshape(-1, q.shape[-2], q.shape[-1])
            k = k.reshape(-1, k.shape[-2], k.shape[-1])
            v = v.reshape(-1, v.shape[-2], v.shape[-1])
helloyongyang's avatar
helloyongyang committed
146
        # 获取当前进程的排名和全局进程数
147
148
        world_size = dist.get_world_size(seq_p_group)
        cur_rank = dist.get_rank(seq_p_group)
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
149
150
151
        global_world_size = dist.get_world_size()
        global_rank = dist.get_global_rank(seq_p_group, cur_rank)
        cfg_p_group_index = global_rank // world_size
helloyongyang's avatar
helloyongyang committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170

        # 获取序列长度和文本相关的长度
        seq_len = q.shape[0]
        if len(cu_seqlens_qkv) == 3:
            txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len  # 文本查询、键和值的长度
            txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len  # 文本掩码长度
        elif len(cu_seqlens_qkv) == 2:
            txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len  # 文本查询、键和值的长度
            txt_mask_len = None

        # 获取查询张量的头数和隐藏维度
        _, heads, hidden_dims = q.shape
        shard_heads = heads // world_size  # 每个进程处理的头数
        shard_seqlen = img_qkv_len  # 每个进程处理的序列长度

        # 分割图像和文本的查询、键和值
        img_q, img_k, img_v = q[:img_qkv_len, :, :].contiguous(), k[:img_qkv_len, :, :].contiguous(), v[:img_qkv_len, :, :].contiguous()
        txt_q, txt_k, txt_v = q[img_qkv_len:, :, :].contiguous(), k[img_qkv_len:, :, :].contiguous(), v[img_qkv_len:, :, :].contiguous()

STwangyingrui's avatar
STwangyingrui committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        # 计算每个进程应该持有的头数分片
        num_heads = img_q.shape[1]
        shard_heads = num_heads // world_size

        # 将 QKV 按头维度切分成 N 份,每份大小为 D/N
        q_shards = [img_q[:, i * shard_heads : (i + 1) * shard_heads, :].contiguous() for i in range(world_size)]
        k_shards = [img_k[:, i * shard_heads : (i + 1) * shard_heads, :].contiguous() for i in range(world_size)]
        v_shards = [img_v[:, i * shard_heads : (i + 1) * shard_heads, :].contiguous() for i in range(world_size)]

        # 准备接收缓冲区
        gathered_q_shards = [None] * world_size
        gathered_k_shards = [None] * world_size
        gathered_v_shards = [None] * world_size
        for target_rank in range(world_size):
            if target_rank != cur_rank:
                gathered_q_shards[target_rank] = torch.empty_like(q_shards[target_rank])
                gathered_k_shards[target_rank] = torch.empty_like(k_shards[target_rank])
                gathered_v_shards[target_rank] = torch.empty_like(v_shards[target_rank])
            else:
                gathered_q_shards[cur_rank] = q_shards[cur_rank]
                gathered_k_shards[cur_rank] = k_shards[cur_rank]
                gathered_v_shards[cur_rank] = v_shards[cur_rank]

        # 异步发起通信后同步
        for target_rank in range(world_size):
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
196
            target_global_rank = cfg_p_group_index * world_size + target_rank
STwangyingrui's avatar
STwangyingrui committed
197
198
199
            if target_rank != cur_rank:
                # 避免死锁: 按 rank 顺序决定发送/接收顺序
                if cur_rank < target_rank:
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
200
201
202
203
204
205
                    sendq_req = dist.isend(q_shards[target_rank], dst=target_global_rank, group=seq_p_group)
                    sendk_req = dist.isend(k_shards[target_rank], dst=target_global_rank, group=seq_p_group)
                    sendv_req = dist.isend(v_shards[target_rank], dst=target_global_rank, group=seq_p_group)
                    recvq_req = dist.irecv(gathered_q_shards[target_rank], src=target_global_rank, group=seq_p_group)
                    recvk_req = dist.irecv(gathered_k_shards[target_rank], src=target_global_rank, group=seq_p_group)
                    recvv_req = dist.irecv(gathered_v_shards[target_rank], src=target_global_rank, group=seq_p_group)
STwangyingrui's avatar
STwangyingrui committed
206
                else:
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
207
208
209
210
211
212
                    recvq_req = dist.irecv(gathered_q_shards[target_rank], src=target_global_rank, group=seq_p_group)
                    recvk_req = dist.irecv(gathered_k_shards[target_rank], src=target_global_rank, group=seq_p_group)
                    recvv_req = dist.irecv(gathered_v_shards[target_rank], src=target_global_rank, group=seq_p_group)
                    sendq_req = dist.isend(q_shards[target_rank], dst=target_global_rank, group=seq_p_group)
                    sendk_req = dist.isend(k_shards[target_rank], dst=target_global_rank, group=seq_p_group)
                    sendv_req = dist.isend(v_shards[target_rank], dst=target_global_rank, group=seq_p_group)
STwangyingrui's avatar
STwangyingrui committed
213
214
215
216
217
218
219
220
221
222
223
224
225
                sendq_req.wait()
                sendk_req.wait()
                sendv_req.wait()
                recvq_req.wait()
                recvk_req.wait()
                recvv_req.wait()

        # 拼接所有分片 (在序列维度上)
        # 每个 gathered_*_shards[i] 的形状是 (seq_len/N, num_heads/N, head_dim)
        # 拼接后形状是 (seq_len, num_heads/N, head_dim)
        img_q = torch.cat(gathered_q_shards, dim=0)
        img_k = torch.cat(gathered_k_shards, dim=0)
        img_v = torch.cat(gathered_v_shards, dim=0)
helloyongyang's avatar
helloyongyang committed
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248

        # 处理文本的查询、键和值,选择当前进程的头
        txt_q = txt_q[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
        txt_k = txt_k[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
        txt_v = txt_v[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]

        # 合并图像和文本的查询、键和值
        q = torch.cat((img_q, txt_q), dim=0)
        k = torch.cat((img_k, txt_k), dim=0)
        v = torch.cat((img_v, txt_v), dim=0)

        # 初始化累积序列长度张量
        cu_seqlens_qkv = torch.zeros([2], dtype=torch.int32, device="cuda")
        s = txt_qkv_len + img_q.shape[0]  # 计算文本和图像的总长度
        s1 = s  # 当前样本的结束位置
        cu_seqlens_qkv[1] = s1  # 设置累积序列长度
        if txt_mask_len:
            s2 = txt_mask_len + img_q.shape[0]  # 文本掩码的结束位置
            cu_seqlens_qkv = torch.cat(cu_seqlens_qkv, s2)
        max_seqlen_qkv = img_q.shape[0] + txt_q.shape[0]  # 最大序列长度

        # 调用注意力函数计算注意力结果
        # attn = attention(attention_type=attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv)
249
        attn = attention_module.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv, model_cls=model_cls)
helloyongyang's avatar
helloyongyang committed
250
251
252
253
254
255

        # 分割图像和文本的注意力结果
        img_attn, txt_attn = attn[: img_q.shape[0], :], attn[img_q.shape[0] :,]

        # 收集所有进程的文本注意力结果
        gathered_txt_attn = [torch.empty_like(txt_attn) for _ in range(world_size)]
256
        dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group)
helloyongyang's avatar
helloyongyang committed
257

Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
258
        img_attn = self._reshape_img_attn(img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group)
helloyongyang's avatar
helloyongyang committed
259
260
261
262
263
264
265

        txt_attn = torch.cat(gathered_txt_attn, dim=1)  # 合并所有进程的文本注意力结果

        # 合并图像和文本的注意力结果
        attn = torch.cat([img_attn, txt_attn], dim=0)

        return attn  # 返回最终的注意力结果
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
266
267
268

    @torch.compiler.disable
    def _reshape_img_attn(self, img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group):
STwangyingrui's avatar
STwangyingrui committed
269
        cur_rank = dist.get_rank(seq_p_group)
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
270
271
272
        global_world_size = dist.get_world_size()
        global_rank = dist.get_global_rank(seq_p_group, cur_rank)
        cfg_p_group_index = global_rank // world_size
STwangyingrui's avatar
STwangyingrui committed
273

Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
274
        img_attn = img_attn.reshape(world_size * shard_seqlen, shard_heads, hidden_dims)  # 重塑图像注意力结果
STwangyingrui's avatar
STwangyingrui committed
275
276
277
278
279
280
281
282
283
284
285
286
287

        # 按序列维度切分成 N 份
        attn_shards = [img_attn[i * shard_seqlen : (i + 1) * shard_seqlen, :, :].contiguous() for i in range(world_size)]
        # 准备接收缓冲区
        gathered_attn_shards = [None] * world_size
        for target_rank in range(world_size):
            if target_rank != cur_rank:
                gathered_attn_shards[target_rank] = torch.empty_like(attn_shards[target_rank])
            else:
                gathered_attn_shards[cur_rank] = attn_shards[cur_rank]

        # 异步发起通信后同步
        for target_rank in range(world_size):
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
288
            target_global_rank = cfg_p_group_index * world_size + target_rank
STwangyingrui's avatar
STwangyingrui committed
289
290
291
            if target_rank != cur_rank:
                # 避免死锁: 按 rank 顺序决定发送/接收顺序
                if cur_rank < target_rank:
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
292
293
                    send_req = dist.isend(attn_shards[target_rank], dst=target_global_rank, group=seq_p_group)
                    recv_req = dist.irecv(gathered_attn_shards[target_rank], src=target_global_rank, group=seq_p_group)
STwangyingrui's avatar
STwangyingrui committed
294
                else:
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
295
296
                    recv_req = dist.irecv(gathered_attn_shards[target_rank], src=target_global_rank, group=seq_p_group)
                    send_req = dist.isend(attn_shards[target_rank], dst=target_global_rank, group=seq_p_group)
STwangyingrui's avatar
STwangyingrui committed
297
298
299
300
301
                send_req.wait()
                recv_req.wait()

        # 拼接所有分片 (在头维度上)
        img_attn = torch.cat(gathered_attn_shards, dim=1)
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
302
        img_attn = img_attn.reshape(shard_seqlen, -1)  # 重塑为 [shard_seqlen, -1] 形状
STwangyingrui's avatar
STwangyingrui committed
303

Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
304
305
        torch.cuda.synchronize()  # 确保CUDA操作完成
        return img_attn