Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a0f7d53b
Unverified
Commit
a0f7d53b
authored
Dec 19, 2024
by
Cyrus Leung
Committed by
GitHub
Dec 19, 2024
Browse files
[Bugfix] Cleanup Pixtral HF code (#11333)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
5aef4980
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
141 deletions
+14
-141
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+14
-141
No files found.
vllm/model_executor/models/pixtral.py
View file @
a0f7d53b
...
@@ -10,12 +10,12 @@ from mistral_common.protocol.instruct.messages import ImageChunk
...
@@ -10,12 +10,12 @@ from mistral_common.protocol.instruct.messages import ImageChunk
from
PIL
import
Image
from
PIL
import
Image
from
transformers
import
PixtralVisionConfig
from
transformers
import
PixtralVisionConfig
from
transformers.models.pixtral.image_processing_pixtral
import
(
from
transformers.models.pixtral.image_processing_pixtral
import
(
_num_image_tokens
)
_num_image_tokens
as
_get_pixtral_hf_num_image_tokens
)
from
transformers.models.pixtral.modeling_pixtral
import
(
from
transformers.models.pixtral.modeling_pixtral
import
(
PixtralRotaryEmbedding
,
apply_rotary_pos_emb
,
position_ids_in_meshgrid
)
PixtralRotaryEmbedding
,
apply_rotary_pos_emb
,
position_ids_in_meshgrid
)
from
vllm.attention
import
AttentionMetadata
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
InputContext
,
token_inputs
)
InputContext
,
token_inputs
)
...
@@ -27,7 +27,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...
@@ -27,7 +27,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.utils
import
merge_multimodal_embeddings
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
NestedTensors
,
PlaceholderRange
from
vllm.multimodal.inputs
import
NestedTensors
,
PlaceholderRange
...
@@ -35,11 +34,10 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
...
@@ -35,11 +34,10 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges
,
consecutive_placeholder_ranges
,
resolve_visual_encoder_outputs
)
resolve_visual_encoder_outputs
)
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
from
vllm.transformers_utils.processor
import
cached_get_processor
from
vllm.utils
import
is_list_of
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.utils
import
init_vllm_registered_model
,
maybe_prefix
from
.utils
import
(
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
try
:
try
:
from
xformers
import
ops
as
xops
from
xformers
import
ops
as
xops
...
@@ -699,37 +697,14 @@ def get_pixtral_hf_num_patches(*, image_size: int, patch_size: int) -> int:
...
@@ -699,37 +697,14 @@ def get_pixtral_hf_num_patches(*, image_size: int, patch_size: int) -> int:
return
grid_length
*
grid_length
return
grid_length
*
grid_length
def
get_max_pixtral_hf_image_feature_size
(
hf_config
:
PixtralVisionConfig
)
->
int
:
return
get_pixtral_hf_num_patches
(
image_size
=
hf_config
.
image_size
,
patch_size
=
hf_config
.
patch_size
)
def
get_max_pixtral_hf_image_tokens
(
hf_config
:
PixtralVisionConfig
)
->
int
:
def
get_max_pixtral_hf_image_tokens
(
hf_config
:
PixtralVisionConfig
)
->
int
:
return
get_max_pixtral_hf_image_feature_size
(
hf_config
)
grid_length
=
get_pixtral_hf_patch_grid_length
(
image_size
=
hf_config
.
image_size
,
patch_size
=
hf_config
.
patch_size
,
)
# Consider the image_break_token
def
dummy_seq_data_for_pixtral_hf
(
return
(
grid_length
+
1
)
*
grid_length
hf_config
:
PixtralVisionConfig
,
seq_len
:
int
,
num_images
:
int
,
*
,
image_token_id
:
int
,
image_feature_size_override
:
Optional
[
int
]
=
None
,
mm_key
:
str
=
"image"
):
if
image_feature_size_override
is
None
:
image_feature_size
=
get_max_pixtral_hf_image_feature_size
(
hf_config
)
else
:
image_feature_size
=
image_feature_size_override
return
SequenceData
.
from_prompt_token_counts
(
(
image_token_id
,
image_feature_size
*
num_images
),
(
0
,
seq_len
-
image_feature_size
*
num_images
),
),
{
mm_key
:
consecutive_placeholder_ranges
(
num_items
=
num_images
,
item_size
=
image_feature_size
)
}
def
dummy_image_for_pixtral_hf
(
def
dummy_image_for_pixtral_hf
(
...
@@ -763,116 +738,14 @@ def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig,
...
@@ -763,116 +738,14 @@ def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig,
image_width
=
int
(
numpy
.
ceil
(
image_width
/
ratio
))
image_width
=
int
(
numpy
.
ceil
(
image_width
/
ratio
))
image_height
=
int
(
numpy
.
ceil
(
image_height
/
ratio
))
image_height
=
int
(
numpy
.
ceil
(
image_height
/
ratio
))
num_height_tokens
,
num_width_tokens
=
_num_image_tokens
(
num_height_tokens
,
num_width_tokens
=
_get_pixtral_hf_num_image_tokens
(
(
image_height
,
image_width
),
(
patch_height
,
patch_width
))
(
image_height
,
image_width
),
(
patch_height
,
patch_width
),
)
return
num_width_tokens
,
num_height_tokens
return
num_width_tokens
,
num_height_tokens
def
input_processor_for_pixtral_hf
(
model_config
:
ModelConfig
,
hf_config
:
PixtralVisionConfig
,
inputs
:
DecoderOnlyInputs
,
*
,
image_token_id
:
int
,
image_feature_size_override
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
,
)
->
DecoderOnlyInputs
:
assert
image_feature_size_override
is
None
,
(
"image_feature_size_override is not supported for Pixtral"
)
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
return
inputs
processor
=
cached_get_processor
(
model_config
.
model
)
image_data
=
multi_modal_data
[
"image"
]
if
isinstance
(
image_data
,
Image
.
Image
):
image_data
=
[
image_data
]
elif
not
is_list_of
(
image_data
,
Image
.
Image
):
raise
TypeError
(
f
"Invalid image type:
{
type
(
image_data
)
}
"
)
new_prompt
=
inputs
.
get
(
"prompt"
)
new_token_ids
=
inputs
[
"prompt_token_ids"
]
image_token
=
processor
.
image_token
image_break_token
=
processor
.
image_break_token
image_end_token
=
processor
.
image_end_token
# Update new_prompt if present
if
new_prompt
:
parts
=
new_prompt
.
split
(
image_token
)
assert
len
(
parts
)
-
1
==
len
(
image_data
)
new_parts
=
[
parts
[
0
]]
# Start with the part before any image tokens
for
image
,
next_part
in
zip
(
image_data
,
parts
[
1
:]):
w
,
h
=
image
.
size
(
num_width_tokens
,
num_height_tokens
)
=
get_pixtral_hf_image_feature_size
(
hf_config
,
image_width
=
w
,
image_height
=
h
)
replace_tokens
=
[
image_token
]
*
num_width_tokens
+
[
image_break_token
]
replace_tokens
=
replace_tokens
*
num_height_tokens
replace_tokens
[
-
1
]
=
image_end_token
new_parts
.
append
(
""
.
join
(
replace_tokens
))
new_parts
.
append
(
next_part
)
new_prompt
=
""
.
join
(
new_parts
)
# Update new_token_ids
convert_tokens_to_ids
=
processor
.
tokenizer
.
convert_tokens_to_ids
image_token_id
=
convert_tokens_to_ids
(
image_token
)
image_break_id
=
convert_tokens_to_ids
(
image_break_token
)
image_end_id
=
convert_tokens_to_ids
(
image_end_token
)
placeholder_token_id
=
-
999
# Find all image token indices at once
placeholder_indices
=
[
idx
for
idx
,
token_id
in
enumerate
(
new_token_ids
)
if
token_id
==
image_token_id
]
assert
len
(
placeholder_indices
)
==
len
(
image_data
)
replace_tokens_list
=
[]
for
placeholder_idx
,
image
in
zip
(
placeholder_indices
,
image_data
):
new_token_ids
[
placeholder_idx
]
=
placeholder_token_id
w
,
h
=
image
.
size
(
num_width_tokens
,
num_height_tokens
)
=
get_pixtral_hf_image_feature_size
(
hf_config
,
image_width
=
w
,
image_height
=
h
)
replace_tokens
=
[
image_token_id
]
*
num_width_tokens
+
[
image_break_id
]
replace_tokens
=
replace_tokens
*
num_height_tokens
replace_tokens
[
-
1
]
=
image_end_id
replace_tokens_list
.
append
(
replace_tokens
)
reverse_offsets
:
List
[
int
]
=
[]
# Backward iteration for replacement without affecting known indices
for
placeholder_idx
,
replace_tokens
in
zip
(
reversed
(
placeholder_indices
),
reversed
(
replace_tokens_list
)):
reverse_offsets
.
append
(
len
(
new_token_ids
)
-
placeholder_idx
+
len
(
replace_tokens
))
new_token_ids
[
placeholder_idx
:
placeholder_idx
+
1
]
=
replace_tokens
placeholder_ranges
:
List
[
PlaceholderRange
]
=
[]
for
reverse_offset
,
replace_tokens
in
zip
(
reversed
(
reverse_offsets
),
replace_tokens_list
):
placeholder_ranges
.
append
(
PlaceholderRange
(
offset
=
len
(
new_token_ids
)
-
reverse_offset
,
length
=
len
(
replace_tokens
),
))
# NOTE: Create a defensive copy of the original inputs
return
token_inputs
(
prompt_token_ids
=
new_token_ids
,
prompt
=
new_prompt
,
multi_modal_data
=
multi_modal_data
,
multi_modal_placeholders
=
{
"image"
:
placeholder_ranges
})
class
PixtralHFMLP
(
nn
.
Module
):
class
PixtralHFMLP
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment