Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d31174a4
Unverified
Commit
d31174a4
authored
Sep 13, 2024
by
Patrick von Platen
Committed by
GitHub
Sep 12, 2024
Browse files
[Hotfix][Pixtral] Fix multiple images bugs (#8415)
parent
b61bd98f
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
196 additions
and
77 deletions
+196
-77
tests/conftest.py
tests/conftest.py
+1
-1
tests/models/fixtures/pixtral_chat.pickle
tests/models/fixtures/pixtral_chat.pickle
+0
-0
tests/models/fixtures/pixtral_chat_engine.pickle
tests/models/fixtures/pixtral_chat_engine.pickle
+0
-0
tests/models/test_pixtral.py
tests/models/test_pixtral.py
+146
-42
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+49
-34
No files found.
tests/conftest.py
View file @
d31174a4
...
...
@@ -658,8 +658,8 @@ class VllmRunner:
outputs
.
append
((
req_sample_output_ids
,
req_sample_output_strs
))
return
outputs
@
staticmethod
def
_final_steps_generate_w_logprobs
(
self
,
req_outputs
:
List
[
RequestOutput
],
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]:
outputs
:
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]
=
[]
...
...
tests/models/fixtures/pixtral_chat.pickle
0 → 100644
View file @
d31174a4
File added
tests/models/fixtures/pixtral_chat_engine.pickle
0 → 100644
View file @
d31174a4
File added
tests/models/test_pixtral.py
View file @
d31174a4
...
...
@@ -2,13 +2,92 @@
Run `pytest tests/models/test_mistral.py`.
"""
import
pickle
import
uuid
from
typing
import
Any
,
Dict
,
List
import
pytest
from
mistral_common.protocol.instruct.messages
import
ImageURLChunk
from
mistral_common.protocol.instruct.request
import
ChatCompletionRequest
from
mistral_common.tokens.tokenizers.mistral
import
MistralTokenizer
from
mistral_common.tokens.tokenizers.multimodal
import
image_from_chunk
from
vllm
import
EngineArgs
,
LLMEngine
,
SamplingParams
,
TokensPrompt
from
vllm.multimodal
import
MultiModalDataBuiltins
from
vllm.sampling_params
import
SamplingParams
from
.utils
import
check_logprobs_close
pytestmark
=
pytest
.
mark
.
vlm
MODELS
=
[
"mistralai/Pixtral-12B-2409"
]
IMG_URLS
=
[
"https://picsum.photos/id/237/400/300"
,
"https://picsum.photos/id/231/200/300"
,
"https://picsum.photos/id/27/500/500"
,
"https://picsum.photos/id/17/150/600"
,
]
PROMPT
=
"Describe each image in one short sentence."
def
_create_msg_format
(
urls
:
List
[
str
])
->
List
[
Dict
[
str
,
Any
]]:
return
[{
"role"
:
"user"
,
"content"
:
[{
"type"
:
"text"
,
"text"
:
PROMPT
,
}]
+
[{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
url
}
}
for
url
in
urls
],
}]
def
_create_engine_inputs
(
urls
:
List
[
str
])
->
TokensPrompt
:
msg
=
_create_msg_format
(
urls
)
tokenizer
=
MistralTokenizer
.
from_model
(
"pixtral"
)
request
=
ChatCompletionRequest
(
messages
=
msg
)
# type: ignore[type-var]
tokenized
=
tokenizer
.
encode_chat_completion
(
request
)
engine_inputs
=
TokensPrompt
(
prompt_token_ids
=
tokenized
.
tokens
)
images
=
[]
for
chunk
in
request
.
messages
[
0
].
content
:
if
isinstance
(
chunk
,
ImageURLChunk
):
images
.
append
(
image_from_chunk
(
chunk
))
mm_data
=
MultiModalDataBuiltins
(
image
=
images
)
engine_inputs
[
"multi_modal_data"
]
=
mm_data
return
engine_inputs
MSGS
=
[
_create_msg_format
(
IMG_URLS
[:
1
]),
_create_msg_format
(
IMG_URLS
[:
2
]),
_create_msg_format
(
IMG_URLS
),
]
ENGINE_INPUTS
=
[
_create_engine_inputs
(
IMG_URLS
[:
1
]),
_create_engine_inputs
(
IMG_URLS
[:
2
]),
_create_engine_inputs
(
IMG_URLS
),
]
SAMPLING_PARAMS
=
SamplingParams
(
max_tokens
=
512
,
temperature
=
0.0
,
logprobs
=
5
)
LIMIT_MM_PER_PROMPT
=
dict
(
image
=
4
)
MAX_MODEL_LEN
=
[
8192
,
65536
]
FIXTURE_LOGPROBS_CHAT
=
"tests/models/fixtures/pixtral_chat.pickle"
FIXTURE_LOGPROBS_ENGINE
=
"tests/models/fixtures/pixtral_chat_engine.pickle"
def
load_logprobs
(
filename
:
str
)
->
Any
:
with
open
(
filename
,
'rb'
)
as
f
:
return
pickle
.
load
(
f
)
@
pytest
.
mark
.
skip
(
...
...
@@ -16,49 +95,74 @@ MODELS = ["mistralai/Pixtral-12B-2409"]
"Model is too big, test passed on A100 locally but will OOM on CI machine."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"max_model_len"
,
MAX_MODEL_LEN
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models
(
def
test_chat
(
vllm_runner
,
example_prompts
,
max_model_len
:
int
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
)
->
None
:
image_urls
=
[
"https://picsum.photos/id/237/200/300"
,
"https://picsum.photos/seed/picsum/200/300"
]
expected
=
[
"The image depicts a black dog lying on a wooden surface, looking directly at the camera with a calm expression."
,
# noqa
"The image depicts a serene landscape with a snow-covered mountain under a pastel-colored sky during sunset."
# noqa
]
prompt
=
"Describe the image in one short sentence."
sampling_params
=
SamplingParams
(
max_tokens
=
512
,
temperature
=
0.0
)
with
vllm_runner
(
model
,
dtype
=
dtype
,
tokenizer_mode
=
"mistral"
)
as
vllm_model
:
for
i
,
image_url
in
enumerate
(
image_urls
):
messages
=
[
{
"role"
:
"user"
,
"content"
:
[{
"type"
:
"text"
,
"text"
:
prompt
},
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
image_url
}
}]
},
]
EXPECTED_CHAT_LOGPROBS
=
load_logprobs
(
FIXTURE_LOGPROBS_CHAT
)
with
vllm_runner
(
model
,
dtype
=
dtype
,
tokenizer_mode
=
"mistral"
,
enable_chunked_prefill
=
False
,
max_model_len
=
max_model_len
,
limit_mm_per_prompt
=
LIMIT_MM_PER_PROMPT
,
)
as
vllm_model
:
outputs
=
[]
for
msg
in
MSGS
:
output
=
vllm_model
.
model
.
chat
(
msg
,
sampling_params
=
SAMPLING_PARAMS
)
outputs
.
extend
(
output
)
logprobs
=
vllm_runner
.
_final_steps_generate_w_logprobs
(
outputs
)
check_logprobs_close
(
outputs_0_lst
=
logprobs
,
outputs_1_lst
=
EXPECTED_CHAT_LOGPROBS
,
name_0
=
"output"
,
name_1
=
"h100_ref"
)
@
pytest
.
mark
.
skip
(
reason
=
"Model is too big, test passed on A100 locally but will OOM on CI machine."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
def
test_model_engine
(
vllm_runner
,
model
:
str
,
dtype
:
str
)
->
None
:
EXPECTED_ENGINE_LOGPROBS
=
load_logprobs
(
FIXTURE_LOGPROBS_ENGINE
)
args
=
EngineArgs
(
model
=
model
,
tokenizer_mode
=
"mistral"
,
enable_chunked_prefill
=
False
,
limit_mm_per_prompt
=
LIMIT_MM_PER_PROMPT
,
dtype
=
dtype
,
)
engine
=
LLMEngine
.
from_engine_args
(
args
)
engine
.
add_request
(
uuid
.
uuid4
().
hex
,
ENGINE_INPUTS
[
0
],
SAMPLING_PARAMS
)
engine
.
add_request
(
uuid
.
uuid4
().
hex
,
ENGINE_INPUTS
[
1
],
SAMPLING_PARAMS
)
outputs
=
[]
count
=
0
while
True
:
out
=
engine
.
step
()
count
+=
1
for
request_output
in
out
:
if
request_output
.
finished
:
outputs
.
append
(
request_output
)
if
count
==
2
:
engine
.
add_request
(
uuid
.
uuid4
().
hex
,
ENGINE_INPUTS
[
2
],
SAMPLING_PARAMS
)
if
not
engine
.
has_unfinished_requests
():
break
outputs
=
vllm_model
.
model
.
chat
(
messages
,
sampling_params
=
sampling_params
)
assert
outputs
[
0
].
outputs
[
0
].
text
==
expected
[
i
]
logprobs
=
vllm_runner
.
_final_steps_generate_w_logprobs
(
outputs
)
check_logprobs_close
(
outputs_0_lst
=
logprobs
,
outputs_1_lst
=
EXPECTED_ENGINE_LOGPROBS
,
name_0
=
"output"
,
name_1
=
"h100_ref"
)
vllm/model_executor/models/pixtral.py
View file @
d31174a4
import
math
from
array
import
array
from
dataclasses
import
dataclass
,
fields
from
itertools
import
tee
...
...
@@ -15,11 +14,12 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalMask
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.utils
import
merge_multimodal_embeddings
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.base
import
MultiModalInputs
...
...
@@ -48,23 +48,29 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
tokenizer
=
cached_get_tokenizer
(
ctx
.
model_config
.
tokenizer
,
tokenizer_mode
=
ctx
.
model_config
.
tokenizer_mode
)
mm_encoder
=
tokenizer
.
instruct
.
mm_encoder
mm_config
=
ctx
.
model_config
.
multimodal_config
max_num_images_per_request
=
mm_config
.
limit_per_prompt
.
get
(
"image"
,
1
)
mm_encoder
=
tokenizer
.
mistral
.
instruct_tokenizer
.
mm_encoder
patch_size
=
mm_encoder
.
mm_config
.
image_patch_size
image_token_id
=
mm_encoder
.
special_ids
.
img
# approximate image size
size
=
int
(
math
.
sqrt
(
seq_len
)
*
mm_encoder
.
mm_config
.
im
age_patch_size
)
mm_config
=
ctx
.
model_config
.
multimodal_config
num_images
=
mm_config
.
l
im
it_per_prompt
.
get
(
"image"
,
1
)
# dummy size
size
=
256
image
=
Image
.
new
(
"RGB"
,
(
size
,
size
),
color
=
0
)
img_chunk
=
ImageChunk
(
image
=
image
)
tokens
=
mm_encoder
(
img_chunk
).
tokens
token_ids
=
max_num_images_per_request
*
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
tokens
)
image_feature_size
=
(
size
**
2
)
//
(
patch_size
**
2
)
num_image_tokens
=
image_feature_size
*
num_images
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
image_token_id
])
*
num_image_tokens
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
num_image_tokens
)
seq_data
=
SequenceData
(
token_ids
)
mm_data
=
{
"image"
:
max_
num_images
_per_request
*
[
image
]}
mm_data
=
{
"image"
:
num_images
*
[
image
]}
return
seq_data
,
mm_data
...
...
@@ -99,32 +105,31 @@ def input_mapper_for_pixtral(ctx: InputContext,
return
MultiModalInputs
({
"images"
:
images
})
def
merge_multimodal_embeddings
(
input_ids
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
,
image_features
:
Optional
[
List
[
torch
.
Tensor
]],
image_id
:
int
)
->
torch
.
Tensor
:
text_locations
=
input_ids
!=
image_id
image_locations
=
input_ids
==
image_id
seq_len
=
input_ids
.
shape
[
0
]
def
input_processor_for_pixtral
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
):
multi_modal_data
=
llm_inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
not
None
and
"image"
in
multi_modal_data
:
tokenizer
=
cached_get_tokenizer
(
ctx
.
model_config
.
tokenizer
,
tokenizer_mode
=
ctx
.
model_config
.
tokenizer_mode
)
N_txt
=
text_locations
.
sum
().
item
()
_
,
D_txt
=
inputs_embeds
.
shape
N_img
,
D_img
=
image_features
.
shape
mm_encoder
=
tokenizer
.
mistral
.
instruct_tokenizer
.
mm_encoder
image_token_id
=
mm_encoder
.
special_ids
.
img
assert
(
D_txt
==
D_img
),
(
f
"Text features dim
{
D_txt
}
should be equal "
"to image features dim {D_img}"
)
assert
(
seq_len
==
N_txt
+
N_img
),
(
f
"seq_len
{
seq_len
}
should be equal to N_txt + N_img "
f
"
{
(
N_txt
,
N_img
,
image_locations
.
sum
().
item
())
}
"
)
if
image_token_id
not
in
llm_inputs
[
'prompt_token_ids'
]:
raise
ValueError
(
(
f
"You've passed
{
llm_inputs
=
}
without
{
image_token_id
=
}
"
" Make sure to process your input via mistral_common's"
" tokenizer or pass a chat completion request. For more"
" For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411."
))
inputs_embeds
[
image_locations
,
:]
=
image_features
return
inputs_embeds
return
llm_inputs
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
(
input_mapper_for_pixtral
)
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_pixtral_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_pixtral
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_pixtral
)
class
PixtralForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
):
def
__init__
(
self
,
...
...
@@ -201,11 +206,21 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
return
None
if
isinstance
(
images
,
torch
.
Tensor
):
# always take last images
images
=
[
images
[
-
1
][
i
]
for
i
in
range
(
images
.
size
(
1
))]
# if passed as batch take all images
N
,
B
,
C
,
W
,
H
=
images
.
shape
images
=
images
.
reshape
(
N
*
B
,
C
,
W
,
H
)
images
=
[
images
[
i
]
for
i
in
range
(
images
.
size
(
0
))]
elif
isinstance
(
images
,
list
):
# always take last images
images
=
[
images
[
-
1
][
i
]
for
i
in
range
(
len
(
images
[
0
]))]
# if passed as list flatten lists of tensors
flatten_images
=
[]
for
imgs_per_req
in
images
:
imgs_per_req
=
[
imgs_per_req
[
i
]
for
i
in
range
(
imgs_per_req
.
size
(
0
))
]
if
isinstance
(
imgs_per_req
,
torch
.
Tensor
)
else
imgs_per_req
flatten_images
.
extend
(
imgs_per_req
)
images
=
flatten_images
return
images
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment