Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5340a2dc
Unverified
Commit
5340a2dc
authored
Aug 27, 2024
by
zifeitong
Committed by
GitHub
Aug 28, 2024
Browse files
[Model] Add multi-image input support for LLaVA-Next offline inference (#7230)
parent
345be0e2
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
174 additions
and
52 deletions
+174
-52
tests/conftest.py
tests/conftest.py
+10
-11
tests/models/test_llava_next.py
tests/models/test_llava_next.py
+80
-13
tests/multimodal/test_utils.py
tests/multimodal/test_utils.py
+34
-1
vllm/model_executor/models/clip.py
vllm/model_executor/models/clip.py
+4
-4
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+12
-2
vllm/model_executor/models/siglip.py
vllm/model_executor/models/siglip.py
+2
-2
vllm/multimodal/utils.py
vllm/multimodal/utils.py
+32
-19
No files found.
tests/conftest.py
View file @
5340a2dc
...
@@ -41,6 +41,10 @@ _TEST_DIR = os.path.dirname(__file__)
...
@@ -41,6 +41,10 @@ _TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS
=
[
os
.
path
.
join
(
_TEST_DIR
,
"prompts"
,
"example.txt"
)]
_TEST_PROMPTS
=
[
os
.
path
.
join
(
_TEST_DIR
,
"prompts"
,
"example.txt"
)]
_LONG_PROMPTS
=
[
os
.
path
.
join
(
_TEST_DIR
,
"prompts"
,
"summary.txt"
)]
_LONG_PROMPTS
=
[
os
.
path
.
join
(
_TEST_DIR
,
"prompts"
,
"summary.txt"
)]
PromptImageInput
=
Union
[
List
[
Image
.
Image
],
List
[
List
[
Image
.
Image
]]]
PromptAudioInput
=
Union
[
List
[
Tuple
[
np
.
ndarray
,
int
]],
List
[
List
[
Tuple
[
np
.
ndarray
,
int
]]]]
def
_read_prompts
(
filename
:
str
)
->
List
[
str
]:
def
_read_prompts
(
filename
:
str
)
->
List
[
str
]:
with
open
(
filename
,
"r"
)
as
f
:
with
open
(
filename
,
"r"
)
as
f
:
...
@@ -161,7 +165,7 @@ def example_encoder_decoder_prompts(
...
@@ -161,7 +165,7 @@ def example_encoder_decoder_prompts(
decoder prompt) tuple.
decoder prompt) tuple.
Returns:
Returns:
* Encoder prompt list
* Encoder prompt list
* Decoder prompt list (reverse of encoder prompt list)
* Decoder prompt list (reverse of encoder prompt list)
'''
'''
...
@@ -578,8 +582,7 @@ class VllmRunner:
...
@@ -578,8 +582,7 @@ class VllmRunner:
self
,
self
,
prompts
:
List
[
str
],
prompts
:
List
[
str
],
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
images
:
Optional
[
Union
[
List
[
Image
.
Image
],
images
:
Optional
[
PromptImageInput
]
=
None
,
List
[
List
[
Image
.
Image
]]]]
=
None
,
)
->
List
[
Tuple
[
List
[
List
[
int
]],
List
[
str
]]]:
)
->
List
[
Tuple
[
List
[
List
[
int
]],
List
[
str
]]]:
if
images
is
not
None
:
if
images
is
not
None
:
assert
len
(
prompts
)
==
len
(
images
)
assert
len
(
prompts
)
==
len
(
images
)
...
@@ -623,10 +626,8 @@ class VllmRunner:
...
@@ -623,10 +626,8 @@ class VllmRunner:
self
,
self
,
prompts
:
List
[
str
],
prompts
:
List
[
str
],
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
images
:
Optional
[
Union
[
List
[
Image
.
Image
],
images
:
Optional
[
PromptImageInput
]
=
None
,
List
[
List
[
Image
.
Image
]]]]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
audios
:
Optional
[
Union
[
List
[
Tuple
[
np
.
ndarray
,
int
]],
List
[
List
[
Tuple
[
np
.
ndarray
,
int
]]]]]
=
None
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]:
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]:
assert
sampling_params
.
logprobs
is
not
None
assert
sampling_params
.
logprobs
is
not
None
...
@@ -676,10 +677,8 @@ class VllmRunner:
...
@@ -676,10 +677,8 @@ class VllmRunner:
prompts
:
List
[
str
],
prompts
:
List
[
str
],
max_tokens
:
int
,
max_tokens
:
int
,
num_logprobs
:
int
,
num_logprobs
:
int
,
images
:
Optional
[
Union
[
List
[
Image
.
Image
],
images
:
Optional
[
PromptImageInput
]
=
None
,
List
[
List
[
Image
.
Image
]]]]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
audios
:
Optional
[
Union
[
List
[
Tuple
[
np
.
ndarray
,
int
]],
List
[
List
[
Tuple
[
np
.
ndarray
,
int
]]]]]
=
None
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]:
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]:
greedy_logprobs_params
=
SamplingParams
(
temperature
=
0.0
,
greedy_logprobs_params
=
SamplingParams
(
temperature
=
0.0
,
...
...
tests/models/test_llava_next.py
View file @
5340a2dc
...
@@ -6,24 +6,22 @@ from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
...
@@ -6,24 +6,22 @@ from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
from
vllm.multimodal.utils
import
rescale_image_size
from
vllm.multimodal.utils
import
rescale_image_size
from
vllm.sequence
import
SampleLogprobs
from
vllm.sequence
import
SampleLogprobs
from
..conftest
import
IMAGE_ASSETS
,
HfRunner
,
VllmRunner
,
_ImageAssets
from
..conftest
import
(
IMAGE_ASSETS
,
HfRunner
,
PromptImageInput
,
VllmRunner
,
_ImageAssets
)
from
.utils
import
check_logprobs_close
from
.utils
import
check_logprobs_close
pytestmark
=
pytest
.
mark
.
vlm
pytestmark
=
pytest
.
mark
.
vlm
_PREFACE
=
(
_LIMIT_IMAGE_PER_PROMPT
=
4
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's "
"questions."
)
HF_IMAGE_PROMPTS
=
IMAGE_ASSETS
.
prompts
({
HF_IMAGE_PROMPTS
=
IMAGE_ASSETS
.
prompts
({
"stop_sign"
:
"stop_sign"
:
f
"
{
_PREFACE
}
USER:
<image>
\n
What's the content of the image?
ASSISTANT:
"
,
"[INST]
<image>
\n
What's the content of the image?
[/INST]
"
,
"cherry_blossom"
:
"cherry_blossom"
:
f
"
{
_PREFACE
}
USER:
<image>
\n
What is the season?
ASSISTANT:
"
,
"[INST]
<image>
\n
What is the season?
[/INST]
"
,
})
})
models
=
[
"llava-hf/llava-v1.6-
vicuna
-7b-hf"
]
models
=
[
"llava-hf/llava-v1.6-
mistral
-7b-hf"
]
def
vllm_to_hf_output
(
vllm_output
:
Tuple
[
List
[
int
],
str
,
def
vllm_to_hf_output
(
vllm_output
:
Tuple
[
List
[
int
],
str
,
...
@@ -114,19 +112,43 @@ def run_test(
...
@@ -114,19 +112,43 @@ def run_test(
else
:
else
:
raise
ValueError
(
"You must provide either `size_factors` or `sizes`"
)
raise
ValueError
(
"You must provide either `size_factors` or `sizes`"
)
_run_test
(
hf_runner
,
vllm_runner
,
inputs_per_image
,
model
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
)
def
_run_test
(
hf_runner
:
Type
[
HfRunner
],
vllm_runner
:
Type
[
VllmRunner
],
inputs
:
List
[
Tuple
[
List
[
str
],
PromptImageInput
]],
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
tensor_parallel_size
:
int
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
):
# max_model_len should be greater than image_feature_size
# max_model_len should be greater than image_feature_size
with
vllm_runner
(
model
,
with
vllm_runner
(
model
,
dtype
=
dtype
,
dtype
=
dtype
,
max_model_len
=
40
96
,
max_model_len
=
102
40
,
tensor_parallel_size
=
tensor_parallel_size
,
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
,
distributed_executor_backend
=
distributed_executor_backend
,
enforce_eager
=
True
)
as
vllm_model
:
enforce_eager
=
True
,
limit_mm_per_prompt
=
{
"image"
:
_LIMIT_IMAGE_PER_PROMPT
})
as
vllm_model
:
vllm_outputs_per_image
=
[
vllm_outputs_per_image
=
[
vllm_model
.
generate_greedy_logprobs
(
prompts
,
vllm_model
.
generate_greedy_logprobs
(
prompts
,
max_tokens
,
max_tokens
,
num_logprobs
=
num_logprobs
,
num_logprobs
=
num_logprobs
,
images
=
images
)
images
=
images
)
for
prompts
,
images
in
inputs
_per_image
for
prompts
,
images
in
inputs
]
]
with
hf_runner
(
model
,
dtype
=
dtype
,
with
hf_runner
(
model
,
dtype
=
dtype
,
...
@@ -136,7 +158,7 @@ def run_test(
...
@@ -136,7 +158,7 @@ def run_test(
max_tokens
,
max_tokens
,
num_logprobs
=
num_logprobs
,
num_logprobs
=
num_logprobs
,
images
=
images
)
images
=
images
)
for
prompts
,
images
in
inputs
_per_image
for
prompts
,
images
in
inputs
]
]
for
hf_outputs
,
vllm_outputs
in
zip
(
hf_outputs_per_image
,
for
hf_outputs
,
vllm_outputs
in
zip
(
hf_outputs_per_image
,
...
@@ -177,7 +199,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
...
@@ -177,7 +199,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
All the image fixtures for the test is under tests/images.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input.
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
The text output is sanitized to be able to compare with hf.
...
@@ -216,3 +238,48 @@ def test_models_fixed_sizes(hf_runner, vllm_runner, image_assets, model, sizes,
...
@@ -216,3 +238,48 @@ def test_models_fixed_sizes(hf_runner, vllm_runner, image_assets, model, sizes,
num_logprobs
=
num_logprobs
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
tensor_parallel_size
=
1
,
)
)
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models_multiple_image_inputs
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
dtype
,
max_tokens
,
num_logprobs
)
->
None
:
stop_sign
=
image_assets
[
0
].
pil_image
cherry_blossom
=
image_assets
[
1
].
pil_image
inputs
=
[(
[
"[INST] <image><image>
\n
Describe 2 images. [/INST]"
,
"[INST] <image><image>
\n
Describe 2 images. [/INST]"
,
"[INST] <image><image><image><image>
\n
Describe 4 images. [/INST]"
,
"[INST] <image>
\n
What is the season? [/INST]"
],
[
[
stop_sign
,
cherry_blossom
],
# Images with different sizes and aspect-ratios
[
rescale_image_size
(
stop_sign
,
0.1
),
stop_sign
,
],
[
stop_sign
,
rescale_image_size
(
stop_sign
,
0.25
),
cherry_blossom
.
resize
((
183
,
488
)),
cherry_blossom
.
resize
((
488
,
183
))
],
cherry_blossom
,
])]
_run_test
(
hf_runner
,
vllm_runner
,
inputs
,
model
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
)
tests/multimodal/test_utils.py
View file @
5340a2dc
...
@@ -6,8 +6,10 @@ from typing import Dict, Tuple
...
@@ -6,8 +6,10 @@ from typing import Dict, Tuple
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
from
PIL
import
Image
from
PIL
import
Image
from
transformers
import
AutoConfig
,
AutoTokenizer
from
vllm.multimodal.utils
import
async_fetch_image
,
fetch_image
from
vllm.multimodal.utils
import
(
async_fetch_image
,
fetch_image
,
repeat_and_pad_placeholder_tokens
)
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS
=
[
TEST_IMAGE_URLS
=
[
...
@@ -80,3 +82,34 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
...
@@ -80,3 +82,34 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
data_image_async
=
await
async_fetch_image
(
data_url
)
data_image_async
=
await
async_fetch_image
(
data_url
)
assert
_image_equals
(
data_image_sync
,
data_image_async
)
assert
_image_equals
(
data_image_sync
,
data_image_async
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"llava-hf/llava-v1.6-mistral-7b-hf"
])
def
test_repeat_and_pad_placeholder_tokens
(
model
):
config
=
AutoConfig
.
from_pretrained
(
model
)
image_token_id
=
config
.
image_token_index
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model
)
test_cases
=
[
(
"<image>"
,
2
,
"<image><image>"
,
[
32000
,
32000
]),
(
"<image><image>"
,
2
,
"<image><image><image>"
,
[
32000
,
32000
,
32000
]),
(
"<image><image>"
,
[
3
,
2
],
"<image><image><image><image><image>"
,
[
32000
,
32000
,
32000
,
32000
,
32000
]),
(
"Image:<image>Image:<image>!"
,
[
3
,
2
],
"Image:<image><image><image>Image:<image><image>!"
,
[
9833
,
28747
,
32000
,
32000
,
32000
,
9833
,
28747
,
32000
,
32000
,
918
]),
(
"<image>"
,
[
3
,
2
],
"<image><image><image>"
,
[
32000
,
32000
,
32000
]),
]
for
prompt
,
repeat_count
,
expected_prompt
,
expected_token_ids
in
test_cases
:
new_prompt
,
new_token_ids
=
repeat_and_pad_placeholder_tokens
(
tokenizer
=
tokenizer
,
prompt
=
prompt
,
prompt_token_ids
=
tokenizer
.
encode
(
prompt
,
add_special_tokens
=
False
),
placeholder_token_id
=
image_token_id
,
repeat_count
=
repeat_count
,
)
assert
new_prompt
==
expected_prompt
assert
new_token_ids
==
expected_token_ids
vllm/model_executor/models/clip.py
View file @
5340a2dc
"""Minimal implementation of CLIPVisionModel intended to be only used
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
within a vision language model."""
from
array
import
array
from
array
import
array
from
typing
import
Iterable
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -84,7 +84,7 @@ def input_processor_for_clip(
...
@@ -84,7 +84,7 @@ def input_processor_for_clip(
llm_inputs
:
LLMInputs
,
llm_inputs
:
LLMInputs
,
*
,
*
,
image_token_id
:
int
,
image_token_id
:
int
,
image_feature_size_override
:
Optional
[
int
]
=
None
,
image_feature_size_override
:
Optional
[
Union
[
int
,
List
[
int
]]
]
=
None
,
):
):
multi_modal_data
=
llm_inputs
.
get
(
"multi_modal_data"
)
multi_modal_data
=
llm_inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
...
@@ -217,7 +217,7 @@ class CLIPEncoderLayer(nn.Module):
...
@@ -217,7 +217,7 @@ class CLIPEncoderLayer(nn.Module):
class
CLIPEncoder
(
nn
.
Module
):
class
CLIPEncoder
(
nn
.
Module
):
"""
"""
Transformer encoder consisting of `config.num_hidden_layers` self
Transformer encoder consisting of `config.num_hidden_layers` self
attention layers. Each layer is a [`CLIPEncoderLayer`].
attention layers. Each layer is a [`CLIPEncoderLayer`].
Args:
Args:
...
...
vllm/model_executor/models/llava_next.py
View file @
5340a2dc
...
@@ -19,6 +19,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -19,6 +19,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.utils
import
is_list_of
from
.clip
import
(
CLIPVisionModel
,
dummy_image_for_clip
,
from
.clip
import
(
CLIPVisionModel
,
dummy_image_for_clip
,
dummy_seq_data_for_clip
,
get_clip_image_feature_size
,
dummy_seq_data_for_clip
,
get_clip_image_feature_size
,
...
@@ -223,6 +224,13 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
...
@@ -223,6 +224,13 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
input_height
=
height
,
input_height
=
height
,
input_width
=
width
,
input_width
=
width
,
)
)
elif
is_list_of
(
image_data
,
Image
.
Image
):
image_feature_size
=
[
get_llava_next_image_feature_size
(
hf_config
,
input_height
=
img
.
height
,
input_width
=
img
.
width
)
for
img
in
image_data
]
elif
isinstance
(
image_data
,
torch
.
Tensor
):
elif
isinstance
(
image_data
,
torch
.
Tensor
):
image_feature_size
=
image_data
.
shape
[
0
]
image_feature_size
=
image_data
.
shape
[
0
]
else
:
else
:
...
@@ -425,7 +433,10 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -425,7 +433,10 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
self
.
config
.
image_grid_pinpoints
,
self
.
config
.
image_grid_pinpoints
,
self
.
config
.
vision_config
.
image_size
,
self
.
config
.
vision_config
.
image_size
,
)
)
other_patch_embeds
=
other_patch_embeds
\
num_patches
=
num_patch_height
*
num_patch_width
# Image patches might be padded for batch processing
other_patch_embeds
=
other_patch_embeds
[:
num_patches
]
\
.
view
(
num_patch_height
,
num_patch_width
,
height
,
width
,
-
1
)
.
view
(
num_patch_height
,
num_patch_width
,
height
,
width
,
-
1
)
if
"unpad"
in
strategy
:
if
"unpad"
in
strategy
:
...
@@ -496,7 +507,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -496,7 +507,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
self
,
self
,
image_input
:
LlavaNextImageInputs
,
image_input
:
LlavaNextImageInputs
,
)
->
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]:
)
->
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]:
if
image_input
[
"type"
]
==
"image_embeds"
:
if
image_input
[
"type"
]
==
"image_embeds"
:
return
[
image_input
[
"data"
]]
return
[
image_input
[
"data"
]]
...
...
vllm/model_executor/models/siglip.py
View file @
5340a2dc
...
@@ -3,7 +3,7 @@ within a vision language model."""
...
@@ -3,7 +3,7 @@ within a vision language model."""
import
math
import
math
from
array
import
array
from
array
import
array
from
typing
import
Iterable
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
PIL
import
Image
from
PIL
import
Image
...
@@ -93,7 +93,7 @@ def input_processor_for_siglip(
...
@@ -93,7 +93,7 @@ def input_processor_for_siglip(
llm_inputs
:
LLMInputs
,
llm_inputs
:
LLMInputs
,
*
,
*
,
image_token_id
:
int
,
image_token_id
:
int
,
image_feature_size_override
:
Optional
[
int
]
=
None
,
image_feature_size_override
:
Optional
[
Union
[
int
,
List
[
int
]]
]
=
None
,
):
):
multi_modal_data
=
llm_inputs
.
get
(
"multi_modal_data"
)
multi_modal_data
=
llm_inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
...
...
vllm/multimodal/utils.py
View file @
5340a2dc
...
@@ -189,10 +189,13 @@ def repeat_and_pad_placeholder_tokens(
...
@@ -189,10 +189,13 @@ def repeat_and_pad_placeholder_tokens(
prompt_token_ids
:
List
[
int
],
prompt_token_ids
:
List
[
int
],
*
,
*
,
placeholder_token_id
:
int
,
placeholder_token_id
:
int
,
repeat_count
:
int
=
1
,
repeat_count
:
Union
[
int
,
List
[
int
]]
,
pad_token_left
:
Optional
[
int
]
=
None
,
pad_token_left
:
Optional
[
int
]
=
None
,
pad_token_right
:
Optional
[
int
]
=
None
,
pad_token_right
:
Optional
[
int
]
=
None
,
)
->
Tuple
[
Optional
[
str
],
List
[
int
]]:
)
->
Tuple
[
Optional
[
str
],
List
[
int
]]:
if
isinstance
(
repeat_count
,
int
):
repeat_count
=
[
repeat_count
]
if
prompt
is
None
:
if
prompt
is
None
:
new_prompt
=
None
new_prompt
=
None
else
:
else
:
...
@@ -201,13 +204,6 @@ def repeat_and_pad_placeholder_tokens(
...
@@ -201,13 +204,6 @@ def repeat_and_pad_placeholder_tokens(
tokenizer
.
decode
(
pad_token_left
))
tokenizer
.
decode
(
pad_token_left
))
pad_token_str_right
=
(
None
if
pad_token_right
is
None
else
pad_token_str_right
=
(
None
if
pad_token_right
is
None
else
tokenizer
.
decode
(
pad_token_right
))
tokenizer
.
decode
(
pad_token_right
))
replacement_str
=
""
.
join
(
repeat_and_pad_token
(
placeholder_token_str
,
repeat_count
=
repeat_count
,
pad_token_left
=
pad_token_str_left
,
pad_token_right
=
pad_token_str_right
,
))
placeholder_token_count
=
prompt
.
count
(
placeholder_token_str
)
placeholder_token_count
=
prompt
.
count
(
placeholder_token_str
)
# This is an arbitrary number to distinguish between the two cases
# This is an arbitrary number to distinguish between the two cases
...
@@ -216,28 +212,45 @@ def repeat_and_pad_placeholder_tokens(
...
@@ -216,28 +212,45 @@ def repeat_and_pad_placeholder_tokens(
"Please follow the prompt format that is "
"Please follow the prompt format that is "
"documented on HuggingFace which does not involve "
"documented on HuggingFace which does not involve "
"repeating %s tokens."
,
placeholder_token_str
)
"repeating %s tokens."
,
placeholder_token_str
)
elif
placeholder_token_count
>
1
:
if
placeholder_token_count
<
len
(
repeat_count
):
logger
.
warning
(
"Multiple multi-modal input is not supported yet, "
logger
.
warning
(
"so any extra placeholder tokens will be treated "
"The number of multi-modal placeholder tokens in the prompt "
"as plain text."
)
"is less than the number of multi-modal inputs. Extra "
"placeholder tokens will be treated as plain text"
)
# The image tokens are removed to be consistent with HuggingFace
repeat_count
=
repeat_count
[:
placeholder_token_count
]
new_prompt
=
prompt
.
replace
(
placeholder_token_str
,
replacement_str
,
1
)
prompt_parts
=
prompt
.
split
(
placeholder_token_str
,
maxsplit
=
len
(
repeat_count
))
new_prompt
=
""
for
i
,
repeat_count_item
in
enumerate
(
repeat_count
):
replacement_str
=
""
.
join
(
repeat_and_pad_token
(
placeholder_token_str
,
repeat_count
=
repeat_count_item
,
pad_token_left
=
pad_token_str_left
,
pad_token_right
=
pad_token_str_right
,
))
# The image tokens are removed to be consistent with HuggingFace
new_prompt
+=
prompt_parts
[
i
]
+
replacement_str
new_prompt
+=
prompt_parts
[
-
1
]
new_token_ids
:
List
[
int
]
=
[]
new_token_ids
:
List
[
int
]
=
[]
placeholder_token_idx
=
0
for
i
,
token
in
enumerate
(
prompt_token_ids
):
for
i
,
token
in
enumerate
(
prompt_token_ids
):
if
token
==
placeholder_token_id
:
if
token
==
placeholder_token_id
:
replacement_ids
=
repeat_and_pad_token
(
replacement_ids
=
repeat_and_pad_token
(
placeholder_token_id
,
placeholder_token_id
,
repeat_count
=
repeat_count
,
repeat_count
=
repeat_count
[
placeholder_token_idx
]
,
pad_token_left
=
pad_token_left
,
pad_token_left
=
pad_token_left
,
pad_token_right
=
pad_token_right
,
pad_token_right
=
pad_token_right
,
)
)
new_token_ids
.
extend
(
replacement_ids
)
new_token_ids
.
extend
(
replacement_ids
)
placeholder_token_idx
+=
1
# No need to further scan the list since we only replace once
# No need to further scan the list since we replaced all tokens
new_token_ids
.
extend
(
prompt_token_ids
[
i
+
1
:])
if
placeholder_token_idx
>=
len
(
repeat_count
):
break
new_token_ids
.
extend
(
prompt_token_ids
[
i
+
1
:])
break
else
:
else
:
new_token_ids
.
append
(
token
)
new_token_ids
.
append
(
token
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment