Unverified Commit 75e9d497 authored by Junlin Zhou's avatar Junlin Zhou Committed by GitHub
Browse files

[Bugfix] Initialize attention bias on the same device as Query/Key/Value (#13468)

parent 32c3b6bf
......@@ -673,7 +673,9 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# Cross-attention mask is non-causal
attn_bias = BlockDiagonalMask.from_seqlens(
attn_metadata.seq_lens, attn_metadata.encoder_seq_lens)
attn_metadata.seq_lens,
attn_metadata.encoder_seq_lens,
device=query.device)
# Encoder branch of encoder-decoder model uses
# attn_metadata.encoder_seq_lens
......@@ -683,7 +685,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# Encoder self-attention mask is non-causal
attn_bias = BlockDiagonalMask.from_seqlens(
attn_metadata.encoder_seq_lens)
attn_metadata.encoder_seq_lens, device=query.device)
# Self-attention block of encoder-only model just
# uses the seq_lens directly.
......@@ -692,7 +694,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# Encoder self-attention mask is non-causal
attn_bias = BlockDiagonalMask.from_seqlens(
attn_metadata.seq_lens)
attn_metadata.seq_lens, device=query.device)
# Self-attention block of decoder branch just
# uses the seq_lens directly
......@@ -701,7 +703,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# Decoder self-attention mask is causal
attn_bias = BlockDiagonalCausalMask.from_seqlens(
attn_metadata.seq_lens)
attn_metadata.seq_lens, device=query.device)
else:
raise ValueError("Unknown AttentionType: %s", attn_type)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment