Unverified Commit 02809af1 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Bugfix]: Fix cross attention backend selection for Turing GPU (#31806)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent cbd4690a
......@@ -149,16 +149,20 @@ class CrossAttention(Attention):
kv_cache_dtype = "auto"
block_size = 16
underlying_attn_backend = get_attn_backend(
head_size, dtype, kv_cache_dtype, block_size
)
attn_backend = create_cross_attention_backend(underlying_attn_backend)
if attn_type is not None:
assert attn_type == AttentionType.ENCODER_DECODER, (
"CrossAttention only supports AttentionType.ENCODER_DECODER"
)
underlying_attn_backend = get_attn_backend(
head_size,
dtype,
kv_cache_dtype,
block_size,
attn_type=AttentionType.ENCODER_DECODER,
)
attn_backend = create_cross_attention_backend(underlying_attn_backend)
super().__init__(
num_heads=num_heads,
head_size=head_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment