Unverified Commit 75e9d497 authored by Junlin Zhou's avatar Junlin Zhou Committed by GitHub
Browse files

[Bugfix] Initialize attention bias on the same device as Query/Key/Value (#13468)

parent 32c3b6bf
...@@ -673,7 +673,9 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -673,7 +673,9 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# Cross-attention mask is non-causal # Cross-attention mask is non-causal
attn_bias = BlockDiagonalMask.from_seqlens( attn_bias = BlockDiagonalMask.from_seqlens(
attn_metadata.seq_lens, attn_metadata.encoder_seq_lens) attn_metadata.seq_lens,
attn_metadata.encoder_seq_lens,
device=query.device)
# Encoder branch of encoder-decoder model uses # Encoder branch of encoder-decoder model uses
# attn_metadata.encoder_seq_lens # attn_metadata.encoder_seq_lens
...@@ -683,7 +685,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -683,7 +685,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# Encoder self-attention mask is non-causal # Encoder self-attention mask is non-causal
attn_bias = BlockDiagonalMask.from_seqlens( attn_bias = BlockDiagonalMask.from_seqlens(
attn_metadata.encoder_seq_lens) attn_metadata.encoder_seq_lens, device=query.device)
# Self-attention block of encoder-only model just # Self-attention block of encoder-only model just
# uses the seq_lens directly. # uses the seq_lens directly.
...@@ -692,7 +694,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -692,7 +694,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# Encoder self-attention mask is non-causal # Encoder self-attention mask is non-causal
attn_bias = BlockDiagonalMask.from_seqlens( attn_bias = BlockDiagonalMask.from_seqlens(
attn_metadata.seq_lens) attn_metadata.seq_lens, device=query.device)
# Self-attention block of decoder branch just # Self-attention block of decoder branch just
# uses the seq_lens directly # uses the seq_lens directly
...@@ -701,7 +703,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -701,7 +703,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# Decoder self-attention mask is causal # Decoder self-attention mask is causal
attn_bias = BlockDiagonalCausalMask.from_seqlens( attn_bias = BlockDiagonalCausalMask.from_seqlens(
attn_metadata.seq_lens) attn_metadata.seq_lens, device=query.device)
else: else:
raise ValueError("Unknown AttentionType: %s", attn_type) raise ValueError("Unknown AttentionType: %s", attn_type)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment