Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6b29d6fe
Unverified
Commit
6b29d6fe
authored
Jun 10, 2024
by
Cyrus Leung
Committed by
GitHub
Jun 10, 2024
Browse files
[Model] Initial support for LLaVA-NeXT (#4199)
Co-authored-by:
Roger Wang
<
ywang@roblox.com
>
parent
0bfa1c4f
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
640 additions
and
18 deletions
+640
-18
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+5
-1
tests/models/test_llava.py
tests/models/test_llava.py
+0
-2
tests/models/test_llava_next.py
tests/models/test_llava_next.py
+123
-0
tests/multimodal/test_processor.py
tests/multimodal/test_processor.py
+55
-7
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+2
-0
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+10
-8
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+445
-0
No files found.
docs/source/models/supported_models.rst
View file @
6b29d6fe
...
...
@@ -89,7 +89,11 @@ Alongside each architecture, we include some popular models that use it.
- ✅︎
* - :code:`LlavaForConditionalGeneration`
- LLaVA-1.5
- :code:`llava-hf/llava-1.5-7b-hf`\*, :code:`llava-hf/llava-1.5-13b-hf`\*, etc.
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc.
-
* - :code:`LlavaNextForConditionalGeneration`
- LLaVA-NeXT
- :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
-
* - :code:`MiniCPMForCausalLM`
- MiniCPM
...
...
tests/models/test_llava.py
View file @
6b29d6fe
...
...
@@ -39,8 +39,6 @@ def iter_llava_configs(model_name: str):
model_and_vl_config
=
[
*
iter_llava_configs
(
"llava-hf/llava-1.5-7b-hf"
),
# Not enough memory
# *iter_llava_configs("llava-hf/llava-1.5-13b-hf"),
]
...
...
tests/models/test_llava_next.py
0 → 100644
View file @
6b29d6fe
from
typing
import
List
,
Tuple
import
pytest
from
transformers
import
AutoTokenizer
from
vllm.config
import
VisionLanguageConfig
from
..conftest
import
IMAGE_FILES
pytestmark
=
pytest
.
mark
.
llava
_PREFACE
=
(
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's "
"questions."
)
# The image token is placed before "user" on purpose so that the test can pass
HF_IMAGE_PROMPTS
=
[
f
"
{
_PREFACE
}
<image>
\n
USER: What's the content of the image? ASSISTANT:"
,
f
"
{
_PREFACE
}
<image>
\n
USER: What is the season? ASSISTANT:"
,
]
assert
len
(
HF_IMAGE_PROMPTS
)
==
len
(
IMAGE_FILES
)
def
iter_llava_next_configs
(
model_name
:
str
):
image_hw_to_feature_size
=
{
(
336
,
336
):
1176
,
(
672
,
672
):
2928
,
(
1344
,
336
):
1944
,
(
336
,
1344
):
1890
,
}
for
(
h
,
w
),
f
in
image_hw_to_feature_size
.
items
():
for
input_type
,
input_shape
in
[
(
VisionLanguageConfig
.
ImageInputType
.
PIXEL_VALUES
,
(
1
,
3
,
h
,
w
)),
]:
yield
(
model_name
,
VisionLanguageConfig
(
image_input_type
=
input_type
,
image_feature_size
=
f
,
image_token_id
=
32000
,
image_input_shape
=
input_shape
,
image_processor
=
model_name
,
image_processor_revision
=
None
))
model_and_vl_config
=
[
*
iter_llava_next_configs
(
"llava-hf/llava-v1.6-vicuna-7b-hf"
),
]
def
vllm_to_hf_output
(
vllm_output
:
Tuple
[
List
[
int
],
str
],
vlm_config
:
VisionLanguageConfig
,
model_id
:
str
):
"""Sanitize vllm output to be comparable with hf output.
The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
It also reduces `output_str` from "<image><image>bla" to "bla".
"""
input_ids
,
output_str
=
vllm_output
image_token_id
=
vlm_config
.
image_token_id
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
)
image_token_str
=
tokenizer
.
decode
(
image_token_id
)
hf_input_ids
=
[
input_id
for
idx
,
input_id
in
enumerate
(
input_ids
)
if
input_id
!=
image_token_id
or
input_ids
[
idx
-
1
]
!=
image_token_id
]
hf_output_str
=
output_str
\
.
replace
(
image_token_str
*
vlm_config
.
image_feature_size
,
" "
)
return
hf_input_ids
,
hf_output_str
@
pytest
.
mark
.
xfail
(
reason
=
"Inconsistent image processor being used due to lack "
"of support for dynamic image token replacement"
)
@
pytest
.
mark
.
parametrize
(
"model_and_config"
,
model_and_vl_config
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
def
test_models
(
hf_runner
,
vllm_runner
,
hf_images
,
vllm_images
,
model_and_config
,
dtype
:
str
,
max_tokens
:
int
)
->
None
:
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalData objects and corresponding
vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
model_id
,
vlm_config
=
model_and_config
with
hf_runner
(
model_id
,
dtype
=
dtype
,
is_vision_model
=
True
)
as
hf_model
:
hf_outputs
=
hf_model
.
generate_greedy
(
HF_IMAGE_PROMPTS
,
max_tokens
,
images
=
hf_images
)
vllm_image_prompts
=
[
p
.
replace
(
"<image>"
,
"<image>"
*
vlm_config
.
image_feature_size
)
for
p
in
HF_IMAGE_PROMPTS
]
with
vllm_runner
(
model_id
,
dtype
=
dtype
,
# should be greater than image_feature_size
max_model_len
=
4096
,
enforce_eager
=
True
,
**
vlm_config
.
as_cli_args_dict
(),
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
vllm_image_prompts
,
max_tokens
,
images
=
vllm_images
)
for
i
in
range
(
len
(
HF_IMAGE_PROMPTS
)):
hf_output_ids
,
hf_output_str
=
hf_outputs
[
i
]
vllm_output_ids
,
vllm_output_str
=
vllm_to_hf_output
(
vllm_outputs
[
i
],
vlm_config
,
model_id
)
assert
hf_output_str
==
vllm_output_str
,
(
f
"Test
{
i
}
:
\n
HF:
{
hf_output_str
!
r
}
\n
vLLM:
{
vllm_output_str
!
r
}
"
)
assert
hf_output_ids
==
vllm_output_ids
,
(
f
"Test
{
i
}
:
\n
HF:
{
hf_output_ids
}
\n
vLLM:
{
vllm_output_ids
}
"
)
tests/multimodal/test_processor.py
View file @
6b29d6fe
import
numpy
as
np
import
pytest
from
transformers
import
CLIPImageProcessor
from
transformers
import
CLIPImageProcessor
,
LlavaNextImageProcessor
from
vllm.config
import
ModelConfig
,
VisionLanguageConfig
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
...
...
@@ -12,7 +12,7 @@ from ..conftest import _STR_DTYPE_TO_TORCH_DTYPE
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
,
"float"
])
def
test_clip_image_processor
(
hf_images
,
dtype
):
MODEL_NAME
=
"llava-hf/llava-1.5-7b-hf"
IMAGE_HEIGHT
=
IMAGE_WIDTH
=
33
IMAGE_HEIGHT
=
IMAGE_WIDTH
=
560
hf_processor
=
CLIPImageProcessor
.
from_pretrained
(
MODEL_NAME
)
assert
isinstance
(
hf_processor
,
CLIPImageProcessor
)
...
...
@@ -55,10 +55,61 @@ def test_clip_image_processor(hf_images, dtype):
assert
np
.
allclose
(
hf_arr
,
vllm_arr
),
f
"Failed for key=
{
key
}
"
@
pytest
.
mark
.
xfail
(
reason
=
"Inconsistent image processor being used due to lack "
"of support for dynamic image token replacement"
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
,
"float"
])
def
test_llava_next_image_processor
(
hf_images
,
dtype
):
MODEL_NAME
=
"llava-hf/llava-v1.6-34b-hf"
IMAGE_HEIGHT
=
IMAGE_WIDTH
=
560
hf_processor
=
LlavaNextImageProcessor
.
from_pretrained
(
MODEL_NAME
)
assert
isinstance
(
hf_processor
,
LlavaNextImageProcessor
)
model_config
=
ModelConfig
(
model
=
MODEL_NAME
,
tokenizer
=
MODEL_NAME
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
seed
=
0
,
dtype
=
dtype
,
revision
=
None
,
)
vlm_config
=
VisionLanguageConfig
(
image_input_type
=
VisionLanguageConfig
.
ImageInputType
.
PIXEL_VALUES
,
image_token_id
=
64000
,
image_input_shape
=
(
1
,
3
,
IMAGE_HEIGHT
,
IMAGE_WIDTH
),
image_feature_size
=
2928
,
image_processor
=
MODEL_NAME
,
image_processor_revision
=
None
,
)
for
image
in
hf_images
:
hf_result
=
hf_processor
.
preprocess
(
image
,
return_tensors
=
"pt"
,
).
to
(
dtype
=
_STR_DTYPE_TO_TORCH_DTYPE
[
dtype
])
vllm_result
=
MULTIMODAL_REGISTRY
.
process_input
(
ImagePixelData
(
image
),
model_config
=
model_config
,
vlm_config
=
vlm_config
,
)
assert
hf_result
.
keys
()
==
vllm_result
.
keys
()
for
key
,
hf_tensor
in
hf_result
.
items
():
hf_arr
:
np
.
ndarray
=
hf_tensor
.
numpy
()
vllm_arr
:
np
.
ndarray
=
vllm_result
[
key
].
numpy
()
assert
hf_arr
.
shape
==
vllm_arr
.
shape
,
f
"Failed for key=
{
key
}
"
assert
np
.
allclose
(
hf_arr
,
vllm_arr
),
f
"Failed for key=
{
key
}
"
@
pytest
.
mark
.
xfail
(
reason
=
"Example image pixels were not processed using HuggingFace"
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
def
test_image_pixel_types
(
hf_images
,
vllm_image_tensors
,
dtype
):
MODEL_NAME
=
"llava-hf/llava-1.5-7b-hf"
IMAGE_HEIGHT
=
IMAGE_WIDTH
=
33
IMAGE_HEIGHT
=
IMAGE_WIDTH
=
560
model_config
=
ModelConfig
(
model
=
MODEL_NAME
,
...
...
@@ -95,7 +146,4 @@ def test_image_pixel_types(hf_images, vllm_image_tensors, dtype):
tensor_arr
:
np
.
ndarray
=
tensor_result
[
key
].
numpy
()
assert
image_arr
.
shape
==
tensor_arr
.
shape
,
f
"Failed for key=
{
key
}
"
# The examples in PR#3042 have slightly different preprocessing from
# HuggingFace's LlavaProcessor, causing the test to fail.
# assert np.allclose(image_arr, tensor_arr), f"Failed for key={key}"
assert
np
.
allclose
(
image_arr
,
tensor_arr
),
f
"Failed for key=
{
key
}
"
vllm/model_executor/models/__init__.py
View file @
6b29d6fe
...
...
@@ -33,6 +33,8 @@ _GENERATION_MODELS = {
"LlamaForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"LlavaForConditionalGeneration"
:
(
"llava"
,
"LlavaForConditionalGeneration"
),
"LlavaNextForConditionalGeneration"
:
(
"llava_next"
,
"LlavaNextForConditionalGeneration"
),
# For decapoda-research/llama-*
"LLaMAForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"MistralForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
...
...
vllm/model_executor/models/llava.py
View file @
6b29d6fe
from
typing
import
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
,
Union
import
torch
from
torch
import
nn
import
torch.nn
as
nn
# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
# transformers' impl.
from
transformers
import
CLIPVisionModel
,
LlavaConfig
...
...
@@ -51,7 +51,7 @@ class LlavaMultiModalProjector(nn.Module):
return
hidden_states
def
_
merge_vision_embeddings
(
input_ids
:
torch
.
Tensor
,
def
merge_vision_embeddings
(
input_ids
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
,
vision_embeddings
:
torch
.
Tensor
,
image_token_id
:
int
)
->
torch
.
Tensor
:
...
...
@@ -151,7 +151,8 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
return
None
if
not
isinstance
(
pixel_values
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of pixel values"
)
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
return
LlavaImagePixelInputs
(
type
=
"pixel_values"
,
...
...
@@ -166,7 +167,8 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
return
None
if
not
isinstance
(
image_features
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of image features"
)
raise
ValueError
(
"Incorrect type of image features. "
f
"Got type:
{
type
(
image_features
)
}
"
)
return
LlavaImageFeatureInputs
(
type
=
"image_features"
,
...
...
@@ -268,7 +270,7 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
_
merge_vision_embeddings
(
inputs_embeds
=
merge_vision_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
vision_language_config
.
image_token_id
)
...
...
vllm/model_executor/models/llava_next.py
0 → 100644
View file @
6b29d6fe
from
typing
import
(
Dict
,
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
,
Union
)
import
torch
import
torch.nn
as
nn
from
PIL
import
Image
# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
# transformers' impl.
from
transformers
import
CLIPVisionModel
,
LlavaNextConfig
from
transformers.models.llava_next.modeling_llava_next
import
(
get_anyres_image_grid_shape
,
unpad_image
)
from
typing_extensions
import
NotRequired
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VisionLanguageConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.llama
import
LlamaModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalData
from
vllm.multimodal.image
import
ImagePixelData
,
get_dummy_image_data
from
vllm.sequence
import
SamplerOutput
,
SequenceData
from
.llava
import
LlavaMultiModalProjector
,
merge_vision_embeddings
from
.vlm_base
import
VisionLanguageModelBase
logger
=
init_logger
(
__name__
)
_KEYS_TO_MODIFY_MAPPING
=
{
"language_model.lm_head"
:
"lm_head"
,
"language_model.model"
:
"language_model"
,
}
class
LlavaNextImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
data
:
torch
.
Tensor
"""Shape: (batch_size, 1 + num_patches, num_channels, height, width)"""
image_sizes
:
NotRequired
[
torch
.
Tensor
]
"""Shape: (batch_size, 2)"""
class
LlavaNextImageFeatureInputs
(
TypedDict
):
type
:
Literal
[
"image_features"
]
data
:
torch
.
Tensor
"""Shape: (batch_size, 1 + num_patches, image_feature_size, hidden_size)"""
image_sizes
:
NotRequired
[
torch
.
Tensor
]
"""Shape: (batch_size, 2)"""
LlavaNextImageInputs
=
Union
[
LlavaNextImagePixelInputs
,
LlavaNextImageFeatureInputs
]
def
_get_dummy_image_data
(
seq_len
:
int
,
model_config
:
ModelConfig
,
vlm_config
:
VisionLanguageConfig
,
)
->
Tuple
[
SequenceData
,
MultiModalData
]:
seq_data
,
fake_mm_data
=
get_dummy_image_data
(
seq_len
,
model_config
,
vlm_config
)
config_input_type
=
vlm_config
.
image_input_type
ImageInputType
=
VisionLanguageConfig
.
ImageInputType
if
config_input_type
==
ImageInputType
.
PIXEL_VALUES
:
_
,
c
,
h
,
w
=
vlm_config
.
image_input_shape
mode
=
{
1
:
"L"
,
3
:
"RGB"
}[
c
]
fake_mm_data
=
ImagePixelData
(
Image
.
new
(
mode
,
(
w
,
h
),
color
=
0
))
return
seq_data
,
fake_mm_data
def
_image_pixel_processor
(
data
:
ImagePixelData
,
model_config
:
ModelConfig
,
vlm_config
:
VisionLanguageConfig
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
image
=
data
.
image
if
isinstance
(
image
,
torch
.
Tensor
):
pixel_values
=
image
.
to
(
model_config
.
dtype
)
batch_size
,
_
,
_
,
h
,
w
=
pixel_values
.
shape
image_sizes
=
torch
.
tensor
([(
w
,
h
)
for
_
in
range
(
batch_size
)])
return
{
"pixel_values"
:
pixel_values
,
"image_sizes"
:
image_sizes
}
# Temporary patch before dynamic number of image tokens is supported
_
,
_
,
h
,
w
=
vlm_config
.
image_input_shape
if
(
w
,
h
)
!=
(
image
.
width
,
image
.
height
):
logger
.
warning
(
"Dynamic image shape is currently not supported. "
"Resizing input image to (%d, %d)."
,
w
,
h
)
data
.
image
=
image
.
resize
((
w
,
h
))
return
MULTIMODAL_REGISTRY
.
_get_plugin_for_data_type
(
ImagePixelData
)
\
.
_default_input_processor
(
data
,
model_config
,
vlm_config
)
@
MULTIMODAL_REGISTRY
.
register_image_pixel_input
(
_image_pixel_processor
)
@
MULTIMODAL_REGISTRY
.
register_dummy_data
(
_get_dummy_image_data
)
class
LlavaNextForConditionalGeneration
(
VisionLanguageModelBase
):
"""
Args to `forward()`:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: For PIXEL_VALUES, expects a batch with shape
[1, num_patches, 3, 336, 336].
image_features: For IMAGE_FEATURES, expects a batch with shape
[1, num_patches, 1176, 1024].
"""
def
__init__
(
self
,
config
:
LlavaNextConfig
,
vision_language_config
:
VisionLanguageConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
super
().
__init__
(
vision_language_config
)
# Update the type annotation from that of its superclass
self
.
config
=
config
if
self
.
vision_language_config
.
image_input_type
==
(
VisionLanguageConfig
.
ImageInputType
.
PIXEL_VALUES
):
self
.
vision_tower
=
CLIPVisionModel
(
config
.
vision_config
)
else
:
raise
TypeError
(
"Image features are not supported by LLaVA-NeXT"
)
self
.
multi_modal_projector
=
LlavaMultiModalProjector
(
vision_hidden_size
=
config
.
vision_config
.
hidden_size
,
text_hidden_size
=
config
.
text_config
.
hidden_size
,
projector_hidden_act
=
config
.
projector_hidden_act
)
self
.
quant_config
=
quant_config
self
.
language_model
=
LlamaModel
(
config
.
text_config
,
cache_config
,
quant_config
)
self
.
unpadded_vocab_size
=
config
.
text_config
.
vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
text_config
.
hidden_size
,
org_num_embeddings
=
self
.
language_model
.
org_vocab_size
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
Sampler
()
self
.
image_newline
=
nn
.
Parameter
(
torch
.
empty
(
config
.
text_config
.
hidden_size
))
def
_validate_image_pixels
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
_
,
num_channels
,
_
,
_
=
self
.
vision_language_config
.
image_input_shape
# Note that this is different from that of vLLM vision_language_config
# since the image is resized by the HuggingFace preprocessor
height
=
width
=
self
.
config
.
vision_config
.
image_size
if
list
(
data
.
shape
[
2
:])
!=
[
num_channels
,
height
,
width
]:
raise
ValueError
(
f
"The expected image tensor shape is batch dimension plus "
f
"num_patches plus
{
[
num_channels
,
height
,
width
]
}
. "
f
"You supplied
{
data
.
shape
}
. "
f
"If you are using vLLM's entrypoint, make sure your "
f
"supplied image input is consistent with "
f
"image_input_shape in engine args."
)
return
data
def
_validate_image_sizes
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
list
(
data
.
shape
[
1
:])
!=
[
2
]:
raise
ValueError
(
f
"The expected image sizes shape is batch dimension plus "
f
"
{
[
2
]
}
. You supplied
{
data
.
shape
}
."
)
return
data
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
LlavaNextImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
image_sizes
=
kwargs
.
pop
(
"image_sizes"
,
None
)
image_features
=
kwargs
.
pop
(
"image_features"
,
None
)
expected_input_type
=
self
.
vision_language_config
.
image_input_type
ImageInputType
=
VisionLanguageConfig
.
ImageInputType
if
expected_input_type
==
ImageInputType
.
PIXEL_VALUES
:
if
image_features
is
not
None
:
raise
ValueError
(
"Expected pixel values but got image features"
)
if
pixel_values
is
None
:
return
None
if
not
isinstance
(
pixel_values
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
if
not
isinstance
(
image_sizes
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of image sizes. "
f
"Got type:
{
type
(
image_sizes
)
}
"
)
return
LlavaNextImagePixelInputs
(
type
=
"pixel_values"
,
data
=
self
.
_validate_image_pixels
(
pixel_values
),
image_sizes
=
self
.
_validate_image_sizes
(
image_sizes
),
)
assert
expected_input_type
!=
ImageInputType
.
IMAGE_FEATURES
,
(
"Failed to validate this at initialization time"
)
return
None
def
_merge_image_patch_embeddings
(
self
,
image_size
:
torch
.
Tensor
,
patch_embeddings
:
torch
.
Tensor
,
*
,
strategy
:
str
)
->
torch
.
Tensor
:
# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
if
strategy
==
"flat"
:
return
patch_embeddings
.
flatten
(
0
,
1
)
if
strategy
.
startswith
(
"spatial"
):
orig_width
,
orig_height
=
image_size
height
=
width
=
self
.
config
.
vision_config
.
image_size
\
//
self
.
config
.
vision_config
.
patch_size
base_patch_embeds
=
patch_embeddings
[
0
]
if
height
*
width
!=
base_patch_embeds
.
shape
[
0
]:
raise
ValueError
(
"The number of patches is not consistent with the "
"image size."
)
if
patch_embeddings
.
shape
[
0
]
>
1
:
other_patch_embeds
=
patch_embeddings
[
1
:]
# image_aspect_ratio == "anyres"
num_patch_width
,
num_patch_height
=
get_anyres_image_grid_shape
(
(
orig_width
,
orig_height
),
self
.
config
.
image_grid_pinpoints
,
self
.
config
.
vision_config
.
image_size
,
)
other_patch_embeds
=
other_patch_embeds
\
.
view
(
num_patch_width
,
num_patch_height
,
height
,
width
,
-
1
)
if
"unpad"
in
strategy
:
other_patch_embeds
=
other_patch_embeds
\
.
permute
(
4
,
0
,
2
,
1
,
3
).
contiguous
()
\
.
flatten
(
1
,
2
).
flatten
(
2
,
3
)
other_patch_embeds
=
unpad_image
(
other_patch_embeds
,
image_size
)
other_patch_embeds
=
torch
.
cat
((
other_patch_embeds
,
self
.
image_newline
[:,
None
,
None
]
\
.
expand
(
*
other_patch_embeds
.
shape
[:
-
1
],
1
)
\
.
to
(
other_patch_embeds
.
device
),
),
dim
=-
1
)
other_patch_embeds
=
other_patch_embeds
\
.
flatten
(
1
,
2
).
transpose
(
0
,
1
)
else
:
other_patch_embeds
=
other_patch_embeds
\
.
permute
(
0
,
2
,
1
,
3
,
4
).
contiguous
()
\
.
flatten
(
0
,
3
)
merged_patch_embeddings
=
torch
.
cat
(
(
base_patch_embeds
,
other_patch_embeds
),
dim
=
0
)
else
:
if
"unpad"
in
strategy
:
merged_patch_embeddings
=
torch
.
cat
(
(
base_patch_embeds
,
self
.
image_newline
[
None
]
\
.
to
(
base_patch_embeds
.
device
)
),
dim
=
0
)
else
:
merged_patch_embeddings
=
base_patch_embeds
return
merged_patch_embeddings
raise
ValueError
(
f
"Unexpected patch merge strategy:
{
strategy
}
"
)
def
_process_image_pixels
(
self
,
inputs
:
LlavaNextImagePixelInputs
)
->
torch
.
Tensor
:
assert
self
.
vision_tower
is
not
None
pixel_values
=
inputs
[
"data"
]
b
,
num_patches
,
c
,
h
,
w
=
pixel_values
.
shape
stacked_pixel_values
=
pixel_values
.
view
(
b
*
num_patches
,
c
,
h
,
w
)
stacked_image_features
=
self
.
_image_pixels_to_features
(
self
.
vision_tower
,
stacked_pixel_values
)
return
stacked_image_features
.
view
(
b
,
num_patches
,
*
stacked_image_features
.
shape
[
-
2
:])
def
_process_image_input
(
self
,
image_input
:
LlavaNextImageInputs
)
->
torch
.
Tensor
:
if
image_input
[
"type"
]
==
"pixel_values"
:
assert
self
.
vision_tower
is
not
None
image_features
=
self
.
_process_image_pixels
(
image_input
)
else
:
image_features
=
image_input
[
"data"
]
patch_embeddings
=
self
.
multi_modal_projector
(
image_features
)
image_sizes
=
image_input
.
get
(
"image_sizes"
)
if
image_sizes
is
None
:
batch_size
=
image_input
[
"data"
].
shape
[
0
]
vision_config
=
self
.
config
.
vision_config
default_width
=
default_height
=
vision_config
.
image_size
image_sizes
=
torch
.
as_tensor
([[
default_width
,
default_height
]
for
_
in
range
(
batch_size
)])
merged_patch_embeddings
=
[
self
.
_merge_image_patch_embeddings
(
image_sizes
[
i
],
patch_features
,
strategy
=
"spatial_unpad"
)
for
i
,
patch_features
in
enumerate
(
patch_embeddings
)
]
return
torch
.
stack
(
merged_patch_embeddings
,
dim
=
0
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
**
kwargs
:
object
,
)
->
SamplerOutput
:
"""Run forward pass for Llava 1.5.
One key thing to understand is the `input_ids` already accounts for the
positions of the to-be-inserted image embeddings.
Concretely, consider a text prompt:
"<image>
\n
USER: What's the content of the image?
\n
ASSISTANT:".
Tokenizer outputs:
[1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
The to-be-inserted image has a size of 576 (24 * 24) along the context
length dimension.
`input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
9047, 13566, 29901].
There will be 576 `32000` in the `input_ids`.
(32000 is the token id for `<image>`.)
This way, the `positions` and `attn_metadata` are consistent
with the `input_ids`.
The model takes two types of image inputs:
PIXEL_VALUES and IMAGE_FEATURES.
The following shows how each maps to huggingface implementation.
PIXEL_VALUES:
- https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L353
IMAGE_FEATURES:
- https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L430
before going through the multi modal projector.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: For PIXEL_VALUES, expects a batch with shape
[1, 3, 336, 336].
image_features: For IMAGE_FEATURES, expects a batch with shape
[1, 576, 1024].
"""
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
not
None
:
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
merge_vision_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
vision_language_config
.
image_token_id
)
input_ids
=
None
else
:
inputs_embeds
=
None
hidden_states
=
self
.
language_model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
inputs_embeds
=
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# only doing this for language model part for now.
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
for
key_to_modify
,
new_key
in
_KEYS_TO_MODIFY_MAPPING
.
items
():
if
key_to_modify
in
name
:
name
=
name
.
replace
(
key_to_modify
,
new_key
)
use_default_weight_loading
=
False
if
"vision"
in
name
:
if
self
.
vision_tower
is
not
None
:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading
=
True
else
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
use_default_weight_loading
=
True
if
use_default_weight_loading
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment