Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
962d2c63
Unverified
Commit
962d2c63
authored
Oct 20, 2024
by
Michael Goin
Committed by
GitHub
Oct 20, 2024
Browse files
[Model][Pixtral] Use memory_efficient_attention for PixtralHFVision (#9520)
parent
5b59fe0f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
41 deletions
+21
-41
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+21
-41
No files found.
vllm/model_executor/models/pixtral.py
View file @
962d2c63
...
@@ -13,8 +13,7 @@ from transformers import PixtralVisionConfig, PretrainedConfig
...
@@ -13,8 +13,7 @@ from transformers import PixtralVisionConfig, PretrainedConfig
from
transformers.models.pixtral.image_processing_pixtral
import
(
from
transformers.models.pixtral.image_processing_pixtral
import
(
_num_image_tokens
)
_num_image_tokens
)
from
transformers.models.pixtral.modeling_pixtral
import
(
from
transformers.models.pixtral.modeling_pixtral
import
(
PixtralRotaryEmbedding
,
apply_rotary_pos_emb
,
PixtralRotaryEmbedding
,
apply_rotary_pos_emb
,
position_ids_in_meshgrid
)
generate_block_attention_mask
,
position_ids_in_meshgrid
)
from
xformers.ops.fmha
import
memory_efficient_attention
from
xformers.ops.fmha
import
memory_efficient_attention
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalMask
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalMask
...
@@ -813,48 +812,30 @@ class PixtralHFAttention(nn.Module):
...
@@ -813,48 +812,30 @@ class PixtralHFAttention(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
attention_mask
:
BlockDiagonalMask
,
position_embeddings
:
torch
.
Tensor
,
position_embeddings
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""Input shape: Batch x Time x Channel"""
batch
,
patches
,
_
=
hidden_states
.
size
()
batch_size
,
patches
,
_
=
hidden_states
.
size
()
q
=
self
.
q_proj
(
hidden_states
)
k
=
self
.
k_proj
(
hidden_states
)
query_states
=
self
.
q_proj
(
hidden_states
)
v
=
self
.
v_proj
(
hidden_states
)
key_states
=
self
.
k_proj
(
hidden_states
)
value_states
=
self
.
v_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
batch_size
,
patches
,
self
.
n_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
batch_size
,
patches
,
self
.
n_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
batch_size
,
patches
,
self
.
n_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
# Transpose q and k to apply HF's Rotary Position Embedding
q
=
q
.
view
(
batch
,
patches
,
self
.
n_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
k
=
k
.
view
(
batch
,
patches
,
self
.
n_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
cos
,
sin
=
position_embeddings
cos
,
sin
=
position_embeddings
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
q
,
k
=
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
unsqueeze_dim
=
0
)
key_states
,
cos
,
sin
,
unsqueeze_dim
=
0
)
attn_weights
=
torch
.
matmul
(
query_states
,
key_states
.
transpose
(
2
,
3
))
*
self
.
scale
if
attention_mask
is
not
None
:
attn_weights
=
attn_weights
+
attention_mask
# upcast attention to fp32
# Transpose q and k back for attention
attn_weights
=
nn
.
functional
.
softmax
(
attn_weights
,
q
=
q
.
transpose
(
1
,
2
).
contiguous
()
dim
=-
1
,
k
=
k
.
transpose
(
1
,
2
).
contiguous
()
dtype
=
torch
.
float32
).
to
(
v
=
v
.
reshape
(
batch
,
patches
,
self
.
n_heads
,
self
.
head_dim
)
query_states
.
dtype
)
attn_output
=
torch
.
matmul
(
attn_weights
,
value_states
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
(
)
out
=
memory_efficient_attention
(
q
,
k
,
v
,
attn_bias
=
attention_mask
)
attn_output
=
attn_outp
ut
.
reshape
(
batch
_size
,
patches
,
-
1
)
out
=
o
ut
.
reshape
(
batch
,
patches
,
self
.
n_heads
*
self
.
head_dim
)
return
self
.
o_proj
(
attn_outp
ut
)
return
self
.
o_proj
(
o
ut
)
class
PixtralHFTransformerBlock
(
nn
.
Module
):
class
PixtralHFTransformerBlock
(
nn
.
Module
):
...
@@ -869,7 +850,7 @@ class PixtralHFTransformerBlock(nn.Module):
...
@@ -869,7 +850,7 @@ class PixtralHFTransformerBlock(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
attention_mask
:
BlockDiagonalMask
,
position_embeddings
:
torch
.
Tensor
,
position_embeddings
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
r
=
self
.
attention
.
forward
(
self
.
attention_norm
(
hidden_states
),
r
=
self
.
attention
.
forward
(
self
.
attention_norm
(
hidden_states
),
...
@@ -892,7 +873,7 @@ class PixtralHFTransformer(nn.Module):
...
@@ -892,7 +873,7 @@ class PixtralHFTransformer(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
attention_mask
:
BlockDiagonalMask
,
position_embeddings
:
torch
.
Tensor
,
position_embeddings
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
for
layer
in
self
.
layers
:
for
layer
in
self
.
layers
:
...
@@ -953,9 +934,8 @@ class PixtralHFVisionModel(nn.Module):
...
@@ -953,9 +934,8 @@ class PixtralHFVisionModel(nn.Module):
position_embedding
=
self
.
patch_positional_embedding
(
position_embedding
=
self
.
patch_positional_embedding
(
patch_embeds
,
position_ids
)
patch_embeds
,
position_ids
)
attention_mask
=
generate_block_attention_mask
(
attention_mask
=
BlockDiagonalMask
.
from_seqlens
(
[
p
.
shape
[
-
2
]
*
p
.
shape
[
-
1
]
for
p
in
patch_embeds_list
],
[
p
.
shape
[
-
2
]
*
p
.
shape
[
-
1
]
for
p
in
patch_embeds_list
],
)
patch_embeds
)
out
=
self
.
transformer
(
patch_embeds
,
attention_mask
,
out
=
self
.
transformer
(
patch_embeds
,
attention_mask
,
position_embedding
)
position_embedding
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment