Unverified Commit 5d97e0c4 authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

fix FlashDecoding change's regression in intel platform (#2161)



install triton because GPTQParams needs it.
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
parent 022f6515
...@@ -62,6 +62,7 @@ ENV HUGGINGFACE_HUB_CACHE=/data \ ...@@ -62,6 +62,7 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
WORKDIR /usr/src WORKDIR /usr/src
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl
RUN pip install https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v2.1.0/triton-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed
# Install server # Install server
...@@ -132,6 +133,7 @@ RUN conda install -c conda-forge gperftools mkl ...@@ -132,6 +133,7 @@ RUN conda install -c conda-forge gperftools mkl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
RUN pip install triton
WORKDIR /usr/src WORKDIR /usr/src
......
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
import torch import torch
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen
SUPPORTS_WINDOWING = False SUPPORTS_WINDOWING = False
...@@ -55,11 +56,10 @@ def paged_attention( ...@@ -55,11 +56,10 @@ def paged_attention(
kv_head_mapping: torch.Tensor, kv_head_mapping: torch.Tensor,
softmax_scale: float, softmax_scale: float,
block_tables: torch.Tensor, block_tables: torch.Tensor,
cu_seqlen_q: torch.Tensor, seqlen: Seqlen,
cu_seqlen_k: torch.Tensor,
max_s: int, max_s: int,
): ):
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out, out,
query, query,
key_cache, key_cache,
...@@ -67,8 +67,9 @@ def paged_attention( ...@@ -67,8 +67,9 @@ def paged_attention(
kv_head_mapping, kv_head_mapping,
softmax_scale, softmax_scale,
block_tables, block_tables,
cu_seqlen_q, seqlen.input_lengths,
BLOCK_SIZE, BLOCK_SIZE,
max_s, max_s,
None, None,
) )
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment