"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "ac07160c8da87802a84c01598af5a39b4660b28e"
Unverified Commit 57f9685d authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

enable mllama in intel platform (#2610)


Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
parent 0da4df4b
...@@ -19,7 +19,12 @@ from typing import Optional, Tuple, List ...@@ -19,7 +19,12 @@ from typing import Optional, Tuple, List
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
import flash_attn_2_cuda from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
else:
import flash_attn_2_cuda
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
import torch.nn.functional as F import torch.nn.functional as F
...@@ -698,6 +703,37 @@ class MllamaTextCrossAttention(nn.Module): ...@@ -698,6 +703,37 @@ class MllamaTextCrossAttention(nn.Module):
# logger.info( # logger.info(
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}" # f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
# ) # )
if SYSTEM == "ipex":
attn_output = torch.empty_like(query_states)
ipex.llm.functional.varlen_attention(
(
query_states.contiguous()
if query_states.device.type == "xpu"
else query_states
),
(
key_states.contiguous()
if key_states.device.type == "xpu"
else key_states
),
(
value_states.contiguous()
if value_states.device.type == "xpu"
else value_states
),
attn_output,
cu_seqlen_q,
cu_seqlen_k,
max_q,
max_k,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
attn_output = flash_attn_2_cuda.varlen_fwd( attn_output = flash_attn_2_cuda.varlen_fwd(
query_states, query_states,
key_states, key_states,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment