Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c91ed47c
Unverified
Commit
c91ed47c
authored
Oct 24, 2024
by
Michael Goin
Committed by
GitHub
Oct 24, 2024
Browse files
[Bugfix] Remove xformers requirement for Pixtral (#9597)
Signed-off-by:
mgoin
<
michael@neuralmagic.com
>
parent
59449095
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
19 deletions
+46
-19
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+46
-19
No files found.
vllm/model_executor/models/pixtral.py
View file @
c91ed47c
...
...
@@ -14,8 +14,6 @@ from transformers.models.pixtral.image_processing_pixtral import (
_num_image_tokens
)
from
transformers.models.pixtral.modeling_pixtral
import
(
PixtralRotaryEmbedding
,
apply_rotary_pos_emb
,
position_ids_in_meshgrid
)
from
xformers.ops.fmha
import
memory_efficient_attention
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalMask
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
ModelConfig
,
MultiModalConfig
...
...
@@ -38,6 +36,12 @@ from vllm.utils import is_list_of
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.utils
import
init_vllm_registered_model
try
:
from
xformers
import
ops
as
xops
USE_XFORMERS_OPS
=
True
except
ImportError
:
USE_XFORMERS_OPS
=
False
def
get_max_pixtral_image_tokens
(
ctx
:
InputContext
):
tokenizer
=
cached_get_tokenizer
(
...
...
@@ -416,7 +420,7 @@ class Attention(nn.Module):
def
forward
(
self
,
x
:
torch
.
Tensor
,
mask
:
BlockDiagonalMask
,
mask
:
torch
.
Tensor
,
freqs_cis
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
batch
,
patches
,
_
=
x
.
shape
...
...
@@ -427,7 +431,7 @@ class Attention(nn.Module):
v
=
v
.
reshape
(
batch
,
patches
,
self
.
n_heads
,
self
.
head_dim
)
q
,
k
=
apply_rotary_emb_vit
(
q
,
k
,
freqs_cis
=
freqs_cis
)
out
=
memory_efficient_attention
(
q
,
k
,
v
,
attn_bias
=
mask
)
out
=
xops
.
memory_efficient_attention
(
q
,
k
,
v
,
attn_bias
=
mask
)
out
=
out
.
reshape
(
batch
,
patches
,
self
.
n_heads
*
self
.
head_dim
)
return
self
.
wo
(
out
)
...
...
@@ -444,7 +448,7 @@ class TransformerBlock(nn.Module):
def
forward
(
self
,
x
:
torch
.
Tensor
,
mask
:
BlockDiagonalMask
,
mask
:
torch
.
Tensor
,
freqs_cis
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
r
=
self
.
attention
.
forward
(
self
.
attention_norm
(
x
),
...
...
@@ -467,7 +471,7 @@ class Transformer(nn.Module):
def
forward
(
self
,
x
:
torch
.
Tensor
,
mask
:
BlockDiagonalMask
,
mask
:
torch
.
Tensor
,
freqs_cis
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
for
layer
in
self
.
layers
:
...
...
@@ -562,8 +566,12 @@ class VisionTransformer(nn.Module):
freqs_cis
=
self
.
freqs_cis
[
positions
[:,
0
],
positions
[:,
1
]]
# pass through Transformer with a block diagonal mask delimiting images
mask
=
BlockDiagonalMask
.
from_seqlens
(
if
USE_XFORMERS_OPS
:
mask
=
xops
.
fmha
.
attn_bias
.
BlockDiagonalMask
.
from_seqlens
(
[
p
.
shape
[
-
2
]
*
p
.
shape
[
-
1
]
for
p
in
patch_embeds_list
],
)
else
:
raise
ImportError
(
"Xformers is required for Pixtral inference "
"with the Mistral format"
)
out
=
self
.
transformer
(
patch_embeds
,
mask
=
mask
,
freqs_cis
=
freqs_cis
)
# remove batch dimension of the single sequence
...
...
@@ -828,7 +836,7 @@ class PixtralHFAttention(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
BlockDiagonalMask
,
attention_mask
:
torch
.
Tensor
,
position_embeddings
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
batch
,
patches
,
_
=
hidden_states
.
size
()
...
...
@@ -843,12 +851,23 @@ class PixtralHFAttention(nn.Module):
cos
,
sin
=
position_embeddings
q
,
k
=
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
unsqueeze_dim
=
0
)
if
USE_XFORMERS_OPS
:
# Transpose q and k back for attention
q
=
q
.
transpose
(
1
,
2
).
contiguous
()
k
=
k
.
transpose
(
1
,
2
).
contiguous
()
v
=
v
.
reshape
(
batch
,
patches
,
self
.
n_heads
,
self
.
head_dim
)
out
=
memory_efficient_attention
(
q
,
k
,
v
,
attn_bias
=
attention_mask
)
out
=
xops
.
memory_efficient_attention
(
q
,
k
,
v
,
attn_bias
=
attention_mask
)
else
:
v
=
v
.
reshape
(
batch
,
patches
,
self
.
n_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
out
=
nn
.
functional
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
attention_mask
)
out
=
out
.
transpose
(
1
,
2
)
out
=
out
.
reshape
(
batch
,
patches
,
self
.
n_heads
*
self
.
head_dim
)
return
self
.
o_proj
(
out
)
...
...
@@ -877,7 +896,7 @@ class PixtralHFTransformerBlock(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
BlockDiagonalMask
,
attention_mask
:
torch
.
Tensor
,
position_embeddings
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
r
=
self
.
attention
.
forward
(
self
.
attention_norm
(
hidden_states
),
...
...
@@ -916,7 +935,7 @@ class PixtralHFTransformer(nn.Module):
def
forward
(
self
,
x
:
torch
.
Tensor
,
attention_mask
:
BlockDiagonalMask
,
attention_mask
:
torch
.
Tensor
,
position_embeddings
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
for
layer
in
self
.
layers
:
...
...
@@ -1000,11 +1019,19 @@ class PixtralHFVisionModel(nn.Module):
patch_embeds_list
,
max_width
=
self
.
config
.
image_size
//
self
.
config
.
patch_size
).
to
(
self
.
device
)
position_embedding
=
self
.
patch_positional_embedding
(
patch_embeds
,
position_ids
)
attention_mask
=
BlockDiagonalMask
.
from_seqlens
(
if
USE_XFORMERS_OPS
:
attention_mask
=
xops
.
fmha
.
attn_bias
.
BlockDiagonalMask
.
from_seqlens
(
[
p
.
shape
[
-
2
]
*
p
.
shape
[
-
1
]
for
p
in
patch_embeds_list
],
)
else
:
from
transformers.models.pixtral.modeling_pixtral
import
(
generate_block_attention_mask
)
attention_mask
=
generate_block_attention_mask
(
[
p
.
shape
[
-
2
]
*
p
.
shape
[
-
1
]
for
p
in
patch_embeds_list
],
patch_embeds
)
out
=
self
.
transformer
(
patch_embeds
,
attention_mask
,
position_embedding
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment