Commit 0dd7ca09 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Fix] Fix sage-attn distribute bug (#235)

parent 79c3caa2
...@@ -12,7 +12,7 @@ class UlyssesAttnWeight(AttnWeightTemplate): ...@@ -12,7 +12,7 @@ class UlyssesAttnWeight(AttnWeightTemplate):
def __init__(self): def __init__(self):
self.config = {} self.config = {}
def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None): def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None, model_cls=None):
""" """
执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。 执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
...@@ -77,7 +77,7 @@ class UlyssesAttnWeight(AttnWeightTemplate): ...@@ -77,7 +77,7 @@ class UlyssesAttnWeight(AttnWeightTemplate):
# 调用注意力函数计算注意力结果 # 调用注意力函数计算注意力结果
# attn = attention(attention_type=attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv) # attn = attention(attention_type=attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv)
attn = attention_module.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv) attn = attention_module.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv, model_cls=model_cls)
# 分割图像和文本的注意力结果 # 分割图像和文本的注意力结果
img_attn, txt_attn = attn[: img_q.shape[0], :], attn[img_q.shape[0] :,] img_attn, txt_attn = attn[: img_q.shape[0], :], attn[img_q.shape[0] :,]
......
...@@ -164,6 +164,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -164,6 +164,7 @@ class WanTransformerInfer(BaseTransformerInfer):
cu_seqlens_qkv=cu_seqlens_q, cu_seqlens_qkv=cu_seqlens_q,
attention_module=weights.self_attn_1, attention_module=weights.self_attn_1,
seq_p_group=self.seq_p_group, seq_p_group=self.seq_p_group,
model_cls=self.config["model_cls"],
) )
else: else:
attn_out = weights.self_attn_1.apply( attn_out = weights.self_attn_1.apply(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment