Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7cbd9ec7
Unverified
Commit
7cbd9ec7
authored
Jul 29, 2024
by
Isotr0py
Committed by
GitHub
Jul 29, 2024
Browse files
[Model] Initialize support for InternVL2 series models (#6514)
Co-authored-by:
Roger Wang
<
ywang@roblox.com
>
parent
3eeb148f
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
1042 additions
and
6 deletions
+1042
-6
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+4
-0
examples/offline_inference_vision_language.py
examples/offline_inference_vision_language.py
+15
-0
examples/openai_vision_api_client.py
examples/openai_vision_api_client.py
+2
-0
requirements-test.txt
requirements-test.txt
+1
-0
tests/models/test_internvl.py
tests/models/test_internvl.py
+201
-0
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+1
-1
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+1
-0
vllm/model_executor/models/intern_vit.py
vllm/model_executor/models/intern_vit.py
+270
-0
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+9
-1
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+471
-0
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+9
-1
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+5
-3
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+2
-0
vllm/transformers_utils/configs/internvl.py
vllm/transformers_utils/configs/internvl.py
+51
-0
No files found.
docs/source/models/supported_models.rst
View file @
7cbd9ec7
...
...
@@ -200,6 +200,10 @@ Vision Language Models
- Fuyu
- :code:`adept/fuyu-8b` etc.
-
* - :code:`InternVLChatModel`
- InternVL2
- :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc.
-
* - :code:`LlavaForConditionalGeneration`
- LLaVA-1.5
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc.
...
...
examples/offline_inference_vision_language.py
View file @
7cbd9ec7
...
...
@@ -106,6 +106,20 @@ def run_minicpmv(question):
return
llm
,
prompt
# InternVL
def
run_internvl
(
question
):
# Generally, InternVL can use chatml template for conversation
TEMPLATE
=
"<|im_start|>User
\n
{prompt}<|im_end|>
\n
<|im_start|>Assistant
\n
"
prompt
=
f
"<image>
\n
{
question
}
\n
"
prompt
=
TEMPLATE
.
format
(
prompt
=
prompt
)
llm
=
LLM
(
model
=
"OpenGVLab/InternVL2-4B"
,
trust_remote_code
=
True
,
max_num_seqs
=
5
,
)
return
llm
,
prompt
# BLIP-2
def
run_blip2
(
question
):
...
...
@@ -125,6 +139,7 @@ model_example_map = {
"chameleon"
:
run_chameleon
,
"minicpmv"
:
run_minicpmv
,
"blip-2"
:
run_blip2
,
"internvl_chat"
:
run_internvl
,
}
...
...
examples/openai_vision_api_client.py
View file @
7cbd9ec7
...
...
@@ -42,6 +42,7 @@ chat_completion_from_url = client.chat.completions.create(
],
}],
model
=
model
,
max_tokens
=
64
,
)
result
=
chat_completion_from_url
.
choices
[
0
].
message
.
content
...
...
@@ -78,6 +79,7 @@ chat_completion_from_base64 = client.chat.completions.create(
],
}],
model
=
model
,
max_tokens
=
64
,
)
result
=
chat_completion_from_base64
.
choices
[
0
].
message
.
content
...
...
requirements-test.txt
View file @
7cbd9ec7
...
...
@@ -16,6 +16,7 @@ ray
sentence-transformers # required for embedding
sparseml==1.8.0 # required for compressed-tensors
compressed-tensors==0.4.0 # required for compressed-tensors
timm # required for internvl test
# Benchmarking
aiohttp
...
...
tests/models/test_internvl.py
0 → 100644
View file @
7cbd9ec7
import
types
from
typing
import
List
,
Optional
,
Type
import
pytest
import
torch
from
huggingface_hub
import
snapshot_download
from
PIL.Image
import
Image
from
vllm.model_executor.models.internvl
import
(
IMG_CONTEXT
,
IMG_END
,
IMG_START
,
image_to_pixel_values
)
from
vllm.multimodal.utils
import
rescale_image_size
from
vllm.utils
import
is_cpu
from
..conftest
import
IMAGE_ASSETS
,
HfRunner
,
VllmRunner
,
_ImageAssets
from
.utils
import
check_logprobs_close
pytestmark
=
pytest
.
mark
.
vlm
HF_IMAGE_PROMPTS
=
IMAGE_ASSETS
.
prompts
({
"stop_sign"
:
"<|im_start|>User
\n
<image>
\n
What's the content in the center of the image?<|im_end|>
\n
<|im_start|>Assistant
\n
"
,
# noqa: E501
"cherry_blossom"
:
"<|im_start|>User
\n
<image>
\n
What is the season?<|im_end|>
\n
<|im_start|>Assistant
\n
"
,
# noqa: E501
})
# we use snapshot_download to prevent conflicts between
# dynamic_module and trust_remote_code for hf_runner
models
=
[
snapshot_download
(
"OpenGVLab/InternVL2-1B"
),
snapshot_download
(
"OpenGVLab/InternVL2-2B"
),
# snapshot_download("OpenGVLab/InternVL2-4B"), # broken
]
class
InternVLProcessor
:
"""A simple processor for InternVL2 HF model which misses a processor."""
def
__init__
(
self
,
hf_runner
:
HfRunner
):
self
.
num_image_token
=
hf_runner
.
model
.
num_image_token
self
.
tokenizer
=
hf_runner
.
tokenizer
self
.
dtype
=
hf_runner
.
model
.
dtype
def
__call__
(
self
,
text
:
str
,
images
:
Image
,
**
kwargs
):
pixel_values
=
image_to_pixel_values
(
images
).
to
(
self
.
dtype
)
num_patches_list
=
[
pixel_values
.
shape
[
0
]]
for
num_patches
in
num_patches_list
:
context_tokens
=
IMG_CONTEXT
*
self
.
num_image_token
*
num_patches
image_tokens
=
IMG_START
+
context_tokens
+
IMG_END
text
=
text
.
replace
(
'<image>'
,
image_tokens
,
1
)
prompt
=
self
.
tokenizer
(
text
,
return_tensors
=
"pt"
)
prompt
.
update
({
"pixel_values"
:
pixel_values
})
return
prompt
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py
def
generate
(
self
,
pixel_values
:
torch
.
FloatTensor
,
input_ids
:
torch
.
FloatTensor
,
attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
**
generate_kwargs
,
)
->
torch
.
LongTensor
:
"""Generate method for InternVL2 model without fixed use_cache."""
assert
self
.
img_context_token_id
is
not
None
vit_embeds
=
self
.
extract_feature
(
pixel_values
)
input_embeds
=
self
.
language_model
.
get_input_embeddings
()(
input_ids
)
B
,
N
,
C
=
input_embeds
.
shape
input_embeds
=
input_embeds
.
reshape
(
B
*
N
,
C
)
input_ids
=
input_ids
.
reshape
(
B
*
N
)
selected
=
(
input_ids
==
self
.
img_context_token_id
)
assert
selected
.
sum
()
!=
0
input_embeds
[
selected
]
=
vit_embeds
.
reshape
(
-
1
,
C
).
to
(
input_embeds
.
device
)
input_embeds
=
input_embeds
.
reshape
(
B
,
N
,
C
)
outputs
=
self
.
language_model
.
generate
(
inputs_embeds
=
input_embeds
,
attention_mask
=
attention_mask
,
**
generate_kwargs
,
)
return
outputs
def
run_test
(
hf_runner
:
Type
[
HfRunner
],
vllm_runner
:
Type
[
VllmRunner
],
image_assets
:
_ImageAssets
,
model
:
str
,
*
,
size_factors
:
List
[
float
],
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
tensor_parallel_size
:
int
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
):
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
images
=
[
asset
.
pil_image
for
asset
in
image_assets
]
inputs_per_image
=
[(
[
prompt
for
_
in
size_factors
],
[
rescale_image_size
(
image
,
factor
)
for
factor
in
size_factors
],
)
for
image
,
prompt
in
zip
(
images
,
HF_IMAGE_PROMPTS
)]
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
with
vllm_runner
(
model
,
max_model_len
=
4096
,
dtype
=
dtype
,
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
,
enforce_eager
=
True
)
as
vllm_model
:
vllm_outputs_per_image
=
[
vllm_model
.
generate_greedy_logprobs
(
prompts
,
max_tokens
,
num_logprobs
=
num_logprobs
,
images
=
images
)
for
prompts
,
images
in
inputs_per_image
]
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
:
img_context_token_id
=
hf_model
.
tokenizer
.
convert_tokens_to_ids
(
"<IMG_CONTEXT>"
)
hf_model
.
model
.
img_context_token_id
=
img_context_token_id
hf_model
.
processor
=
InternVLProcessor
(
hf_model
)
hf_model
.
model
.
get_output_embeddings
=
lambda
:
\
hf_model
.
model
.
language_model
.
get_output_embeddings
()
hf_model
.
model
.
generate
=
types
.
MethodType
(
generate
,
hf_model
.
model
)
eos_token_id
=
hf_model
.
tokenizer
.
eos_token_id
hf_outputs_per_image
=
[
hf_model
.
generate_greedy_logprobs_limit
(
prompts
,
max_tokens
,
num_logprobs
=
num_logprobs
,
images
=
hf_images
,
eos_token_id
=
eos_token_id
)
for
prompts
,
hf_images
in
inputs_per_image
]
for
hf_outputs
,
vllm_outputs
in
zip
(
hf_outputs_per_image
,
vllm_outputs_per_image
):
# TODO: Check whether using original CLIPVisionModel can improve
# consistency against HF
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
target_dtype
=
"half"
if
is_cpu
():
target_dtype
=
"bfloat16"
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"size_factors"
,
[
# No image
[],
# Single-scale
[
1.0
],
# Single-scale, batched
[
1.0
,
1.0
,
1.0
],
# Multi-scale
[
0.25
,
0.5
,
1.0
],
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
target_dtype
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
torch
.
inference_mode
()
def
test_models
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
size_factors
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
)
->
None
:
run_test
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
size_factors
=
size_factors
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
)
vllm/entrypoints/chat_utils.py
View file @
7cbd9ec7
...
...
@@ -107,7 +107,7 @@ def _image_token_str(model_config: ModelConfig,
return
None
if
model_type
.
startswith
(
"llava"
):
return
tokenizer
.
decode
(
model_config
.
hf_config
.
image_token_index
)
if
model_type
==
"chameleon"
:
if
model_type
in
(
"chameleon"
,
"internvl_chat"
)
:
return
"<image>"
raise
TypeError
(
f
"Unknown model type:
{
model_type
}
"
)
...
...
vllm/model_executor/models/__init__.py
View file @
7cbd9ec7
...
...
@@ -37,6 +37,7 @@ _GENERATION_MODELS = {
"GPTNeoXForCausalLM"
:
(
"gpt_neox"
,
"GPTNeoXForCausalLM"
),
"InternLMForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"InternLM2ForCausalLM"
:
(
"internlm2"
,
"InternLM2ForCausalLM"
),
"InternVLChatModel"
:
(
"internvl"
,
"InternVLChatModel"
),
"JAISLMHeadModel"
:
(
"jais"
,
"JAISLMHeadModel"
),
"LlamaForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"LlavaForConditionalGeneration"
:
...
...
vllm/model_executor/models/intern_vit.py
0 → 100644
View file @
7cbd9ec7
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
transformers
import
PretrainedConfig
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
NORM2FN
=
{
'rms_norm'
:
RMSNorm
,
'layer_norm'
:
nn
.
LayerNorm
,
}
class
InternVisionEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
image_size
=
config
.
image_size
self
.
patch_size
=
config
.
patch_size
self
.
class_embedding
=
nn
.
Parameter
(
torch
.
randn
(
1
,
1
,
self
.
embed_dim
))
self
.
patch_embedding
=
nn
.
Conv2d
(
in_channels
=
3
,
out_channels
=
self
.
embed_dim
,
kernel_size
=
self
.
patch_size
,
stride
=
self
.
patch_size
)
self
.
num_patches
=
(
self
.
image_size
//
self
.
patch_size
)
**
2
self
.
num_positions
=
self
.
num_patches
+
1
self
.
position_embedding
=
nn
.
Parameter
(
torch
.
randn
(
1
,
self
.
num_positions
,
self
.
embed_dim
))
def
_get_pos_embed
(
self
,
pos_embed
,
H
,
W
):
target_dtype
=
pos_embed
.
dtype
pos_embed
=
pos_embed
.
float
().
reshape
(
1
,
self
.
image_size
//
self
.
patch_size
,
self
.
image_size
//
self
.
patch_size
,
-
1
).
permute
(
0
,
3
,
1
,
2
)
pos_embed
=
F
.
interpolate
(
pos_embed
,
size
=
(
H
,
W
),
mode
=
'bicubic'
,
align_corners
=
False
)
pos_embed
=
pos_embed
.
reshape
(
1
,
-
1
,
H
*
W
).
permute
(
0
,
2
,
1
).
to
(
target_dtype
)
return
pos_embed
def
forward
(
self
,
pixel_values
:
torch
.
FloatTensor
)
->
torch
.
Tensor
:
target_dtype
=
self
.
patch_embedding
.
weight
.
dtype
patch_embeds
=
self
.
patch_embedding
(
pixel_values
.
to
(
target_dtype
))
# shape = [*, channel, width, height]
batch_size
,
_
,
height
,
width
=
patch_embeds
.
shape
patch_embeds
=
patch_embeds
.
flatten
(
2
).
transpose
(
1
,
2
)
class_embeds
=
self
.
class_embedding
.
expand
(
batch_size
,
1
,
-
1
).
to
(
target_dtype
)
embeddings
=
torch
.
cat
([
class_embeds
,
patch_embeds
],
dim
=
1
)
position_embedding
=
torch
.
cat
([
self
.
position_embedding
[:,
:
1
,
:],
self
.
_get_pos_embed
(
self
.
position_embedding
[:,
1
:,
:],
height
,
width
)
],
dim
=
1
)
embeddings
=
embeddings
+
position_embedding
.
to
(
target_dtype
)
return
embeddings
class
InternAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
self
,
config
:
PretrainedConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_dim
=
self
.
embed_dim
//
self
.
num_heads
if
self
.
head_dim
*
self
.
num_heads
!=
self
.
embed_dim
:
raise
ValueError
(
f
'embed_dim must be divisible by num_heads '
f
'(got `embed_dim`:
{
self
.
embed_dim
}
and `num_heads`:'
f
'
{
self
.
num_heads
}
).'
)
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
qkv
=
nn
.
Linear
(
self
.
embed_dim
,
3
*
self
.
embed_dim
,
bias
=
config
.
qkv_bias
)
self
.
qk_normalization
=
config
.
qk_normalization
if
self
.
qk_normalization
:
self
.
q_norm
=
RMSNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
k_norm
=
RMSNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
proj
=
nn
.
Linear
(
self
.
embed_dim
,
self
.
embed_dim
)
def
forward
(
self
,
x
):
B
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
).
reshape
(
B
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
.
unbind
(
0
)
if
self
.
qk_normalization
:
B_
,
H_
,
N_
,
D_
=
q
.
shape
q
=
self
.
q_norm
(
q
.
transpose
(
1
,
2
).
flatten
(
-
2
,
-
1
)).
view
(
B_
,
N_
,
H_
,
D_
).
transpose
(
1
,
2
)
k
=
self
.
k_norm
(
k
.
transpose
(
1
,
2
).
flatten
(
-
2
,
-
1
)).
view
(
B_
,
N_
,
H_
,
D_
).
transpose
(
1
,
2
)
x
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
scale
=
self
.
scale
)
x
=
x
.
transpose
(
1
,
2
).
reshape
(
B
,
N
,
C
)
x
=
self
.
proj
(
x
)
return
x
class
InternMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
activation_fn
=
get_act_fn
(
config
.
hidden_act
)
self
.
fc1
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
intermediate_size
,
bias
=
True
,
quant_config
=
quant_config
)
self
.
fc2
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
hidden_states
)
hidden_states
,
_
=
self
.
fc2
(
hidden_states
)
return
hidden_states
class
InternVisionEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
embed_dim
=
config
.
hidden_size
self
.
intermediate_size
=
config
.
intermediate_size
self
.
norm_type
=
config
.
norm_type
self
.
attn
=
InternAttention
(
config
)
self
.
mlp
=
InternMLP
(
config
,
quant_config
=
quant_config
)
self
.
norm1
=
NORM2FN
[
self
.
norm_type
](
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
norm2
=
NORM2FN
[
self
.
norm_type
](
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
ls1
=
nn
.
Parameter
(
config
.
initializer_factor
*
torch
.
ones
(
self
.
embed_dim
))
self
.
ls2
=
nn
.
Parameter
(
config
.
initializer_factor
*
torch
.
ones
(
self
.
embed_dim
))
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
):
hidden_states
=
hidden_states
+
self
.
attn
(
self
.
norm1
(
hidden_states
))
*
self
.
ls1
hidden_states
=
hidden_states
+
self
.
mlp
(
self
.
norm2
(
hidden_states
))
*
self
.
ls2
return
hidden_states
class
InternVisionEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
num_hidden_layers_override
:
Optional
[
int
]
=
None
):
super
().
__init__
()
self
.
config
=
config
if
num_hidden_layers_override
is
None
:
num_hidden_layers
=
config
.
num_hidden_layers
else
:
num_hidden_layers
=
num_hidden_layers_override
self
.
layers
=
nn
.
ModuleList
([
InternVisionEncoderLayer
(
config
=
config
,
quant_config
=
quant_config
)
for
_
in
range
(
num_hidden_layers
)
])
def
forward
(
self
,
inputs_embeds
:
torch
.
Tensor
):
hidden_states
=
inputs_embeds
for
encoder_layer
in
self
.
layers
:
hidden_states
=
encoder_layer
(
hidden_states
)
return
hidden_states
class
InternVisionModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
num_hidden_layers_override
:
Optional
[
int
]
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
embeddings
=
InternVisionEmbeddings
(
config
)
self
.
encoder
=
InternVisionEncoder
(
config
=
config
,
quant_config
=
quant_config
,
num_hidden_layers_override
=
num_hidden_layers_override
)
def
resize_pos_embeddings
(
self
,
old_size
,
new_size
,
patch_size
):
pos_emb
=
self
.
embeddings
.
position_embedding
_
,
num_positions
,
embed_dim
=
pos_emb
.
shape
cls_emb
=
pos_emb
[:,
:
1
,
:]
pos_emb
=
pos_emb
[:,
1
:,
:].
reshape
(
1
,
old_size
//
patch_size
,
old_size
//
patch_size
,
-
1
).
permute
(
0
,
3
,
1
,
2
)
pos_emb
=
F
.
interpolate
(
pos_emb
.
float
(),
size
=
new_size
//
patch_size
,
mode
=
'bicubic'
,
align_corners
=
False
)
pos_emb
=
pos_emb
.
to
(
cls_emb
.
dtype
).
reshape
(
1
,
embed_dim
,
-
1
).
permute
(
0
,
2
,
1
)
pos_emb
=
torch
.
cat
([
cls_emb
,
pos_emb
],
dim
=
1
)
self
.
embeddings
.
position_embedding
=
nn
.
Parameter
(
pos_emb
)
self
.
embeddings
.
image_size
=
new_size
def
get_input_embeddings
(
self
):
return
self
.
embeddings
def
forward
(
self
,
pixel_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pixel_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
FloatTensor
:
if
pixel_values
is
None
and
pixel_embeds
is
None
:
raise
ValueError
(
'You have to specify pixel_values or pixel_embeds'
)
if
pixel_embeds
is
not
None
:
hidden_states
=
pixel_embeds
elif
pixel_values
is
not
None
:
if
pixel_values
.
ndim
==
4
:
hidden_states
=
self
.
embeddings
(
pixel_values
)
else
:
raise
ValueError
(
f
'wrong pixel_values size:
{
pixel_values
.
shape
}
'
)
encoder_outputs
=
self
.
encoder
(
inputs_embeds
=
hidden_states
)
return
encoder_outputs
vllm/model_executor/models/internlm2.py
View file @
7cbd9ec7
...
...
@@ -219,13 +219,21 @@ class InternLM2Model(nn.Module):
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
tok_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
IntermediateTensors
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
tok_embeddings
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
...
...
vllm/model_executor/models/internvl.py
0 → 100644
View file @
7cbd9ec7
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from
typing
import
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
,
Union
import
torch
import
torch.nn
as
nn
import
torchvision.transforms
as
T
from
PIL
import
Image
from
transformers
import
PretrainedConfig
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models.intern_vit
import
InternVisionModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
BatchedTensors
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.image
import
cached_get_tokenizer
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
.clip
import
(
dummy_image_for_clip
,
dummy_seq_data_for_clip
,
get_clip_num_patches
)
from
.interfaces
import
SupportsVision
from
.utils
import
merge_vision_embeddings
IMG_START
=
'<img>'
IMG_END
=
'</img>'
IMG_CONTEXT
=
'<IMG_CONTEXT>'
IMAGENET_MEAN
=
(
0.485
,
0.456
,
0.406
)
IMAGENET_STD
=
(
0.229
,
0.224
,
0.225
)
MAX_IMAGE_FEATURE_SIZE_WIDTH
=
3000
MAX_IMAGE_FEATURE_SIZE_HEIGHT
=
500
class
InternVLImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
data
:
BatchedTensors
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
"""
# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
def
build_transform
(
input_size
):
MEAN
,
STD
=
IMAGENET_MEAN
,
IMAGENET_STD
transform
=
T
.
Compose
([
T
.
Lambda
(
lambda
img
:
img
.
convert
(
'RGB'
)
if
img
.
mode
!=
'RGB'
else
img
),
T
.
Resize
((
input_size
,
input_size
),
interpolation
=
T
.
InterpolationMode
.
BICUBIC
),
T
.
ToTensor
(),
T
.
Normalize
(
mean
=
MEAN
,
std
=
STD
)
])
return
transform
# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
def
find_closest_aspect_ratio
(
aspect_ratio
,
target_ratios
,
width
,
height
,
image_size
):
best_ratio_diff
=
float
(
'inf'
)
best_ratio
=
(
1
,
1
)
area
=
width
*
height
for
ratio
in
target_ratios
:
target_aspect_ratio
=
ratio
[
0
]
/
ratio
[
1
]
ratio_diff
=
abs
(
aspect_ratio
-
target_aspect_ratio
)
if
ratio_diff
<
best_ratio_diff
:
best_ratio_diff
=
ratio_diff
best_ratio
=
ratio
elif
ratio_diff
==
best_ratio_diff
:
if
area
>
0.5
*
image_size
*
image_size
*
ratio
[
0
]
*
ratio
[
1
]:
best_ratio
=
ratio
return
best_ratio
def
calculate_num_blocks
(
orig_width
:
int
,
orig_height
:
int
,
min_num
=
1
,
max_num
=
6
,
image_size
=
448
):
aspect_ratio
=
orig_width
/
orig_height
# calculate the existing image aspect ratio
target_ratios
=
set
((
i
,
j
)
for
n
in
range
(
min_num
,
max_num
+
1
)
for
i
in
range
(
1
,
n
+
1
)
for
j
in
range
(
1
,
n
+
1
)
if
i
*
j
<=
max_num
and
i
*
j
>=
min_num
)
target_ratios
=
sorted
(
target_ratios
,
key
=
lambda
x
:
x
[
0
]
*
x
[
1
])
# find the closest aspect ratio to the target
target_aspect_ratio
=
find_closest_aspect_ratio
(
aspect_ratio
,
target_ratios
,
orig_width
,
orig_height
,
image_size
)
# calculate the target width and height
target_width
=
image_size
*
target_aspect_ratio
[
0
]
target_height
=
image_size
*
target_aspect_ratio
[
1
]
blocks
=
target_aspect_ratio
[
0
]
*
target_aspect_ratio
[
1
]
return
blocks
,
target_width
,
target_height
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def
dynamic_preprocess
(
image
,
min_num
=
1
,
max_num
=
6
,
image_size
=
448
,
use_thumbnail
=
False
):
orig_width
,
orig_height
=
image
.
size
blocks
,
target_width
,
target_height
=
calculate_num_blocks
(
orig_width
,
orig_height
,
min_num
,
max_num
,
image_size
)
# resize the image
resized_img
=
image
.
resize
((
target_width
,
target_height
))
processed_images
=
[]
for
i
in
range
(
blocks
):
box
=
((
i
%
(
target_width
//
image_size
))
*
image_size
,
(
i
//
(
target_width
//
image_size
))
*
image_size
,
((
i
%
(
target_width
//
image_size
))
+
1
)
*
image_size
,
((
i
//
(
target_width
//
image_size
))
+
1
)
*
image_size
)
# split the image
split_img
=
resized_img
.
crop
(
box
)
processed_images
.
append
(
split_img
)
assert
len
(
processed_images
)
==
blocks
if
use_thumbnail
and
len
(
processed_images
)
!=
1
:
thumbnail_img
=
image
.
resize
((
image_size
,
image_size
))
processed_images
.
append
(
thumbnail_img
)
return
processed_images
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def
image_to_pixel_values
(
image
:
Image
.
Image
,
input_size
=
448
,
max_num
=
6
):
transform
=
build_transform
(
input_size
=
input_size
)
images
=
dynamic_preprocess
(
image
,
image_size
=
input_size
,
use_thumbnail
=
True
,
max_num
=
max_num
)
pixel_values
=
[
transform
(
image
)
for
image
in
images
]
pixel_values
=
torch
.
stack
(
pixel_values
)
return
pixel_values
def
get_internvl_num_patches
(
image_size
:
int
,
patch_size
:
int
,
downsample_ratio
:
float
):
return
int
(
get_clip_num_patches
(
image_size
=
image_size
,
patch_size
=
patch_size
)
*
(
downsample_ratio
**
2
))
def
get_max_internvl_image_tokens
(
ctx
:
InputContext
):
hf_config
=
ctx
.
get_hf_config
(
PretrainedConfig
)
vision_config
=
hf_config
.
vision_config
image_size
=
vision_config
.
image_size
patch_size
=
vision_config
.
patch_size
downsample_ratio
=
hf_config
.
downsample_ratio
num_patches
=
get_internvl_num_patches
(
image_size
,
patch_size
,
downsample_ratio
)
return
num_patches
*
7
def
input_processor_for_internvl
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
):
multi_modal_data
=
llm_inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
return
llm_inputs
model_config
=
ctx
.
model_config
hf_config
=
ctx
.
get_hf_config
(
PretrainedConfig
)
vision_config
=
hf_config
.
vision_config
image_data
=
multi_modal_data
[
"image"
]
if
isinstance
(
image_data
,
Image
.
Image
):
width
,
height
=
image_data
.
size
num_blocks
,
_
,
_
=
calculate_num_blocks
(
width
,
height
)
elif
isinstance
(
image_data
,
torch
.
Tensor
):
raise
NotImplementedError
(
"Embeddings input is not supported yet"
)
else
:
raise
TypeError
(
f
"Invalid image type:
{
type
(
image_data
)
}
"
)
image_size
=
vision_config
.
image_size
patch_size
=
vision_config
.
patch_size
downsample_ratio
=
hf_config
.
downsample_ratio
num_patches
=
get_internvl_num_patches
(
image_size
,
patch_size
,
downsample_ratio
)
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
prompt
=
llm_inputs
[
"prompt"
]
prompt_token_ids
=
llm_inputs
[
"prompt_token_ids"
]
if
prompt
is
None
:
prompt
=
tokenizer
.
decode
(
prompt_token_ids
)
image_prompt
=
IMG_START
+
IMG_CONTEXT
*
(
num_blocks
+
1
)
*
num_patches
+
IMG_END
new_prompt
=
prompt
.
replace
(
'<image>'
,
image_prompt
,
1
)
new_prompt_token_ids
=
tokenizer
.
encode
(
new_prompt
)
return
LLMInputs
(
prompt
=
prompt
,
prompt_token_ids
=
new_prompt_token_ids
,
multi_modal_data
=
multi_modal_data
)
def
input_mapper_for_internvl
(
ctx
:
InputContext
,
data
:
object
):
if
isinstance
(
data
,
Image
.
Image
):
data
=
image_to_pixel_values
(
data
)
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
image_token_id
=
tokenizer
.
encode
(
IMG_CONTEXT
,
add_special_tokens
=
False
,
return_tensors
=
"pt"
)[
0
]
return
MultiModalInputs
({
"pixel_values"
:
data
,
"image_token_id"
:
image_token_id
})
def
dummy_data_for_internvl
(
ctx
:
InputContext
,
seq_len
:
int
):
image_feature_size
=
get_max_internvl_image_tokens
(
ctx
)
model_config
=
ctx
.
model_config
hf_config
=
ctx
.
get_hf_config
(
PretrainedConfig
)
vision_config
=
hf_config
.
vision_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
seq_data
=
dummy_seq_data_for_clip
(
vision_config
,
seq_len
,
image_token_id
=
tokenizer
.
encode
(
IMG_CONTEXT
,
add_special_tokens
=
False
)[
0
],
image_feature_size_override
=
image_feature_size
,
)
mm_data
=
dummy_image_for_clip
(
vision_config
,
image_width_override
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
,
image_height_override
=
MAX_IMAGE_FEATURE_SIZE_HEIGHT
,
)
return
seq_data
,
mm_data
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
(
input_mapper_for_internvl
)
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_internvl_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_internvl
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_internvl
)
class
InternVLChatModel
(
nn
.
Module
,
SupportsVision
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
multimodal_config
:
MultiModalConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
image_size
=
config
.
force_image_size
or
config
.
vision_config
.
image_size
patch_size
=
config
.
vision_config
.
patch_size
self
.
patch_size
=
patch_size
self
.
select_layer
=
config
.
select_layer
self
.
num_image_token
=
int
(
(
image_size
//
patch_size
)
**
2
*
(
config
.
downsample_ratio
**
2
))
self
.
downsample_ratio
=
config
.
downsample_ratio
self
.
ps_version
=
config
.
ps_version
vision_feature_layer
=
self
.
select_layer
if
vision_feature_layer
<
0
:
num_hidden_layers
=
config
.
vision_config
.
num_hidden_layers
\
+
vision_feature_layer
+
1
else
:
num_hidden_layers
=
vision_feature_layer
+
1
self
.
vision_model
=
InternVisionModel
(
config
.
vision_config
,
num_hidden_layers_override
=
num_hidden_layers
)
llm_class
=
ModelRegistry
.
load_model_cls
(
config
.
text_config
.
architectures
[
0
])
self
.
language_model
=
llm_class
(
config
.
text_config
,
cache_config
,
quant_config
)
vit_hidden_size
=
config
.
vision_config
.
hidden_size
llm_hidden_size
=
config
.
text_config
.
hidden_size
self
.
mlp1
=
nn
.
Sequential
(
nn
.
LayerNorm
(
vit_hidden_size
*
int
(
1
/
self
.
downsample_ratio
)
**
2
),
nn
.
Linear
(
vit_hidden_size
*
int
(
1
/
self
.
downsample_ratio
)
**
2
,
llm_hidden_size
),
nn
.
GELU
(),
nn
.
Linear
(
llm_hidden_size
,
llm_hidden_size
))
self
.
img_context_token_id
=
None
def
pixel_shuffle
(
self
,
x
,
scale_factor
=
0.5
):
n
,
w
,
h
,
c
=
x
.
size
()
# N, W, H, C --> N, W, H * scale, C // scale
x
=
x
.
view
(
n
,
w
,
int
(
h
*
scale_factor
),
int
(
c
/
scale_factor
))
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
x
=
x
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
x
=
x
.
view
(
n
,
int
(
h
*
scale_factor
),
int
(
w
*
scale_factor
),
int
(
c
/
(
scale_factor
*
scale_factor
)))
if
self
.
ps_version
==
'v1'
:
pass
else
:
x
=
x
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
return
x
def
extract_feature
(
self
,
pixel_values
):
vit_embeds
=
self
.
vision_model
(
pixel_values
=
pixel_values
)
vit_embeds
=
vit_embeds
[:,
1
:,
:]
h
=
w
=
int
(
vit_embeds
.
shape
[
1
]
**
0.5
)
vit_embeds
=
vit_embeds
.
reshape
(
vit_embeds
.
shape
[
0
],
h
,
w
,
-
1
)
vit_embeds
=
self
.
pixel_shuffle
(
vit_embeds
,
scale_factor
=
self
.
downsample_ratio
)
vit_embeds
=
vit_embeds
.
reshape
(
vit_embeds
.
shape
[
0
],
-
1
,
vit_embeds
.
shape
[
-
1
])
vit_embeds
=
self
.
mlp1
(
vit_embeds
)
return
vit_embeds
def
_validate_image_sizes
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
list
(
data
.
shape
[
1
:])
!=
[
2
]:
raise
ValueError
(
f
"The expected image sizes shape is batch dimension plus "
f
"
{
[
2
]
}
. You supplied
{
data
.
shape
}
."
)
return
data
def
_validate_pixel_values
(
self
,
data
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
)
->
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]:
h
=
w
=
self
.
config
.
vision_config
.
image_size
expected_dims
=
(
3
,
h
,
w
)
def
_validate_shape
(
d
:
torch
.
Tensor
):
actual_dims
=
tuple
(
d
.
shape
)
if
actual_dims
!=
expected_dims
:
expected_expr
=
(
"num_patches"
,
*
map
(
str
,
expected_dims
))
raise
ValueError
(
"The expected shape of pixel values in each batch element "
f
"is
{
expected_expr
}
. You supplied
{
tuple
(
d
.
shape
)
}
."
)
for
d
in
data
:
_validate_shape
(
d
)
return
data
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
InternVLImagePixelInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
image_token_id
=
kwargs
.
pop
(
"image_token_id"
,
None
)
if
pixel_values
is
None
:
return
None
self
.
img_context_token_id
=
image_token_id
[
0
]
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
return
InternVLImagePixelInputs
(
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
pixel_values
),
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
**
kwargs
:
object
,
)
->
SamplerOutput
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
not
None
:
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
input_ids
)
vit_embeds
=
self
.
extract_feature
(
image_input
[
"data"
])
inputs_embeds
=
merge_vision_embeddings
(
input_ids
,
inputs_embeds
,
vit_embeds
,
self
.
img_context_token_id
)
input_ids
=
None
else
:
inputs_embeds
=
None
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
None
,
inputs_embeds
=
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
(
".gate_up_proj"
,
".w1"
,
0
),
(
".gate_up_proj"
,
".w3"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
self
.
config
.
text_config
.
tie_word_embeddings
\
and
"lm_head.weight"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
# We only do sharding for language model
# and not vision model for now.
if
"vision_embed_tokens"
in
name
and
self
.
vision_embed_tokens
:
continue
if
weight_name
not
in
name
:
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
if
"wqkv"
in
name
:
config
=
self
.
config
.
text_config
kv_groups
=
(
config
.
num_attention_heads
//
config
.
num_key_value_heads
)
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
loaded_weight
=
loaded_weight
.
view
(
-
1
,
2
+
kv_groups
,
head_dim
,
loaded_weight
.
shape
[
-
1
])
wq
,
wk
,
wv
=
torch
.
split
(
loaded_weight
,
[
kv_groups
,
1
,
1
],
dim
=
1
)
wq
=
wq
.
reshape
(
-
1
,
wq
.
shape
[
-
1
])
wk
=
wk
.
reshape
(
-
1
,
wk
.
shape
[
-
1
])
wv
=
wv
.
reshape
(
-
1
,
wv
.
shape
[
-
1
])
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
wq
,
'q'
)
weight_loader
(
param
,
wk
,
'k'
)
weight_loader
(
param
,
wv
,
'v'
)
continue
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/qwen2.py
View file @
7cbd9ec7
...
...
@@ -243,13 +243,21 @@ class Qwen2Model(nn.Module):
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
...
...
vllm/transformers_utils/config.py
View file @
7cbd9ec7
...
...
@@ -6,9 +6,10 @@ from transformers import GenerationConfig, PretrainedConfig
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.configs
import
(
ChatGLMConfig
,
DbrxConfig
,
JAISConfig
,
MedusaConfig
,
MLPSpeculatorConfig
,
MPTConfig
,
NemotronConfig
,
RWConfig
)
InternVLChatConfig
,
JAISConfig
,
MedusaConfig
,
MLPSpeculatorConfig
,
MPTConfig
,
NemotronConfig
,
RWConfig
)
if
VLLM_USE_MODELSCOPE
:
from
modelscope
import
AutoConfig
...
...
@@ -26,6 +27,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"jais"
:
JAISConfig
,
"mlp_speculator"
:
MLPSpeculatorConfig
,
"medusa"
:
MedusaConfig
,
"internvl_chat"
:
InternVLChatConfig
,
"nemotron"
:
NemotronConfig
,
}
...
...
vllm/transformers_utils/configs/__init__.py
View file @
7cbd9ec7
...
...
@@ -4,6 +4,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library.
from
vllm.transformers_utils.configs.falcon
import
RWConfig
from
vllm.transformers_utils.configs.internvl
import
InternVLChatConfig
from
vllm.transformers_utils.configs.jais
import
JAISConfig
from
vllm.transformers_utils.configs.medusa
import
MedusaConfig
from
vllm.transformers_utils.configs.mlp_speculator
import
MLPSpeculatorConfig
...
...
@@ -15,6 +16,7 @@ __all__ = [
"DbrxConfig"
,
"MPTConfig"
,
"RWConfig"
,
"InternVLChatConfig"
,
"JAISConfig"
,
"MedusaConfig"
,
"MLPSpeculatorConfig"
,
...
...
vllm/transformers_utils/configs/internvl.py
0 → 100644
View file @
7cbd9ec7
# Adapted from
# https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/configuration_internvl_chat.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2024 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from
transformers.configuration_utils
import
PretrainedConfig
class
InternVLChatConfig
(
PretrainedConfig
):
model_type
=
'internvl_chat'
is_composition
=
True
def
__init__
(
self
,
vision_config
=
None
,
llm_config
=
None
,
use_backbone_lora
=
0
,
use_llm_lora
=
0
,
select_layer
=-
1
,
force_image_size
=
None
,
downsample_ratio
=
0.5
,
template
=
None
,
dynamic_image_size
=
False
,
use_thumbnail
=
False
,
ps_version
=
'v1'
,
min_dynamic_patch
=
1
,
max_dynamic_patch
=
6
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
if
vision_config
is
None
:
vision_config
=
{}
if
llm_config
is
None
:
llm_config
=
{}
self
.
vision_config
=
PretrainedConfig
(
**
vision_config
)
self
.
text_config
=
PretrainedConfig
(
**
llm_config
)
self
.
use_backbone_lora
=
use_backbone_lora
self
.
use_llm_lora
=
use_llm_lora
self
.
select_layer
=
select_layer
self
.
force_image_size
=
force_image_size
self
.
downsample_ratio
=
downsample_ratio
self
.
template
=
template
self
.
dynamic_image_size
=
dynamic_image_size
self
.
use_thumbnail
=
use_thumbnail
self
.
ps_version
=
ps_version
# pixel shuffle version
self
.
min_dynamic_patch
=
min_dynamic_patch
self
.
max_dynamic_patch
=
max_dynamic_patch
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment