Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
57f9685d
Unverified
Commit
57f9685d
authored
Oct 08, 2024
by
Wang, Yi
Committed by
GitHub
Oct 07, 2024
Browse files
enable mllama in intel platform (#2610)
Signed-off-by:
Wang, Yi A
<
yi.a.wang@intel.com
>
parent
0da4df4b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
60 additions
and
24 deletions
+60
-24
server/text_generation_server/models/custom_modeling/mllama.py
...r/text_generation_server/models/custom_modeling/mllama.py
+60
-24
No files found.
server/text_generation_server/models/custom_modeling/mllama.py
View file @
57f9685d
...
@@ -19,7 +19,12 @@ from typing import Optional, Tuple, List
...
@@ -19,7 +19,12 @@ from typing import Optional, Tuple, List
import
torch
import
torch
import
torch.utils.checkpoint
import
torch.utils.checkpoint
from
torch
import
nn
from
torch
import
nn
import
flash_attn_2_cuda
from
text_generation_server.utils.import_utils
import
SYSTEM
if
SYSTEM
==
"ipex"
:
import
intel_extension_for_pytorch
as
ipex
else
:
import
flash_attn_2_cuda
from
transformers.activations
import
ACT2FN
from
transformers.activations
import
ACT2FN
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -698,29 +703,60 @@ class MllamaTextCrossAttention(nn.Module):
...
@@ -698,29 +703,60 @@ class MllamaTextCrossAttention(nn.Module):
# logger.info(
# logger.info(
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
# )
# )
attn_output
=
flash_attn_2_cuda
.
varlen_fwd
(
if
SYSTEM
==
"ipex"
:
query_states
,
attn_output
=
torch
.
empty_like
(
query_states
)
key_states
,
ipex
.
llm
.
functional
.
varlen_attention
(
value_states
,
(
None
,
query_states
.
contiguous
()
cu_seqlen_q
,
if
query_states
.
device
.
type
==
"xpu"
cu_seqlen_k
,
else
query_states
None
,
),
None
,
(
None
,
# block_tables
key_states
.
contiguous
()
None
,
if
key_states
.
device
.
type
==
"xpu"
max_q
,
else
key_states
max_k
,
),
0.0
,
(
self
.
softmax_scale
,
value_states
.
contiguous
()
False
,
if
value_states
.
device
.
type
==
"xpu"
causal
,
# Causal
else
value_states
-
1
,
# window_size_left,
),
-
1
,
attn_output
,
0.0
,
# softcap
cu_seqlen_q
,
False
,
cu_seqlen_k
,
None
,
max_q
,
)[
0
]
max_k
,
0.0
,
self
.
softmax_scale
,
False
,
causal
,
False
,
None
,
)
else
:
attn_output
=
flash_attn_2_cuda
.
varlen_fwd
(
query_states
,
key_states
,
value_states
,
None
,
cu_seqlen_q
,
cu_seqlen_k
,
None
,
None
,
None
,
# block_tables
None
,
max_q
,
max_k
,
0.0
,
self
.
softmax_scale
,
False
,
causal
,
# Causal
-
1
,
# window_size_left,
-
1
,
0.0
,
# softcap
False
,
None
,
)[
0
]
attn_output
=
self
.
o_proj
(
attn_output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
))
attn_output
=
self
.
o_proj
(
attn_output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
))
return
attn_output
return
attn_output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment