Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d394787e
Unverified
Commit
d394787e
authored
Sep 11, 2024
by
Patrick von Platen
Committed by
GitHub
Sep 11, 2024
Browse files
Pixtral (#8377)
Co-authored-by:
Roger Wang
<
ywang@roblox.com
>
parent
775f00f8
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
807 additions
and
9 deletions
+807
-9
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+5
-0
examples/offline_inference_pixtral.py
examples/offline_inference_pixtral.py
+164
-0
requirements-common.txt
requirements-common.txt
+1
-1
tests/models/test_pixtral.py
tests/models/test_pixtral.py
+64
-0
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+2
-1
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+2
-0
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+551
-0
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+18
-7
No files found.
docs/source/models/supported_models.rst
View file @
d394787e
...
...
@@ -247,6 +247,11 @@ Multimodal Language Models
- Image\ :sup:`E+`
- :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc.
-
* - :code:`PixtralForConditionalGeneration`
- Pixtral
- Image\ :sup:`+`
- :code:`mistralai/Pixtral-12B-2409`
-
* - :code:`QWenLMHeadModel`
- Qwen-VL
- Image\ :sup:`E`
...
...
examples/offline_inference_pixtral.py
0 → 100644
View file @
d394787e
# ruff: noqa
import
argparse
from
vllm
import
LLM
from
vllm.sampling_params
import
SamplingParams
# This script is an offline demo for running Pixtral.
#
# If you want to run a server/client setup, please follow this code:
#
# - Server:
#
# ```bash
# vllm serve mistralai/Pixtral-12B-2409 --tokenizer_mode mistral --limit_mm_per_prompt 'image=4' --max_num_batched_tokens 16384
# ```
#
# - Client:
#
# ```bash
# curl --location 'http://<your-node-url>:8000/v1/chat/completions' \
# --header 'Content-Type: application/json' \
# --header 'Authorization: Bearer token' \
# --data '{
# "model": "mistralai/Pixtral-12B-2409",
# "messages": [
# {
# "role": "user",
# "content": [
# {"type" : "text", "text": "Describe this image in detail please."},
# {"type": "image_url", "image_url": {"url": "https://s3.amazonaws.com/cms.ipressroom.com/338/files/201808/5b894ee1a138352221103195_A680%7Ejogging-edit/A680%7Ejogging-edit_hero.jpg"}},
# {"type" : "text", "text": "and this one as well. Answer in French."},
# {"type": "image_url", "image_url": {"url": "https://www.wolframcloud.com/obj/resourcesystem/images/a0e/a0ee3983-46c6-4c92-b85d-059044639928/6af8cfb971db031b.png"}}
# ]
# }
# ]
# }'
# ```
#
# Usage:
# python demo.py simple
# python demo.py advanced
def
run_simple_demo
():
model_name
=
"mistralai/Pixtral-12B-2409"
sampling_params
=
SamplingParams
(
max_tokens
=
8192
)
llm
=
LLM
(
model
=
model_name
,
tokenizer_mode
=
"mistral"
)
prompt
=
"Describe this image in one sentence."
image_url
=
"https://picsum.photos/id/237/200/300"
messages
=
[
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"text"
,
"text"
:
prompt
},
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
image_url
}
},
],
},
]
outputs
=
llm
.
chat
(
messages
,
sampling_params
=
sampling_params
)
print
(
outputs
[
0
].
outputs
[
0
].
text
)
def
run_advanced_demo
():
model_name
=
"mistralai/Pixtral-12B-2409"
max_img_per_msg
=
5
max_tokens_per_img
=
4096
sampling_params
=
SamplingParams
(
max_tokens
=
8192
,
temperature
=
0.7
)
llm
=
LLM
(
model
=
model_name
,
tokenizer_mode
=
"mistral"
,
limit_mm_per_prompt
=
{
"image"
:
max_img_per_msg
},
max_num_batched_tokens
=
max_img_per_msg
*
max_tokens_per_img
,
)
prompt
=
"Describe the following image."
url_1
=
"https://huggingface.co/datasets/patrickvonplaten/random_img/resolve/main/yosemite.png"
url_2
=
"https://picsum.photos/seed/picsum/200/300"
url_3
=
"https://picsum.photos/id/32/512/512"
messages
=
[
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"text"
,
"text"
:
prompt
},
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
url_1
}
},
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
url_2
}
},
],
},
{
"role"
:
"assistant"
,
"content"
:
"The images show nature."
,
},
{
"role"
:
"user"
,
"content"
:
"More details please and answer only in French!."
,
},
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
url_3
}
},
],
},
]
outputs
=
llm
.
chat
(
messages
=
messages
,
sampling_params
=
sampling_params
)
print
(
outputs
[
0
].
outputs
[
0
].
text
)
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Run a demo in simple or advanced mode."
)
parser
.
add_argument
(
"mode"
,
choices
=
[
"simple"
,
"advanced"
],
help
=
"Specify the demo mode: 'simple' or 'advanced'"
,
)
args
=
parser
.
parse_args
()
if
args
.
mode
==
"simple"
:
print
(
"Running simple demo..."
)
run_simple_demo
()
elif
args
.
mode
==
"advanced"
:
print
(
"Running advanced demo..."
)
run_advanced_demo
()
if
__name__
==
"__main__"
:
main
()
requirements-common.txt
View file @
d394787e
...
...
@@ -25,7 +25,7 @@ pyzmq
msgspec
gguf == 0.9.1
importlib_metadata
mistral_common >= 1.
3.4
mistral_common >= 1.
4.0
pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
einops # Required for Qwen2-VL.
tests/models/test_pixtral.py
0 → 100644
View file @
d394787e
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
Run `pytest tests/models/test_mistral.py`.
"""
import
pytest
from
vllm.sampling_params
import
SamplingParams
pytestmark
=
pytest
.
mark
.
vlm
MODELS
=
[
"mistralai/Pixtral-12B-2409"
]
@
pytest
.
mark
.
skip
(
reason
=
"Model is too big, test passed on A100 locally but will OOM on CI machine."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models
(
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
)
->
None
:
image_urls
=
[
"https://picsum.photos/id/237/200/300"
,
"https://picsum.photos/seed/picsum/200/300"
]
expected
=
[
"The image depicts a black dog lying on a wooden surface, looking directly at the camera with a calm expression."
,
# noqa
"The image depicts a serene landscape with a snow-covered mountain under a pastel-colored sky during sunset."
# noqa
]
prompt
=
"Describe the image in one short sentence."
sampling_params
=
SamplingParams
(
max_tokens
=
512
,
temperature
=
0.0
)
with
vllm_runner
(
model
,
dtype
=
dtype
,
tokenizer_mode
=
"mistral"
)
as
vllm_model
:
for
i
,
image_url
in
enumerate
(
image_urls
):
messages
=
[
{
"role"
:
"user"
,
"content"
:
[{
"type"
:
"text"
,
"text"
:
prompt
},
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
image_url
}
}]
},
]
outputs
=
vllm_model
.
model
.
chat
(
messages
,
sampling_params
=
sampling_params
)
assert
outputs
[
0
].
outputs
[
0
].
text
==
expected
[
i
]
vllm/entrypoints/chat_utils.py
View file @
d394787e
...
...
@@ -148,7 +148,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return
f
"<|image_
{
current_count
}
|>"
if
model_type
==
"minicpmv"
:
return
"(<image>./</image>)"
if
model_type
in
(
"blip-2"
,
"chatglm"
,
"fuyu"
,
"paligemma"
):
if
model_type
in
(
"blip-2"
,
"chatglm"
,
"fuyu"
,
"paligemma"
,
"pixtral"
):
# These models do not use image tokens in the prompt
return
None
if
model_type
==
"qwen"
:
...
...
vllm/model_executor/models/__init__.py
View file @
d394787e
...
...
@@ -92,6 +92,8 @@ _MULTIMODAL_MODELS = {
"Phi3VForCausalLM"
:
(
"phi3v"
,
"Phi3VForCausalLM"
),
"UltravoxModel"
:
(
"ultravox"
,
"UltravoxModel"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"PixtralForConditionalGeneration"
:
(
"pixtral"
,
"PixtralForConditionalGeneration"
),
"Qwen2VLForConditionalGeneration"
:
(
"qwen2_vl"
,
"Qwen2VLForConditionalGeneration"
),
}
...
...
vllm/model_executor/models/pixtral.py
0 → 100644
View file @
d394787e
import
math
from
array
import
array
from
dataclasses
import
dataclass
,
fields
from
itertools
import
tee
from
typing
import
Iterable
,
List
,
Mapping
,
Optional
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mistral_common.protocol.instruct.messages
import
ImageChunk
from
PIL
import
Image
from
transformers
import
PretrainedConfig
from
xformers.ops.fmha
import
memory_efficient_attention
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalMask
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SequenceData
)
from
.interfaces
import
SupportsMultiModal
from
.utils
import
init_vllm_registered_model
def
get_max_pixtral_image_tokens
(
ctx
:
InputContext
):
tokenizer
=
cached_get_tokenizer
(
ctx
.
model_config
.
tokenizer
,
tokenizer_mode
=
ctx
.
model_config
.
tokenizer_mode
)
mm_encoder
=
tokenizer
.
instruct
.
mm_encoder
max_image_size
=
mm_encoder
.
mm_config
.
max_image_size
image_patch_size
=
mm_encoder
.
mm_config
.
image_patch_size
return
((
max_image_size
//
image_patch_size
)
**
2
)
def
dummy_data_for_pixtral
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
tokenizer
=
cached_get_tokenizer
(
ctx
.
model_config
.
tokenizer
,
tokenizer_mode
=
ctx
.
model_config
.
tokenizer_mode
)
mm_encoder
=
tokenizer
.
instruct
.
mm_encoder
mm_config
=
ctx
.
model_config
.
multimodal_config
max_num_images_per_request
=
mm_config
.
limit_per_prompt
.
get
(
"image"
,
1
)
# approximate image size
size
=
int
(
math
.
sqrt
(
seq_len
)
*
mm_encoder
.
mm_config
.
image_patch_size
)
image
=
Image
.
new
(
"RGB"
,
(
size
,
size
),
color
=
0
)
img_chunk
=
ImageChunk
(
image
=
image
)
tokens
=
mm_encoder
(
img_chunk
).
tokens
token_ids
=
max_num_images_per_request
*
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
tokens
)
seq_data
=
SequenceData
(
token_ids
)
mm_data
=
{
"image"
:
max_num_images_per_request
*
[
image
]}
return
seq_data
,
mm_data
def
input_mapper_for_pixtral
(
ctx
:
InputContext
,
data
:
object
)
->
MultiModalInputs
:
"""Maps the input data to its MultiModalInputs (if any).
Args:
ctx: Context of the loaded model.
data: data potentially containing image/image embeddings to be mapped
to pixel_values in .forward() for a visual QWenLMHeadModel model.
Returns:
MultiModalInputs containing the stacked normalized images tensor or
image embeddings.
"""
# Early exit if we have provided an image to a language only Qwen model
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
tokenizer_mode
=
model_config
.
tokenizer_mode
)
data_list
=
data
if
isinstance
(
data
,
list
)
else
[
data
]
images
=
[]
for
image_data
in
data_list
:
image
=
ImageChunk
(
image
=
image_data
)
encoding
=
tokenizer
.
instruct
.
mm_encoder
(
image
)
image
=
torch
.
from_numpy
(
encoding
.
image
).
to
(
device
=
"cuda"
,
dtype
=
torch
.
float16
)
images
.
append
(
image
)
return
MultiModalInputs
({
"images"
:
images
})
def
merge_multimodal_embeddings
(
input_ids
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
,
image_features
:
Optional
[
List
[
torch
.
Tensor
]],
image_id
:
int
)
->
torch
.
Tensor
:
text_locations
=
input_ids
!=
image_id
image_locations
=
input_ids
==
image_id
seq_len
=
input_ids
.
shape
[
0
]
N_txt
=
text_locations
.
sum
().
item
()
_
,
D_txt
=
inputs_embeds
.
shape
N_img
,
D_img
=
image_features
.
shape
assert
(
D_txt
==
D_img
),
(
f
"Text features dim
{
D_txt
}
should be equal "
"to image features dim {D_img}"
)
assert
(
seq_len
==
N_txt
+
N_img
),
(
f
"seq_len
{
seq_len
}
should be equal to N_txt + N_img "
f
"
{
(
N_txt
,
N_img
,
image_locations
.
sum
().
item
())
}
"
)
inputs_embeds
[
image_locations
,
:]
=
image_features
return
inputs_embeds
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
(
input_mapper_for_pixtral
)
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_pixtral_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_pixtral
)
class
PixtralForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
multimodal_config
:
MultiModalConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
dataclass_fields
=
{
field
.
name
for
field
in
fields
(
VisionEncoderArgs
)}
vision_args
=
{
key
:
value
for
key
,
value
in
self
.
config
.
vision_config
.
to_dict
().
items
()
if
key
in
dataclass_fields
}
self
.
vision_args
=
VisionEncoderArgs
(
**
vision_args
)
# init MistralForCausalLM
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
cache_config
,
quant_config
)
self
.
vision_encoder
=
VisionTransformer
(
self
.
vision_args
)
self
.
vision_language_adapter
=
VisionLanguageAdapter
(
self
.
vision_args
,
dim
=
config
.
text_config
.
hidden_size
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
**
kwargs
:
object
,
)
->
SamplerOutput
:
"""Run forward pass for pixtral.
TODO
"""
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
not
None
:
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
vision_args
.
image_token_id
)
input_ids
=
None
else
:
inputs_embeds
=
None
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
None
,
inputs_embeds
=
inputs_embeds
)
return
hidden_states
def
_parse_and_validate_image_input
(
self
,
images
:
Optional
[
Union
[
List
[
List
[
torch
.
Tensor
]],
List
[
torch
.
Tensor
],
torch
.
Tensor
]]
=
None
)
->
Optional
[
List
[
torch
.
Tensor
]]:
if
images
is
None
:
return
None
if
isinstance
(
images
,
torch
.
Tensor
):
# always take last images
images
=
[
images
[
-
1
][
i
]
for
i
in
range
(
images
.
size
(
1
))]
elif
isinstance
(
images
,
list
):
# always take last images
images
=
[
images
[
-
1
][
i
]
for
i
in
range
(
len
(
images
[
0
]))]
return
images
def
_process_image_input
(
self
,
image_input
:
List
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
self
.
vision_language_adapter
(
self
.
vision_encoder
(
image_input
))
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
is_vision_encoder_weights
(
weight
:
Tuple
[
str
,
torch
.
Tensor
]):
return
weight
[
0
].
startswith
(
"vision_encoder"
)
def
is_vision_lang_adapter_weights
(
weight
:
Tuple
[
str
,
torch
.
Tensor
]):
return
weight
[
0
].
startswith
(
"vision_language_adapter"
)
def
is_vision_weights
(
weight
:
Tuple
[
str
,
torch
.
Tensor
]):
return
is_vision_encoder_weights
(
weight
)
or
is_vision_lang_adapter_weights
(
weight
)
llm_weights
,
vision_encoder_weights
,
vision_lang_adapter_weights
=
tee
(
weights
,
3
)
# llm
llm_weights
=
filter
(
lambda
x
:
not
is_vision_weights
(
x
),
llm_weights
)
self
.
language_model
.
load_weights
(
llm_weights
)
# vision encoder
vision_encoder_weights
=
filter
(
is_vision_encoder_weights
,
vision_encoder_weights
)
vision_encoder_dict
=
dict
(
self
.
vision_encoder
.
named_parameters
())
for
name
,
loaded_weight
in
vision_encoder_weights
:
# cut 'vision_encoder.'
name
=
'.'
.
join
(
name
.
split
(
"."
)[
1
:])
param
=
vision_encoder_dict
[
name
]
default_weight_loader
(
param
,
loaded_weight
)
# adapter
vision_lang_adapter_weights
=
filter
(
is_vision_lang_adapter_weights
,
vision_lang_adapter_weights
)
vision_lang_adpter_dict
=
dict
(
self
.
vision_language_adapter
.
named_parameters
())
for
name
,
loaded_weight
in
vision_lang_adapter_weights
:
# cut 'vision_language_adapter.'
name
=
'.'
.
join
(
name
.
split
(
"."
)[
1
:])
param
=
vision_lang_adpter_dict
[
name
]
default_weight_loader
(
param
,
loaded_weight
)
# Vision encoder
@
dataclass
class
VisionEncoderArgs
:
hidden_size
:
int
num_channels
:
int
image_size
:
int
patch_size
:
int
intermediate_size
:
int
num_hidden_layers
:
int
num_attention_heads
:
int
rope_theta
:
float
# for rope-2D
image_token_id
:
int
def
_reshape_for_broadcast
(
freqs_cis
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
freqs_cis: complex - (seq_len, head_dim / 2)
x: complex - (bsz, seq_len, head_dim / 2)
"""
ndim
=
x
.
ndim
assert
ndim
>
1
assert
freqs_cis
.
shape
==
(
x
.
shape
[
1
],
x
.
shape
[
-
1
]),
(
freqs_cis
.
shape
,
(
x
.
shape
[
1
],
x
.
shape
[
-
1
]),
)
shape
=
[
d
if
i
==
1
or
i
==
ndim
-
1
else
1
for
i
,
d
in
enumerate
(
x
.
shape
)
]
return
freqs_cis
.
view
(
*
shape
)
def
precompute_freqs_cis_2d
(
dim
:
int
,
height
:
int
,
width
:
int
,
theta
:
float
,
)
->
torch
.
Tensor
:
"""
freqs_cis: 2D complex tensor of shape (height, width, dim // 2)
to be indexed by (height, width) position tuples
"""
# (dim / 2) frequency bases
freqs
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
()
/
dim
))
h
=
torch
.
arange
(
height
,
device
=
freqs
.
device
)
w
=
torch
.
arange
(
width
,
device
=
freqs
.
device
)
freqs_h
=
torch
.
outer
(
h
,
freqs
[::
2
]).
float
()
freqs_w
=
torch
.
outer
(
w
,
freqs
[
1
::
2
]).
float
()
freqs_2d
=
torch
.
cat
(
[
freqs_h
[:,
None
,
:].
repeat
(
1
,
width
,
1
),
freqs_w
[
None
,
:,
:].
repeat
(
height
,
1
,
1
),
],
dim
=-
1
,
)
return
torch
.
polar
(
torch
.
ones_like
(
freqs_2d
),
freqs_2d
)
def
apply_rotary_emb_vit
(
xq
:
torch
.
Tensor
,
xk
:
torch
.
Tensor
,
freqs_cis
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
xq_
=
torch
.
view_as_complex
(
xq
.
float
().
reshape
(
*
xq
.
shape
[:
-
1
],
-
1
,
2
))
xk_
=
torch
.
view_as_complex
(
xk
.
float
().
reshape
(
*
xk
.
shape
[:
-
1
],
-
1
,
2
))
assert
freqs_cis
.
dtype
==
torch
.
complex64
freqs_cis
=
_reshape_for_broadcast
(
freqs_cis
,
xq_
)
xq_out
=
torch
.
view_as_real
(
xq_
*
freqs_cis
).
flatten
(
3
)
xk_out
=
torch
.
view_as_real
(
xk_
*
freqs_cis
).
flatten
(
3
)
return
xq_out
.
type_as
(
xq
),
xk_out
.
type_as
(
xk
)
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
args
:
VisionEncoderArgs
):
super
().
__init__
()
assert
args
.
intermediate_size
is
not
None
self
.
w1
=
nn
.
Linear
(
args
.
hidden_size
,
args
.
intermediate_size
,
bias
=
False
)
self
.
w2
=
nn
.
Linear
(
args
.
intermediate_size
,
args
.
hidden_size
,
bias
=
False
)
self
.
w3
=
nn
.
Linear
(
args
.
hidden_size
,
args
.
intermediate_size
,
bias
=
False
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
w2
(
F
.
silu
(
self
.
w1
(
x
))
*
self
.
w3
(
x
))
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
args
:
VisionEncoderArgs
):
super
().
__init__
()
self
.
args
=
args
assert
not
args
.
hidden_size
%
args
.
num_attention_heads
self
.
n_heads
=
args
.
num_attention_heads
self
.
head_dim
=
args
.
hidden_size
//
args
.
num_attention_heads
self
.
wq
=
nn
.
Linear
(
args
.
hidden_size
,
args
.
hidden_size
,
bias
=
False
)
self
.
wk
=
nn
.
Linear
(
args
.
hidden_size
,
args
.
hidden_size
,
bias
=
False
)
self
.
wv
=
nn
.
Linear
(
args
.
hidden_size
,
args
.
hidden_size
,
bias
=
False
)
self
.
wo
=
nn
.
Linear
(
args
.
hidden_size
,
args
.
hidden_size
,
bias
=
False
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
mask
:
BlockDiagonalMask
,
freqs_cis
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
batch
,
patches
,
_
=
x
.
shape
q
,
k
,
v
=
self
.
wq
(
x
),
self
.
wk
(
x
),
self
.
wv
(
x
)
q
=
q
.
reshape
(
batch
,
patches
,
self
.
n_heads
,
self
.
head_dim
)
k
=
k
.
reshape
(
batch
,
patches
,
self
.
n_heads
,
self
.
head_dim
)
v
=
v
.
reshape
(
batch
,
patches
,
self
.
n_heads
,
self
.
head_dim
)
q
,
k
=
apply_rotary_emb_vit
(
q
,
k
,
freqs_cis
=
freqs_cis
)
out
=
memory_efficient_attention
(
q
,
k
,
v
,
attn_bias
=
mask
)
out
=
out
.
reshape
(
batch
,
patches
,
self
.
n_heads
*
self
.
head_dim
)
return
self
.
wo
(
out
)
class
TransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
args
:
VisionEncoderArgs
):
super
().
__init__
()
self
.
attention
=
Attention
(
args
)
self
.
feed_forward
=
FeedForward
(
args
)
self
.
attention_norm
=
RMSNorm
(
args
.
hidden_size
,
eps
=
1e-5
)
self
.
ffn_norm
=
RMSNorm
(
args
.
hidden_size
,
eps
=
1e-5
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
mask
:
BlockDiagonalMask
,
freqs_cis
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
r
=
self
.
attention
.
forward
(
self
.
attention_norm
(
x
),
mask
=
mask
,
freqs_cis
=
freqs_cis
)
h
=
x
+
r
r
=
self
.
feed_forward
.
forward
(
self
.
ffn_norm
(
h
))
out
=
h
+
r
return
out
class
Transformer
(
nn
.
Module
):
def
__init__
(
self
,
args
:
VisionEncoderArgs
):
super
().
__init__
()
self
.
layers
=
torch
.
nn
.
ModuleList
()
for
_
in
range
(
args
.
num_hidden_layers
):
self
.
layers
.
append
(
TransformerBlock
(
args
))
def
forward
(
self
,
x
:
torch
.
Tensor
,
mask
:
BlockDiagonalMask
,
freqs_cis
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
for
layer
in
self
.
layers
:
x
=
layer
(
x
,
mask
=
mask
,
freqs_cis
=
freqs_cis
)
return
x
def
position_meshgrid
(
patch_embeds_list
:
list
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
positions
=
torch
.
cat
([
torch
.
stack
(
torch
.
meshgrid
(
torch
.
arange
(
p
.
shape
[
-
2
]),
torch
.
arange
(
p
.
shape
[
-
1
]),
indexing
=
"ij"
,
),
dim
=-
1
,
).
reshape
(
-
1
,
2
)
for
p
in
patch_embeds_list
])
return
positions
class
VisionTransformer
(
nn
.
Module
):
def
__init__
(
self
,
args
:
VisionEncoderArgs
):
super
().
__init__
()
self
.
args
=
args
self
.
patch_conv
=
nn
.
Conv2d
(
in_channels
=
args
.
num_channels
,
out_channels
=
args
.
hidden_size
,
kernel_size
=
args
.
patch_size
,
stride
=
args
.
patch_size
,
bias
=
False
,
)
self
.
ln_pre
=
RMSNorm
(
args
.
hidden_size
,
eps
=
1e-5
)
self
.
transformer
=
Transformer
(
args
)
head_dim
=
self
.
args
.
hidden_size
//
self
.
args
.
num_attention_heads
assert
head_dim
%
2
==
0
,
"ROPE requires even head_dim"
self
.
_freqs_cis
:
Optional
[
torch
.
Tensor
]
=
None
@
property
def
max_patches_per_side
(
self
)
->
int
:
return
self
.
args
.
image_size
//
self
.
args
.
patch_size
@
property
def
device
(
self
)
->
torch
.
device
:
return
next
(
self
.
parameters
()).
device
@
property
def
dtype
(
self
)
->
torch
.
device
:
return
next
(
self
.
parameters
()).
dtype
@
property
def
freqs_cis
(
self
)
->
torch
.
Tensor
:
if
self
.
_freqs_cis
is
None
:
self
.
_freqs_cis
=
precompute_freqs_cis_2d
(
dim
=
self
.
args
.
hidden_size
//
self
.
args
.
num_attention_heads
,
height
=
self
.
max_patches_per_side
,
width
=
self
.
max_patches_per_side
,
theta
=
self
.
args
.
rope_theta
,
)
if
self
.
_freqs_cis
.
device
!=
self
.
device
:
self
.
_freqs_cis
=
self
.
_freqs_cis
.
to
(
device
=
self
.
device
)
return
self
.
_freqs_cis
def
forward
(
self
,
images
:
List
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
"""
Args:
images: list of N_img images of variable sizes,
each of shape (C, H, W)
Returns:
image_features: tensor of token features for
all tokens of all images of shape (N_toks, D)
"""
# pass images through initial convolution independently
patch_embeds_list
=
[
self
.
patch_conv
(
img
.
unsqueeze
(
0
).
to
(
self
.
dtype
))
for
img
in
images
]
# flatten to a single sequence
patch_embeds
=
torch
.
cat
(
[
p
.
flatten
(
2
).
permute
(
0
,
2
,
1
)
for
p
in
patch_embeds_list
],
dim
=
1
)
patch_embeds
=
self
.
ln_pre
(
patch_embeds
)
# positional embeddings
positions
=
position_meshgrid
(
patch_embeds_list
).
to
(
self
.
device
)
freqs_cis
=
self
.
freqs_cis
[
positions
[:,
0
],
positions
[:,
1
]]
# pass through Transformer with a block diagonal mask delimiting images
mask
=
BlockDiagonalMask
.
from_seqlens
(
[
p
.
shape
[
-
2
]
*
p
.
shape
[
-
1
]
for
p
in
patch_embeds_list
],
)
out
=
self
.
transformer
(
patch_embeds
,
mask
=
mask
,
freqs_cis
=
freqs_cis
)
# remove batch dimension of the single sequence
return
out
.
squeeze
(
0
)
class
VisionLanguageAdapter
(
nn
.
Module
):
def
__init__
(
self
,
args
:
VisionEncoderArgs
,
dim
:
int
):
super
().
__init__
()
assert
isinstance
(
args
,
VisionEncoderArgs
)
self
.
w_in
=
nn
.
Linear
(
args
.
hidden_size
,
dim
,
bias
=
True
,
)
self
.
gelu
=
nn
.
GELU
()
self
.
w_out
=
nn
.
Linear
(
dim
,
dim
,
bias
=
True
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
w_out
(
self
.
gelu
(
self
.
w_in
(
x
)))
vllm/transformers_utils/config.py
View file @
d394787e
...
...
@@ -70,7 +70,7 @@ def file_or_path_exists(model: Union[str, Path], config_name, revision,
if
Path
(
model
).
exists
():
return
(
Path
(
model
)
/
config_name
).
is_file
()
return
file_exists
(
model
,
HF_CONFIG_NAME
,
revision
=
revision
,
token
=
token
)
return
file_exists
(
model
,
config_name
,
revision
=
revision
,
token
=
token
)
def
get_config
(
...
...
@@ -205,14 +205,25 @@ def load_params_config(model, revision) -> PretrainedConfig:
config_dict
[
"hidden_act"
]
=
config_dict
.
get
(
"activation"
,
"silu"
)
config_dict
[
"tie_word_embeddings"
]
=
config_dict
.
get
(
"tie_embeddings"
,
False
)
config_dict
[
"max_seq_len"
]
=
config_dict
.
get
(
"max_seq_len"
,
128_000
)
if
config_dict
[
"model_type"
]
==
"transformer"
:
if
"moe"
in
config_dict
:
if
config_dict
.
get
(
"moe"
)
is
not
None
:
config_dict
[
"architectures"
]
=
[
"MixtralForCausalLM"
]
else
:
config_dict
[
"architectures"
]
=
[
"MistralForCausalLM"
]
return
recurse_elems
(
config_dict
)
if
config_dict
.
get
(
"vision_encoder"
)
is
not
None
:
multimodal_config
=
config_dict
.
pop
(
"vision_encoder"
)
config_dict
=
{
"text_config"
:
config_dict
,
"vision_config"
:
multimodal_config
}
config_dict
[
"architectures"
]
=
[
"PixtralForConditionalGeneration"
]
config_dict
[
"model_type"
]
=
"pixtral"
config
=
recurse_elems
(
config_dict
)
return
config
def
get_hf_image_processor_config
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment