Unverified Commit 5cd8025f authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

hotfix: fix regression of attention api change in intel platform (#2439)



fix regression caused by attention api change. ipex.varlen_attention does not support paged-cache
format kv input now.
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
parent e279b38a
......@@ -171,5 +171,8 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
FROM ${PLATFORM} AS final
ENV ATTENTION=paged
ENV USE_PREFIX_CACHING=0
ENV CUDA_GRAPHS=0
ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"]
......@@ -8,11 +8,11 @@ SUPPORTS_WINDOWING = False
def attention(
q,
k,
v,
cu_seqlens,
max_s,
q: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
......@@ -23,13 +23,13 @@ def attention(
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
ipex.llm.functional.varlen_attention(
q,
k,
v,
key_cache,
value_cache,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_q,
0.0,
softmax_scale,
False,
......
......@@ -297,8 +297,8 @@ class FlashCohereAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1] if SYSTEM != "ipex" else value,
seqlen,
block_tables,
self.softmax_scale,
......
......@@ -336,8 +336,8 @@ class DbrxAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
......
......@@ -363,8 +363,8 @@ class DeepseekV2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1] if SYSTEM != "ipex" else value,
seqlen,
block_tables,
self.softmax_scale,
......
......@@ -25,7 +25,7 @@ from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
......@@ -237,8 +237,8 @@ class FlashGemma2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
......
......@@ -25,7 +25,7 @@ from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
......@@ -231,8 +231,8 @@ class FlashGemmaAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
......
......@@ -24,7 +24,7 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
......@@ -231,8 +231,8 @@ class FlashGPT2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1] if SYSTEM != "ipex" else value,
seqlen,
block_tables,
self.softmax_scale,
......
......@@ -24,7 +24,7 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
......@@ -44,7 +44,6 @@ from text_generation_server.layers.rotary import (
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
from text_generation_server.utils.import_utils import SYSTEM
def load_attention(config, prefix: str, weights):
......@@ -193,8 +192,8 @@ class FlashGPTJAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1] if SYSTEM != "ipex" else value,
seqlen,
block_tables,
self.softmax_scale,
......
......@@ -220,8 +220,8 @@ class FlashLlamaAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
......
......@@ -218,8 +218,8 @@ class MistralAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,
......
......@@ -275,8 +275,8 @@ class MixtralAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,
......
......@@ -26,7 +26,7 @@ from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
......@@ -172,8 +172,8 @@ class FlashNeoxAttention(torch.nn.Module):
# flash attention
attn_output = attention(
qkv[:, 0],
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else qkv[:, 1],
kv_cache[1] if SYSTEM != "ipex" else qkv[:, 2],
seqlen,
block_tables,
self.softmax_scale,
......
......@@ -25,6 +25,7 @@ from text_generation_server.layers.layernorm import (
from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
from text_generation_server.utils.import_utils import SYSTEM
class PhiConfig(PretrainedConfig):
......@@ -193,8 +194,8 @@ class FlashPhiAttention(torch.nn.Module):
if cu_seqlen_prefill is not None:
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
......
......@@ -21,6 +21,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from text_generation_server.utils.import_utils import SYSTEM
def load_attention(config, prefix, weights):
......@@ -136,8 +137,8 @@ class Qwen2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,
......
......@@ -5,7 +5,7 @@ import torch.distributed
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import (
SpeculativeHead,
TensorParallelColumnLinear,
......@@ -207,8 +207,8 @@ class FlashRWAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
......@@ -325,12 +325,10 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
kv_cache[0] if SYSTEM != "ipex" else kv[:, :, 0].contiguous(),
kv_cache[1] if SYSTEM != "ipex" else kv[:, :, 1].contiguous(),
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
......
......@@ -22,6 +22,7 @@ from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
from text_generation_server.utils.import_utils import SYSTEM
def load_multi_mqa(
......@@ -292,8 +293,8 @@ class FlashMQAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else key_value[:, 0],
kv_cache[1] if SYSTEM != "ipex" else key_value[:, 1],
seqlen,
block_tables,
self.softmax_scale,
......
......@@ -47,6 +47,7 @@ from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
from text_generation_server.utils.weights import UnquantizedWeight
from text_generation_server.utils.import_utils import SYSTEM
class Starcoder2Config(PretrainedConfig):
......@@ -241,8 +242,8 @@ class Starcoder2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment