Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
af7f4372
Commit
af7f4372
authored
Sep 03, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.5' into v0.5.5-dtk24.04.1
parents
5e19cdef
09c77926
Changes
465
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
195 additions
and
113 deletions
+195
-113
vllm/model_executor/models/paligemma.py
vllm/model_executor/models/paligemma.py
+65
-43
vllm/model_executor/models/persimmon.py
vllm/model_executor/models/persimmon.py
+5
-2
vllm/model_executor/models/phi.py
vllm/model_executor/models/phi.py
+7
-2
vllm/model_executor/models/phi3_small.py
vllm/model_executor/models/phi3_small.py
+7
-3
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+111
-63
No files found.
Too many changes to show.
To preserve performance only
465 of 465+
files are displayed.
Plain diff
Email patch
vllm/model_executor/models/paligemma.py
View file @
af7f4372
from
typing
import
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
import
torch
from
torch
import
nn
...
...
@@ -9,20 +10,19 @@ from vllm.config import CacheConfig, MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.gemma
import
GemmaModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.
image
import
cached_get_tokenizer
from
vllm.multimodal.
utils
import
cached_get_tokenizer
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
.interfaces
import
Supports
Vision
from
.interfaces
import
Supports
MultiModal
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
dummy_seq_data_for_siglip
,
get_max_siglip_image_tokens
)
from
.utils
import
merge_
vision
_embeddings
from
.utils
import
merge_
multimodal
_embeddings
logger
=
init_logger
(
__name__
)
...
...
@@ -31,6 +31,25 @@ _KEYS_TO_MODIFY_MAPPING = {
}
class
PaliGemmaImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
data
:
torch
.
Tensor
"""Shape: (batch_size, num_channels, height, width)"""
class
PaliGemmaImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
data
:
torch
.
Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
PaliGemmaImageInputs
=
Union
[
PaliGemmaImagePixelInputs
,
PaliGemmaImageEmbeddingInputs
]
def
get_max_paligemma_image_tokens
(
ctx
:
InputContext
):
hf_config
=
ctx
.
get_hf_config
(
PaliGemmaConfig
)
vision_config
=
hf_config
.
vision_config
...
...
@@ -38,17 +57,20 @@ def get_max_paligemma_image_tokens(ctx: InputContext):
return
get_max_siglip_image_tokens
(
vision_config
)
def
dummy_data_for_paligemma
(
ctx
:
InputContext
,
seq_len
:
int
):
def
dummy_data_for_paligemma
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
hf_config
=
ctx
.
get_hf_config
(
PaliGemmaConfig
)
vision_config
=
hf_config
.
vision_config
num_images
=
mm_counts
[
"image"
]
seq_data
=
dummy_seq_data_for_siglip
(
vision_config
,
seq_len
,
num_images
,
image_token_id
=
hf_config
.
image_token_index
,
)
mm_data
=
dummy_image_for_siglip
(
vision_config
)
mm_data
=
dummy_image_for_siglip
(
vision_config
,
num_images
)
return
seq_data
,
mm_data
...
...
@@ -107,20 +129,11 @@ class PaliGemmaMultiModalProjector(nn.Module):
return
hidden_states
class
PaliGemmaImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
data
:
torch
.
Tensor
"""Shape: (batch_size, num_channels, height, width)"""
PaliGemmaImageInputs
=
PaliGemmaImagePixelInputs
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
()
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_paligemma_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_paligemma
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_paligemma
)
class
PaliGemmaForConditionalGeneration
(
nn
.
Module
,
Supports
Vision
):
class
PaliGemmaForConditionalGeneration
(
nn
.
Module
,
Supports
MultiModal
):
def
__init__
(
self
,
config
:
PaliGemmaConfig
,
...
...
@@ -163,19 +176,31 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
PaliGemmaImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
if
pixel_values
is
None
:
if
pixel_values
is
None
and
image_embeds
is
None
:
return
None
if
pixel_values
is
not
None
:
if
not
isinstance
(
pixel_values
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
return
PaliGemmaImagePixelInputs
(
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
pixel_values
),
)
if
image_embeds
is
not
None
:
if
not
isinstance
(
image_embeds
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of image embeddings. "
f
"Got type:
{
type
(
image_embeds
)
}
"
)
return
PaliGemmaImageEmbeddingInputs
(
type
=
"image_embeds"
,
data
=
image_embeds
,
)
raise
AssertionError
(
"This line should be unreachable."
)
def
_image_pixels_to_features
(
self
,
vision_tower
:
SiglipVisionModel
,
...
...
@@ -187,27 +212,21 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
return
image_features
def
_process_image_
pixels
(
def
_process_image_
input
(
self
,
input
s
:
PaliGemmaImage
Pixel
Inputs
,
image_
input
:
PaliGemmaImageInputs
,
)
->
torch
.
Tensor
:
assert
self
.
vision_tower
is
not
None
pixel_values
=
inputs
[
"data"
]
if
image_input
[
"type"
]
==
"image_embeds"
:
return
image_input
[
"data"
]
return
self
.
_image_pixels_to_features
(
assert
self
.
vision_tower
is
not
None
pixel_values
=
image_input
[
"data"
]
image_features
=
self
.
_image_pixels_to_features
(
self
.
vision_tower
,
pixel_values
,
)
def
_process_image_input
(
self
,
image_input
:
PaliGemmaImageInputs
,
)
->
torch
.
Tensor
:
assert
self
.
vision_tower
is
not
None
image_features
=
self
.
_process_image_pixels
(
image_input
,
)
return
self
.
multi_modal_projector
(
image_features
)
def
forward
(
self
,
...
...
@@ -228,7 +247,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
merge_
vision
_embeddings
(
inputs_embeds
=
merge_
multimodal
_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
config
.
image_token_index
)
...
...
@@ -246,8 +265,11 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
return
hidden_states
# Copied from vllm/model_executor/models/gemma.py
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
language_model
.
embed_tokens
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/persimmon.py
View file @
af7f4372
...
...
@@ -285,8 +285,11 @@ class PersimmonForCausalLM(nn.Module):
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/phi.py
View file @
af7f4372
...
...
@@ -260,6 +260,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
super
().
__init__
()
self
.
config
=
config
# lm_head use bias, cannot share word embeddings
assert
not
config
.
tie_word_embeddings
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
...
...
@@ -286,8 +288,11 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
,
self
.
lm_head
.
bias
)
return
logits
...
...
vllm/model_executor/models/phi3_small.py
View file @
af7f4372
...
...
@@ -368,6 +368,8 @@ class Phi3SmallForCausalLM(nn.Module):
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
,
quant_config
=
quant_config
,
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
@@ -399,8 +401,11 @@ class Phi3SmallForCausalLM(nn.Module):
def
get_decoder
(
self
):
return
self
.
model
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
if
self
.
dummy_token_indices
is
not
None
and
logits
is
not
None
:
...
...
@@ -446,4 +451,3 @@ class Phi3SmallForCausalLM(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
self
.
lm_head
.
weight
.
data
.
copy_
(
self
.
model
.
embed_tokens
.
weight
.
data
)
vllm/model_executor/models/phi3v.py
View file @
af7f4372
...
...
@@ -15,7 +15,8 @@
# limitations under the License.
import
re
from
functools
import
lru_cache
from
typing
import
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
,
Union
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
import
numpy
as
np
import
torch
...
...
@@ -28,8 +29,7 @@ from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
...
@@ -37,13 +37,13 @@ from vllm.model_executor.models.clip import CLIPVisionModel
from
vllm.model_executor.models.llama
import
LlamaModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.
image
import
cached_get_tokenizer
from
vllm.multimodal.
utils
import
cached_get_tokenizer
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
.clip
import
(
dummy_image_for_clip
,
dummy_seq_data_for_clip
,
input_processor_for_clip
)
from
.interfaces
import
Supports
Vision
from
.utils
import
merge_
vision
_embeddings
from
.interfaces
import
Supports
MultiModal
from
.utils
import
merge_
multimodal
_embeddings
logger
=
init_logger
(
__name__
)
...
...
@@ -70,6 +70,36 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
projection_dim
=
768
)
class
Phi3VImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
data
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
"""
image_sizes
:
torch
.
Tensor
"""
Shape: `(batch_size, 2)`
This should be in `(height, width)` format.
"""
class
Phi3VImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
data
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
Phi3VImageInputs
=
Union
[
Phi3VImagePixelInputs
,
Phi3VImageEmbeddingInputs
]
class
Phi3ImageEmbeddingBase
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
...
...
@@ -189,7 +219,7 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
global_image_features_hd_newline
=
self
.
add_image_newline
(
global_image_features_hd
)
all
_image_
embeddings
=
[]
batch
_image_
features_proj
=
[]
# need a for loop to process each image because of different image sizes
# (patch arrangement is different for each image)
for
i
,
img_size
in
enumerate
(
image_sizes
):
...
...
@@ -207,19 +237,17 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
sub_image_features_hd
)
# [sub features, separator, global features]
all_image_embeddings
.
append
(
torch
.
cat
([
image_embeddings
=
torch
.
cat
([
sub_image_features_hd_newline
.
squeeze
(
0
),
# (h_crop*12*(w_crop*12+1), 4096)
self
.
glb_GN
.
squeeze
(
0
),
global_image_features_hd_newline
[
i
],
]))
])
img_proj
=
self
.
img_projection
(
image_embeddings
.
to
(
target_device
,
target_dtype
))
batch_image_features_proj
.
append
(
img_proj
)
image_features_proj
=
self
.
img_projection
(
torch
.
stack
(
all_image_embeddings
).
to
(
target_device
,
target_dtype
)
)
# (num_images, (h_crop*12*(w_crop*12+1)+1), hidden_size)
return
image_features_proj
return
batch_image_features_proj
def
reshape_hd_patches_2x2merge
(
self
,
image_features
,
h_crop
,
w_crop
):
"""
...
...
@@ -259,24 +287,6 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
return
image_features_hd_newline
class
Phi3VImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
data
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
"""
image_sizes
:
torch
.
Tensor
"""
Shape: `(batch_size, 2)`
This should be in `(height, width)` format.
"""
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
def
_calc_padded_size
(
*
,
width
:
int
,
height
:
int
,
padding_unit
:
int
=
336
):
target_height
=
int
(
np
.
ceil
(
height
/
padding_unit
)
*
padding_unit
)
...
...
@@ -314,12 +324,12 @@ def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181
def
get_phi3v_image_feature_size
(
hf_config
:
PretrainedConfig
,
hf_config
:
Dict
[
str
,
Any
]
,
*
,
input_height
:
int
,
input_width
:
int
,
)
->
int
:
num_crops
=
getattr
(
hf_config
,
"num_crops"
,
16
)
num_crops
=
hf_config
.
get
(
"num_crops"
,
16
)
new_width
,
new_height
=
_calc_hd_transform_size
(
width
=
input_width
,
height
=
input_height
,
hd_num
=
num_crops
)
...
...
@@ -331,24 +341,28 @@ def get_phi3v_image_feature_size(
def
get_max_phi3v_image_tokens
(
ctx
:
InputContext
):
return
get_phi3v_image_feature_size
(
ctx
.
get_hf_
config
(
PretrainedC
onfig
),
ctx
.
get_hf_
image_processor_c
onfig
(
),
input_height
=
MAX_IMAGE_FEATURE_SIZE_HEIGHT
,
input_width
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
,
)
def
dummy_data_for_phi3v
(
ctx
:
InputContext
,
seq_len
:
int
):
def
dummy_data_for_phi3v
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
num_images
=
mm_counts
[
"image"
]
image_feature_size
=
get_max_phi3v_image_tokens
(
ctx
)
seq_data
=
dummy_seq_data_for_clip
(
CLIP_VIT_LARGE_PATCH14_336_CONFIG
,
seq_len
,
num_images
,
image_token_id
=
_IMAGE_TOKEN_ID
,
image_feature_size_override
=
image_feature_size
,
)
mm_data
=
dummy_image_for_clip
(
CLIP_VIT_LARGE_PATCH14_336_CONFIG
,
num_images
,
image_width_override
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
,
image_height_override
=
MAX_IMAGE_FEATURE_SIZE_HEIGHT
,
)
...
...
@@ -381,7 +395,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
return
llm_inputs
model_config
=
ctx
.
model_config
hf_config
=
ctx
.
get_hf_
config
(
PretrainedC
onfig
)
hf_config
=
ctx
.
get_hf_
image_processor_c
onfig
(
)
image_data
=
multi_modal_data
[
"image"
]
if
isinstance
(
image_data
,
Image
.
Image
):
...
...
@@ -392,7 +406,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
input_width
=
w
,
input_height
=
h
)
elif
isinstance
(
image_data
,
torch
.
Tensor
):
raise
NotImplementedError
(
"Embeddings input is not supported yet"
)
image_feature_size
=
image_data
.
shape
[
0
]
else
:
raise
TypeError
(
f
"Invalid image type:
{
type
(
image_data
)
}
"
)
...
...
@@ -443,7 +457,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_phi3v_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_phi3v
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_phi3v
)
class
Phi3VForCausalLM
(
nn
.
Module
,
Supports
Vision
):
class
Phi3VForCausalLM
(
nn
.
Module
,
Supports
MultiModal
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
...
...
@@ -463,6 +477,8 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
@@ -496,13 +512,18 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
return
data
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
Phi3VImage
Pixel
Inputs
]:
self
,
**
kwargs
:
object
)
->
Optional
[
Phi3VImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
image_sizes
=
kwargs
.
pop
(
"image_sizes"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
if
pixel_values
is
None
:
return
None
if
pixel_values
is
None
and
image_embeds
is
None
:
return
None
if
pixel_values
is
not
None
:
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
...
...
@@ -516,6 +537,31 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
data
=
self
.
_validate_pixel_values
(
pixel_values
),
image_sizes
=
self
.
_validate_image_sizes
(
image_sizes
))
if
image_embeds
is
not
None
:
if
not
isinstance
(
image_embeds
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of image embeddings. "
f
"Got type:
{
type
(
image_embeds
)
}
"
)
return
Phi3VImageEmbeddingInputs
(
type
=
"image_embeds"
,
data
=
image_embeds
,
)
raise
AssertionError
(
"This line should be unreachable."
)
def
_process_image_input
(
self
,
image_input
:
Phi3VImageInputs
,
)
->
torch
.
Tensor
:
if
image_input
[
"type"
]
==
"image_embeds"
:
return
image_input
[
"data"
]
assert
self
.
vision_embed_tokens
is
not
None
image_embeds
=
self
.
vision_embed_tokens
(
image_input
[
"data"
],
image_input
[
"image_sizes"
])
return
image_embeds
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
...
@@ -526,11 +572,10 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
not
None
:
vision_embeddings
=
self
.
vision_embed_tokens
(
image_input
[
"data"
],
image_input
[
"image_sizes"
])
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
merge_
vision
_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
inputs_embeds
=
merge_
multimodal
_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
image_token_id
)
input_ids
=
None
else
:
...
...
@@ -545,8 +590,11 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
Prev
1
…
20
21
22
23
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment