Unverified Commit 5cd8025f authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

hotfix: fix regression of attention api change in intel platform (#2439)



fix regression caused by attention api change. ipex.varlen_attention does not support paged-cache
format kv input now.
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
parent e279b38a
...@@ -171,5 +171,8 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca ...@@ -171,5 +171,8 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
FROM ${PLATFORM} AS final FROM ${PLATFORM} AS final
ENV ATTENTION=paged
ENV USE_PREFIX_CACHING=0
ENV CUDA_GRAPHS=0
ENTRYPOINT ["text-generation-launcher"] ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"] CMD ["--json-output"]
...@@ -8,11 +8,11 @@ SUPPORTS_WINDOWING = False ...@@ -8,11 +8,11 @@ SUPPORTS_WINDOWING = False
def attention( def attention(
q, q: torch.Tensor,
k, key_cache: torch.Tensor,
v, value_cache: torch.Tensor,
cu_seqlens, seqlen: Seqlen,
max_s, block_tables: torch.Tensor,
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
causal=True, causal=True,
...@@ -23,13 +23,13 @@ def attention( ...@@ -23,13 +23,13 @@ def attention(
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
ipex.llm.functional.varlen_attention( ipex.llm.functional.varlen_attention(
q, q,
k, key_cache,
v, value_cache,
out, out,
cu_seqlens, seqlen.cu_seqlen_q,
cu_seqlens, seqlen.cu_seqlen_q,
max_s, seqlen.max_q,
max_s, seqlen.max_q,
0.0, 0.0,
softmax_scale, softmax_scale,
False, False,
......
...@@ -297,8 +297,8 @@ class FlashCohereAttention(torch.nn.Module): ...@@ -297,8 +297,8 @@ class FlashCohereAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
......
...@@ -336,8 +336,8 @@ class DbrxAttention(torch.nn.Module): ...@@ -336,8 +336,8 @@ class DbrxAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
......
...@@ -363,8 +363,8 @@ class DeepseekV2Attention(torch.nn.Module): ...@@ -363,8 +363,8 @@ class DeepseekV2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
......
...@@ -25,7 +25,7 @@ from torch import nn ...@@ -25,7 +25,7 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
...@@ -237,8 +237,8 @@ class FlashGemma2Attention(torch.nn.Module): ...@@ -237,8 +237,8 @@ class FlashGemma2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
......
...@@ -25,7 +25,7 @@ from torch import nn ...@@ -25,7 +25,7 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
...@@ -231,8 +231,8 @@ class FlashGemmaAttention(torch.nn.Module): ...@@ -231,8 +231,8 @@ class FlashGemmaAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
......
...@@ -24,7 +24,7 @@ import torch.distributed ...@@ -24,7 +24,7 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
...@@ -231,8 +231,8 @@ class FlashGPT2Attention(torch.nn.Module): ...@@ -231,8 +231,8 @@ class FlashGPT2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
......
...@@ -24,7 +24,7 @@ import torch.distributed ...@@ -24,7 +24,7 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
...@@ -44,7 +44,6 @@ from text_generation_server.layers.rotary import ( ...@@ -44,7 +44,6 @@ from text_generation_server.layers.rotary import (
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
from text_generation_server.utils.import_utils import SYSTEM
def load_attention(config, prefix: str, weights): def load_attention(config, prefix: str, weights):
...@@ -193,8 +192,8 @@ class FlashGPTJAttention(torch.nn.Module): ...@@ -193,8 +192,8 @@ class FlashGPTJAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
......
...@@ -220,8 +220,8 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -220,8 +220,8 @@ class FlashLlamaAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
......
...@@ -218,8 +218,8 @@ class MistralAttention(torch.nn.Module): ...@@ -218,8 +218,8 @@ class MistralAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
......
...@@ -275,8 +275,8 @@ class MixtralAttention(torch.nn.Module): ...@@ -275,8 +275,8 @@ class MixtralAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
......
...@@ -26,7 +26,7 @@ from transformers.activations import ACT2FN ...@@ -26,7 +26,7 @@ from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
...@@ -172,8 +172,8 @@ class FlashNeoxAttention(torch.nn.Module): ...@@ -172,8 +172,8 @@ class FlashNeoxAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
qkv[:, 0], qkv[:, 0],
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else qkv[:, 1],
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else qkv[:, 2],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
......
...@@ -25,6 +25,7 @@ from text_generation_server.layers.layernorm import ( ...@@ -25,6 +25,7 @@ from text_generation_server.layers.layernorm import (
from text_generation_server.layers.rotary import ( from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
from text_generation_server.utils.import_utils import SYSTEM
class PhiConfig(PretrainedConfig): class PhiConfig(PretrainedConfig):
...@@ -193,8 +194,8 @@ class FlashPhiAttention(torch.nn.Module): ...@@ -193,8 +194,8 @@ class FlashPhiAttention(torch.nn.Module):
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
......
...@@ -21,6 +21,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding ...@@ -21,6 +21,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
from text_generation_server.utils.import_utils import SYSTEM
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
...@@ -136,8 +137,8 @@ class Qwen2Attention(torch.nn.Module): ...@@ -136,8 +137,8 @@ class Qwen2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
......
...@@ -5,7 +5,7 @@ import torch.distributed ...@@ -5,7 +5,7 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import ( from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
TensorParallelColumnLinear, TensorParallelColumnLinear,
...@@ -207,8 +207,8 @@ class FlashRWAttention(torch.nn.Module): ...@@ -207,8 +207,8 @@ class FlashRWAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
...@@ -325,12 +325,10 @@ class FlashRWLargeAttention(torch.nn.Module): ...@@ -325,12 +325,10 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=2, index=0), kv_cache[0] if SYSTEM != "ipex" else kv[:, :, 0].contiguous(),
torch.select(kv, dim=2, index=1), kv_cache[1] if SYSTEM != "ipex" else kv[:, :, 1].contiguous(),
kv_cache[0], seqlen,
kv_cache[1], block_tables,
cu_seqlen_prefill,
max_s,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode
......
...@@ -22,6 +22,7 @@ from text_generation_server.layers.gptq import GPTQWeightsLoader ...@@ -22,6 +22,7 @@ from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
from text_generation_server.utils.import_utils import SYSTEM
def load_multi_mqa( def load_multi_mqa(
...@@ -292,8 +293,8 @@ class FlashMQAttention(torch.nn.Module): ...@@ -292,8 +293,8 @@ class FlashMQAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else key_value[:, 0],
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else key_value[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
......
...@@ -47,6 +47,7 @@ from text_generation_server.layers.rotary import ( ...@@ -47,6 +47,7 @@ from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
from text_generation_server.utils.weights import UnquantizedWeight from text_generation_server.utils.weights import UnquantizedWeight
from text_generation_server.utils.import_utils import SYSTEM
class Starcoder2Config(PretrainedConfig): class Starcoder2Config(PretrainedConfig):
...@@ -241,8 +242,8 @@ class Starcoder2Attention(torch.nn.Module): ...@@ -241,8 +242,8 @@ class Starcoder2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment