Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
539aa992
"vllm/vscode:/vscode.git/clone" did not exist on "1dc8a70b6d4e8ba4e139f1ddb86a166694f42f21"
Commit
539aa992
authored
Sep 27, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.2' into v0.6.2-dev
parents
93872128
7193774b
Changes
383
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2640 additions
and
313 deletions
+2640
-313
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+111
-38
vllm/model_executor/models/mllama.py
vllm/model_executor/models/mllama.py
+1142
-0
vllm/model_executor/models/olmoe.py
vllm/model_executor/models/olmoe.py
+409
-0
vllm/model_executor/models/paligemma.py
vllm/model_executor/models/paligemma.py
+7
-10
vllm/model_executor/models/persimmon.py
vllm/model_executor/models/persimmon.py
+6
-6
vllm/model_executor/models/phi3.py
vllm/model_executor/models/phi3.py
+17
-0
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+22
-9
vllm/model_executor/models/phimoe.py
vllm/model_executor/models/phimoe.py
+6
-2
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+6
-10
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+14
-11
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+15
-7
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+48
-20
vllm/model_executor/models/siglip.py
vllm/model_executor/models/siglip.py
+28
-14
vllm/model_executor/models/solar.py
vllm/model_executor/models/solar.py
+580
-0
vllm/model_executor/models/ultravox.py
vllm/model_executor/models/ultravox.py
+53
-22
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+35
-1
vllm/model_executor/parameter.py
vllm/model_executor/parameter.py
+58
-0
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+58
-153
vllm/model_executor/utils.py
vllm/model_executor/utils.py
+3
-7
vllm/multimodal/base.py
vllm/multimodal/base.py
+22
-3
No files found.
vllm/model_executor/models/minicpmv.py
View file @
539aa992
...
...
@@ -23,7 +23,6 @@
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
import
math
import
re
from
array
import
array
from
functools
import
partial
from
typing
import
(
Any
,
Callable
,
Iterable
,
List
,
Mapping
,
Optional
,
Tuple
,
TypedDict
)
...
...
@@ -34,11 +33,11 @@ from PIL import Image
from
torch
import
nn
from
torch.nn.init
import
trunc_normal_
from
transformers
import
PretrainedConfig
from
typing_extensions
import
NotRequired
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
...
...
@@ -54,21 +53,30 @@ from vllm.model_executor.models.minicpm import MiniCPMModel
from
vllm.model_executor.models.qwen2
import
Qwen2Model
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.image
import
cached_get_image_processor
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SequenceData
)
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
from
.idefics2_vision_model
import
Idefics2VisionTransformer
logger
=
init_logger
(
__name__
)
_KEYS_TO_MODIFY_MAPPING
=
{
"llm.lm_head"
:
"lm_head"
,
"llm.model"
:
"llm"
,
}
class
MiniCPMVImageInput
(
TypedDict
):
"""Input mapper input with auxiliary data for computing image bounds."""
image
:
Image
.
Image
# Image bounds token ids in 0-dim scaler tensor.
im_start_id
:
torch
.
Tensor
im_end_id
:
torch
.
Tensor
slice_start_id
:
NotRequired
[
torch
.
Tensor
]
slice_end_id
:
NotRequired
[
torch
.
Tensor
]
class
MiniCPMVImagePixelInputs
(
TypedDict
):
pixel_values
:
List
[
torch
.
Tensor
]
"""
...
...
@@ -93,8 +101,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
"""
MiniCPMVImageInputs
=
MiniCPMVImagePixelInputs
DEFAULT_LN
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
...
...
@@ -239,6 +245,25 @@ class Resampler2_5(BaseResampler):
return
x
def
_build_image_input
(
ctx
:
InputContext
,
image
:
Image
.
Image
)
->
MiniCPMVImageInput
:
tokenizer
=
cached_get_tokenizer
(
ctx
.
model_config
.
tokenizer
,
trust_remote_code
=
ctx
.
model_config
.
trust_remote_code
)
if
hasattr
(
tokenizer
,
"slice_start_id"
):
return
MiniCPMVImageInput
(
image
=
image
,
im_start_id
=
torch
.
tensor
(
tokenizer
.
im_start_id
),
im_end_id
=
torch
.
tensor
(
tokenizer
.
im_end_id
),
slice_start_id
=
torch
.
tensor
(
tokenizer
.
slice_start_id
),
slice_end_id
=
torch
.
tensor
(
tokenizer
.
slice_end_id
))
else
:
return
MiniCPMVImageInput
(
image
=
image
,
im_start_id
=
torch
.
tensor
(
tokenizer
.
im_start_id
),
im_end_id
=
torch
.
tensor
(
tokenizer
.
im_end_id
))
def
get_version_by_config
(
config
:
PretrainedConfig
)
->
Tuple
[
int
,
...]:
version_float
=
getattr
(
config
,
"version"
,
None
)
...
...
@@ -259,14 +284,16 @@ def get_max_minicpmv_image_tokens(ctx: InputContext):
def
dummy_seq_data_for_minicpmv
(
seq_len
:
int
,
num_images
:
int
):
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
seq_len
return
SequenceData
(
token_ids
)
return
SequenceData
.
from_token_counts
((
0
,
seq_len
))
def
dummy_image_for_minicpmv
(
hf_config
:
PretrainedConfig
,
num_images
:
int
):
def
dummy_image_for_minicpmv
(
ctx
:
InputContext
,
hf_config
:
PretrainedConfig
,
num_images
:
int
):
width
=
height
=
hf_config
.
image_size
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
)
return
{
"image"
:
image
if
num_images
==
1
else
[
image
]
*
num_images
}
image
=
_build_image_input
(
ctx
,
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
))
return
{
"image"
:
[
image
]
if
num_images
==
1
else
[
image
]
*
num_images
}
def
dummy_data_for_minicpmv
(
ctx
:
InputContext
,
seq_len
:
int
,
...
...
@@ -275,7 +302,7 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
num_images
=
mm_counts
[
"image"
]
seq_data
=
dummy_seq_data_for_minicpmv
(
seq_len
,
num_images
)
mm_data
=
dummy_image_for_minicpmv
(
hf_config
,
num_images
)
mm_data
=
dummy_image_for_minicpmv
(
ctx
,
hf_config
,
num_images
)
return
seq_data
,
mm_data
...
...
@@ -286,8 +313,9 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
return
llm_inputs
model_config
=
ctx
.
model_config
version
=
get_version_by_config
(
model_config
.
hf_config
)
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
)
image_processor
=
cached_get_image_processor
(
model_config
.
tokenizer
)
def
get_placeholder
(
image_size
:
Tuple
[
int
,
int
],
num_image
:
int
):
...
...
@@ -323,6 +351,10 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
new_prompt
=
""
.
join
(
new_prompt_chunks
)
new_token_ids
=
tokenizer
.
encode
(
new_prompt
)
multi_modal_data
[
"image"
]
=
[
_build_image_input
(
ctx
,
image
)
for
image
in
images
]
llm_inputs
=
LLMInputs
(
prompt_token_ids
=
new_token_ids
,
prompt
=
new_prompt
,
...
...
@@ -331,6 +363,32 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
return
llm_inputs
def
input_mapper_for_minicpmv
(
ctx
:
InputContext
,
data
:
object
):
model_config
=
ctx
.
model_config
image_processor
=
cached_get_image_processor
(
model_config
.
model
,
trust_remote_code
=
model_config
.
trust_remote_code
)
if
image_processor
is
None
:
raise
RuntimeError
(
"No HuggingFace processor is available "
"to process the image object"
)
if
not
isinstance
(
data
,
list
):
raise
ValueError
(
"Image input must be list of MiniCPMVImageInput, got (%s)"
,
data
)
batch_data
=
image_processor
\
.
preprocess
([
img
[
"image"
]
for
img
in
data
],
return_tensors
=
"pt"
)
\
.
data
if
len
(
data
)
>
0
:
batch_data
[
"im_start_id"
]
=
data
[
0
][
"im_start_id"
]
batch_data
[
"im_end_id"
]
=
data
[
0
][
"im_end_id"
]
if
"slice_start_id"
in
data
[
0
]:
batch_data
[
"slice_start_id"
]
=
data
[
0
][
"slice_start_id"
]
batch_data
[
"slice_end_id"
]
=
data
[
0
][
"slice_end_id"
]
return
MultiModalInputs
(
batch_data
)
class
MiniCPMVBaseModel
(
nn
.
Module
,
SupportsMultiModal
):
"""
The abstract class of MiniCPMV can only be inherited, but cannot be
...
...
@@ -371,7 +429,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
def
get_embedding
(
self
,
input_ids
:
torch
.
Tensor
,
image_inputs
:
Optional
[
MiniCPMVImageInputs
],
image_inputs
:
Optional
[
MiniCPMVImage
Pixel
Inputs
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
vlm_embedding
:
torch
.
Tensor
=
self
.
llm
.
embed_tokens
(
input_ids
)
if
hasattr
(
self
.
config
,
"scale_emb"
):
...
...
@@ -399,14 +457,20 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
return
vlm_embedding
,
vision_hidden_states
def
_get_image_bounds
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tokenizer
=
cached_get_tokenizer
(
self
.
config
.
_name_or_path
,
trust_remote_code
=
True
)
start_cond
=
input_ids
==
tokenizer
.
im_start_id
end_cond
=
input_ids
==
tokenizer
.
im_end_id
if
hasattr
(
tokenizer
,
"slice_start_id"
):
start_cond
|=
(
input_ids
==
tokenizer
.
slice_start_id
)
end_cond
|=
(
input_ids
==
tokenizer
.
slice_end_id
)
def
_get_image_bounds
(
self
,
input_ids
:
torch
.
Tensor
,
im_start_id
:
torch
.
Tensor
,
im_end_id
:
torch
.
Tensor
,
slice_start_id
:
Optional
[
torch
.
Tensor
]
=
None
,
slice_end_id
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
# All the images in the batch should share the same special image
# bound token ids.
start_cond
=
input_ids
==
im_start_id
[
0
]
end_cond
=
input_ids
==
im_end_id
[
0
]
if
slice_start_id
is
not
None
:
start_cond
|=
(
input_ids
==
slice_start_id
[
0
])
end_cond
|=
(
input_ids
==
slice_end_id
[
0
])
image_start_tokens
,
=
torch
.
where
(
start_cond
)
image_start_tokens
+=
1
...
...
@@ -425,7 +489,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
self
,
input_ids
:
torch
.
Tensor
,
**
kwargs
:
object
,
)
->
Optional
[
MiniCPMVImageInputs
]:
)
->
Optional
[
MiniCPMVImage
Pixel
Inputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
[])
tgt_sizes
=
kwargs
.
pop
(
"tgt_sizes"
,
[])
...
...
@@ -462,8 +526,17 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
if
len
(
pixel_values_flat
)
==
0
:
return
None
return
MiniCPMVImageInputs
(
image_bounds
=
self
.
_get_image_bounds
(
input_ids
),
im_start_id
=
kwargs
.
pop
(
"im_start_id"
,
None
)
im_end_id
=
kwargs
.
pop
(
"im_end_id"
,
None
)
slice_start_id
=
kwargs
.
pop
(
"slice_start_id"
,
None
)
slice_end_id
=
kwargs
.
pop
(
"slice_end_id"
,
None
)
if
im_start_id
is
None
:
return
None
return
MiniCPMVImagePixelInputs
(
image_bounds
=
self
.
_get_image_bounds
(
input_ids
,
im_start_id
,
im_end_id
,
slice_start_id
,
slice_end_id
),
pixel_values
=
pixel_values_flat
,
tgt_sizes
=
torch
.
stack
(
tgt_sizes_flat
),
)
...
...
@@ -570,8 +643,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
get_vision_hidden_states
(
self
,
data
:
MiniCPMVImageInputs
)
->
torch
.
Tensor
:
def
get_vision_hidden_states
(
self
,
data
:
MiniCPMVImage
Pixel
Inputs
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
is_default_weight_loading
(
self
,
name
:
str
)
->
bool
:
...
...
@@ -660,8 +733,8 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
res
.
append
(
self
.
resampler
(
vision_embedding
,
tgt_size
))
return
torch
.
vstack
(
res
)
def
get_vision_hidden_states
(
self
,
data
:
MiniCPMVImageInputs
)
->
torch
.
Tensor
:
def
get_vision_hidden_states
(
self
,
data
:
MiniCPMVImage
Pixel
Inputs
)
->
torch
.
Tensor
:
pixel_values
=
data
[
"pixel_values"
]
return
self
.
get_vision_embedding
(
pixel_values
)
...
...
@@ -719,8 +792,8 @@ class MiniCPMV2_5(MiniCPMVBaseModel):
vision_embedding
=
self
.
resampler
(
vision_embedding
,
tgt_sizes
)
return
vision_embedding
def
get_vision_hidden_states
(
self
,
data
:
MiniCPMVImageInputs
)
->
torch
.
Tensor
:
def
get_vision_hidden_states
(
self
,
data
:
MiniCPMVImage
Pixel
Inputs
)
->
torch
.
Tensor
:
pixel_values
=
data
[
"pixel_values"
]
tgt_sizes
=
data
[
"tgt_sizes"
]
...
...
@@ -813,8 +886,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
).
last_hidden_state
return
vision_embedding
def
get_vision_hidden_states
(
self
,
data
:
MiniCPMVImageInputs
)
->
torch
.
Tensor
:
def
get_vision_hidden_states
(
self
,
data
:
MiniCPMVImage
Pixel
Inputs
)
->
torch
.
Tensor
:
pixel_values
=
data
[
"pixel_values"
]
tgt_sizes
=
data
[
"tgt_sizes"
]
...
...
@@ -857,7 +930,7 @@ _SUPPORT_VERSION = {
}
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
()
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
(
input_mapper_for_minicpmv
)
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_minicpmv_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_minicpmv
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_minicpmv
)
...
...
@@ -884,7 +957,7 @@ class MiniCPMV(MiniCPMVBaseModel):
version
=
str
(
config
.
version
).
split
(
"."
)
version
=
tuple
([
int
(
x
)
for
x
in
version
])
# Dispatch class based on version
instance_class
=
_SUPPORT_VERSION
.
get
(
version
,
None
)
instance_class
=
_SUPPORT_VERSION
.
get
(
version
)
if
instance_class
is
None
:
raise
ValueError
(
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6"
)
...
...
vllm/model_executor/models/mllama.py
0 → 100644
View file @
539aa992
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Mllama model."""
import
math
from
array
import
array
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
import
torch
import
torch.nn.functional
as
F
import
torch.utils.checkpoint
import
transformers.models.mllama.configuration_mllama
as
config_mllama
from
PIL
import
Image
from
torch
import
nn
from
transformers.modeling_outputs
import
(
BaseModelOutput
,
CausalLMOutputWithPast
)
from
transformers.models.mllama.image_processing_mllama
import
(
get_optimal_tiled_canvas
)
import
vllm.distributed.parallel_state
as
ps
from
vllm.attention
import
Attention
,
AttentionMetadata
,
AttentionType
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
from
.clip
import
CLIPMLP
from
.interfaces
import
SupportsMultiModal
from
.llama
import
LlamaDecoderLayer
,
LlamaMLP
logger
=
init_logger
(
__name__
)
MLLAMA_IMAGE_TOKEN_ID
=
128256
MLLAMA_IMAGE_TOKEN
=
"<|image|>"
class
MllamaImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
data
:
torch
.
Tensor
"""Shape: """
"""(batch_size, max_num_image, max_num_chunk, num_channel, height, width)"""
aspect_ratio_ids
:
torch
.
Tensor
"""Shape: `(batch_size, max_num_image)`"""
aspect_ratio_mask
:
torch
.
Tensor
"""Shape: `(batch_size, max_num_image, max_num_tiles)`"""
# TODO: support LlamaImageEmbeddingInputs
def
input_processor_for_mllama
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
):
# move encoder_prompt to prompt
if
llm_inputs
.
get
(
"prompt"
)
is
None
:
llm_inputs
[
"prompt"
]
=
llm_inputs
[
"encoder_prompt"
]
llm_inputs
[
"prompt_token_ids"
]
=
llm_inputs
[
"encoder_prompt_token_ids"
]
# process multi-modal data
assert
"decoder_multi_modal_data"
not
in
llm_inputs
,
\
"multi-modal data should be put in encoder message of mllama"
multi_modal_data
=
llm_inputs
.
get
(
"encoder_multi_modal_data"
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
\
or
multi_modal_data
[
"image"
]
is
None
:
# text-only
llm_inputs
[
"encoder_prompt"
]
=
""
llm_inputs
[
"encoder_prompt_token_ids"
]
=
[]
llm_inputs
[
"encoder_multi_modal_data"
]
=
{}
return
llm_inputs
# get num_tiles
if
isinstance
(
multi_modal_data
[
'image'
],
Image
.
Image
):
multi_modal_data
[
'image'
]
=
[
multi_modal_data
[
'image'
]]
hf_config
=
ctx
.
model_config
.
hf_config
num_tiles
=
0
for
image
in
multi_modal_data
[
"image"
]:
width
,
height
=
image
.
size
tile_size
=
hf_config
.
vision_config
.
image_size
canvas_height
,
canvas_width
=
get_optimal_tiled_canvas
(
image_height
=
height
,
image_width
=
width
,
max_image_tiles
=
hf_config
.
vision_config
.
max_num_tiles
,
tile_size
=
tile_size
,
)
num_tiles_height
=
canvas_height
//
tile_size
num_tiles_width
=
canvas_width
//
tile_size
num_tiles
+=
num_tiles_height
*
num_tiles_width
# set encoder prompt based on num_tiles
assert
hf_config
.
vision_config
.
image_size
%
14
==
0
,
\
"chunk size should be multiple of 14"
token_per_chunk
=
(
hf_config
.
vision_config
.
image_size
//
14
)
**
2
+
1
num_tokens
=
num_tiles
*
token_per_chunk
llm_inputs
[
"encoder_prompt"
]
=
MLLAMA_IMAGE_TOKEN
*
num_tokens
llm_inputs
[
"encoder_prompt_token_ids"
]
=
[
MLLAMA_IMAGE_TOKEN_ID
]
*
num_tokens
return
llm_inputs
def
get_max_mllama_image_tokens
(
ctx
:
InputContext
)
->
int
:
hf_config
=
ctx
.
model_config
.
hf_config
token_per_chunk
=
(
hf_config
.
vision_config
.
image_size
//
14
)
**
2
+
1
return
hf_config
.
vision_config
.
max_num_tiles
*
token_per_chunk
def
dummy_decoder_seq_data
(
seq_len
:
int
,
num_images
:
int
):
# <|image|> * num_images + 0 * (seq_len - num_images)
assert
seq_len
>=
num_images
,
\
"seq_len should be greater than or equal to num_images"
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
MLLAMA_IMAGE_TOKEN_ID
])
*
num_images
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
num_images
)
return
SequenceData
(
token_ids
)
def
dummy_encoder_seq_data
(
ctx
:
InputContext
,
num_images
:
int
):
num_tokens
=
get_max_mllama_image_tokens
(
ctx
)
*
num_images
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
MLLAMA_IMAGE_TOKEN_ID
])
*
num_tokens
return
SequenceData
(
token_ids
)
def
dummy_image
(
num_images
:
int
,
):
width
=
height
=
1024
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
)
return
{
"image"
:
image
if
num_images
==
1
else
[
image
]
*
num_images
}
def
dummy_decoder_data_for_mllama
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
num_images
=
mm_counts
[
"image"
]
return
dummy_decoder_seq_data
(
seq_len
,
num_images
),
None
def
dummy_encoder_data_for_mllama
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
num_images
=
mm_counts
[
"image"
]
return
dummy_encoder_seq_data
(
ctx
,
num_images
),
dummy_image
(
num_images
)
def
_prepare_aspect_ratio_attention_mask
(
aspect_ratio_mask
:
torch
.
Tensor
,
num_patches
:
int
,
target_length
:
int
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
# Expand aspect ratio mask to target_length
batch_size
,
max_num_tiles
=
aspect_ratio_mask
.
shape
attention_mask
=
aspect_ratio_mask
.
view
(
batch_size
,
max_num_tiles
,
1
,
1
).
to
(
dtype
)
attention_mask
=
attention_mask
.
repeat
(
1
,
1
,
target_length
,
1
)
# Mask padding patches
pad_patches
=
target_length
-
num_patches
attention_mask
[:,
:,
-
pad_patches
:]
=
0
# Invert the mask (0 -> 1, 1 -> 0)
attention_mask
=
1
-
attention_mask
# Reshape to 2D and create 4D attention mask
# (batch_size, 1, max_num_tiles*target_length, max_num_tiles*target_length)
attention_mask
=
attention_mask
.
reshape
(
batch_size
,
max_num_tiles
*
target_length
,
1
)
attention_mask
=
attention_mask
@
attention_mask
.
transpose
(
-
1
,
-
2
)
*
torch
.
finfo
(
dtype
).
min
attention_mask
=
attention_mask
.
unsqueeze
(
1
)
return
attention_mask
class
ColumnParallelConv2dPatch
(
torch
.
nn
.
Module
):
"""Conv2D Patching layer with model parallelism.
Column parallel over unfolded input.
Arguments:
in_channels: Input channels.
out_channels: Output channels.
kernel_size: Size of convolution kernel.
stride (default 1): Stride for convolution.
bias (default False): Use bias in Conv2d.
Input: (bsz, in_channels, width, height)
Output: (bsz, num_tokens, out_channels)
"""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
Union
[
int
,
Tuple
[
int
,
int
]],
stride
:
Union
[
int
,
Tuple
[
int
,
int
]],
bias
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
if
isinstance
(
kernel_size
,
int
):
kernel_size
=
(
kernel_size
,
kernel_size
)
self
.
_unfold
=
torch
.
nn
.
Unfold
(
kernel_size
=
kernel_size
,
stride
=
stride
)
self
.
_linear
=
ColumnParallelLinear
(
in_channels
*
kernel_size
[
0
]
*
kernel_size
[
1
],
out_channels
,
bias
=
bias
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
self
.
_unfold
(
x
)
x
=
x
.
permute
(
0
,
2
,
1
)
x
,
_
=
self
.
_linear
(
x
)
return
x
class
MllamaPrecomputedAspectRatioEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaVisionConfig
,
is_gated
:
bool
=
True
):
super
().
__init__
()
self
.
max_num_tiles
=
config
.
max_num_tiles
self
.
hidden_size
=
config
.
hidden_size
self
.
max_aspect_ratio_id
=
config
.
max_aspect_ratio_id
self
.
is_gated
=
is_gated
self
.
embedding
=
nn
.
Embedding
(
self
.
max_aspect_ratio_id
+
1
,
self
.
max_num_tiles
*
self
.
hidden_size
)
if
is_gated
:
self
.
gate
=
nn
.
Parameter
(
torch
.
zeros
(
1
))
def
forward
(
self
,
hidden_state
:
torch
.
Tensor
,
aspect_ratio_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
embeddings
=
self
.
embedding
(
aspect_ratio_ids
)
embeddings
=
embeddings
.
reshape
(
-
1
,
self
.
max_num_tiles
,
1
,
self
.
hidden_size
)
if
self
.
is_gated
:
embeddings
=
embeddings
*
self
.
gate
.
tanh
()
hidden_state
=
hidden_state
+
embeddings
return
hidden_state
class
MllamaPrecomputedPositionEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaVisionConfig
):
super
().
__init__
()
self
.
max_num_tiles
=
config
.
max_num_tiles
self
.
max_aspect_ratio_id
=
config
.
max_aspect_ratio_id
self
.
num_patches
=
(
config
.
image_size
//
config
.
patch_size
)
**
2
+
1
self
.
hidden_size
=
config
.
hidden_size
self
.
scale
=
config
.
hidden_size
**-
0.5
self
.
gate
=
nn
.
Parameter
(
torch
.
zeros
(
1
))
# position embedding
position_embedding
=
torch
.
randn
(
self
.
num_patches
,
self
.
hidden_size
)
self
.
embedding
=
nn
.
Parameter
(
self
.
scale
*
position_embedding
)
# tile position embedding
self
.
tile_embedding
=
nn
.
Embedding
(
self
.
max_aspect_ratio_id
+
1
,
self
.
max_num_tiles
*
self
.
num_patches
*
self
.
hidden_size
)
def
forward
(
self
,
hidden_state
:
torch
.
Tensor
,
aspect_ratio_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# position embeddings
gated_position_embedding
=
(
1
-
self
.
gate
.
tanh
())
*
self
.
embedding
hidden_state
=
hidden_state
+
gated_position_embedding
.
view
(
1
,
1
,
self
.
num_patches
,
self
.
hidden_size
)
# precomputed tile position embeddings
tile_position_embedding
=
self
.
tile_embedding
(
aspect_ratio_ids
)
batch_size
=
hidden_state
.
shape
[
0
]
tile_position_embedding
=
tile_position_embedding
.
reshape
(
batch_size
,
self
.
max_num_tiles
,
self
.
num_patches
,
self
.
hidden_size
)
gated_tile_position_embedding
=
self
.
gate
.
tanh
(
)
*
tile_position_embedding
hidden_state
=
hidden_state
+
gated_tile_position_embedding
return
hidden_state
# TODO: support other attention backends for attention in vision model
class
MllamaVisionSdpaAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaVisionConfig
):
super
().
__init__
()
model_parallel_size
=
get_tensor_model_parallel_world_size
()
self
.
embed_dim
=
config
.
hidden_size
self
.
num_heads
=
config
.
attention_heads
self
.
head_dim
=
config
.
hidden_size
//
config
.
attention_heads
self
.
num_local_heads
=
self
.
num_heads
//
model_parallel_size
self
.
q_size
=
self
.
num_local_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_local_heads
*
self
.
head_dim
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
embed_dim
,
self
.
head_dim
,
self
.
num_heads
,
bias
=
False
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
num_heads
*
self
.
head_dim
,
self
.
embed_dim
,
bias
=
False
,
input_is_parallel
=
True
,
)
def
forward
(
self
,
hidden_state
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_state
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
=
q
.
view
(
q
.
shape
[
0
],
q
.
shape
[
1
],
self
.
num_local_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
k
=
k
.
view
(
k
.
shape
[
0
],
k
.
shape
[
1
],
self
.
num_local_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
v
=
v
.
view
(
v
.
shape
[
0
],
v
.
shape
[
1
],
self
.
num_local_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
# TODO: remove padding in image encoder
attn_output
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
attention_mask
,
dropout_p
=
0.0
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
reshape
(
attn_output
.
shape
[
0
],
attn_output
.
shape
[
1
],
-
1
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
MllamaVisionEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaVisionConfig
,
is_gated
:
bool
=
False
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
num_attention_heads
=
config
.
attention_heads
self
.
is_gated
=
is_gated
self
.
intermediate_size
=
config
.
intermediate_size
self
.
self_attn
=
MllamaVisionSdpaAttention
(
config
)
self
.
mlp
=
CLIPMLP
(
config
)
self
.
input_layernorm
=
nn
.
LayerNorm
(
self
.
hidden_size
,
eps
=
config
.
norm_eps
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
self
.
hidden_size
,
eps
=
config
.
norm_eps
)
# there used to be an if else here, no code path
if
is_gated
:
self
.
gate_attn
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
math
.
pi
/
4
)
self
.
gate_ffn
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
math
.
pi
/
4
)
def
forward
(
self
,
hidden_state
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
):
# Self Attention
residual
=
hidden_state
hidden_state
=
self
.
input_layernorm
(
hidden_state
)
hidden_state
=
self
.
self_attn
(
hidden_state
,
attention_mask
=
attention_mask
)
gate_attn
=
1
if
not
self
.
is_gated
else
self
.
gate_attn
.
tanh
()
hidden_state
=
residual
+
gate_attn
*
hidden_state
# Feed forward
residual
=
hidden_state
hidden_state
=
self
.
post_attention_layernorm
(
hidden_state
)
hidden_state
=
self
.
mlp
(
hidden_state
)
gate_ffn
=
1
if
not
self
.
is_gated
else
self
.
gate_ffn
.
tanh
()
hidden_state
=
residual
+
gate_ffn
*
hidden_state
return
hidden_state
class
MllamaVisionEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaVisionConfig
,
num_layers
=
32
,
is_gated
=
False
,
output_hidden_states
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
layers
=
nn
.
ModuleList
([
MllamaVisionEncoderLayer
(
config
,
is_gated
)
for
_
in
range
(
num_layers
)
])
self
.
output_hidden_states
=
output_hidden_states
or
[]
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutput
]:
encoder_states
=
()
for
i
,
encoder_layer
in
enumerate
(
self
.
layers
):
if
i
in
self
.
output_hidden_states
:
encoder_states
=
encoder_states
+
(
hidden_states
,
)
hidden_states
=
encoder_layer
(
hidden_states
,
attention_mask
,
)
if
len
(
self
.
layers
)
-
1
in
self
.
output_hidden_states
:
encoder_states
=
encoder_states
+
(
hidden_states
,
)
return
hidden_states
,
encoder_states
class
MllamaVisionModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaVisionConfig
):
super
().
__init__
()
self
.
image_size
=
config
.
image_size
self
.
patch_size
=
config
.
patch_size
self
.
max_num_tiles
=
config
.
max_num_tiles
self
.
hidden_size
=
config
.
hidden_size
self
.
in_channels
=
config
.
num_channels
self
.
intermediate_layers_indices
=
config
.
intermediate_layers_indices
self
.
num_patches
=
(
self
.
image_size
//
self
.
patch_size
)
**
2
+
1
self
.
scale
=
config
.
hidden_size
**-
0.5
self
.
patch_embedding
=
ColumnParallelConv2dPatch
(
in_channels
=
config
.
num_channels
,
out_channels
=
self
.
hidden_size
,
kernel_size
=
self
.
patch_size
,
stride
=
self
.
patch_size
,
bias
=
False
,
)
self
.
class_embedding
=
nn
.
Parameter
(
self
.
scale
*
torch
.
randn
(
self
.
hidden_size
))
self
.
gated_positional_embedding
=
MllamaPrecomputedPositionEmbedding
(
config
)
self
.
pre_tile_positional_embedding
=
\
MllamaPrecomputedAspectRatioEmbedding
(
config
,
is_gated
=
True
)
self
.
post_tile_positional_embedding
=
\
MllamaPrecomputedAspectRatioEmbedding
(
config
,
is_gated
=
True
)
# layer norms
self
.
layernorm_pre
=
nn
.
LayerNorm
(
self
.
hidden_size
)
self
.
layernorm_post
=
nn
.
LayerNorm
(
self
.
hidden_size
)
# encoders
self
.
transformer
=
MllamaVisionEncoder
(
config
,
config
.
num_hidden_layers
,
is_gated
=
False
,
output_hidden_states
=
config
.
intermediate_layers_indices
)
self
.
global_transformer
=
MllamaVisionEncoder
(
config
,
config
.
num_global_layers
,
is_gated
=
True
)
def
apply_class_embedding
(
self
,
hidden_state
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
,
_
,
hidden_size
=
hidden_state
.
shape
class_embedding
=
self
.
class_embedding
.
expand
(
batch_size
,
1
,
hidden_size
)
hidden_state
=
torch
.
cat
([
class_embedding
,
hidden_state
],
dim
=
1
)
return
hidden_state
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
,
aspect_ratio_ids
:
torch
.
Tensor
,
aspect_ratio_mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
,
num_concurrent_media
,
num_tiles
,
num_channels
,
\
height
,
width
=
pixel_values
.
shape
pixel_values
=
pixel_values
.
reshape
(
batch_size
*
num_concurrent_media
*
num_tiles
,
num_channels
,
height
,
width
)
aspect_ratio_ids
=
aspect_ratio_ids
.
reshape
(
batch_size
*
num_concurrent_media
,
-
1
)
# patch embedding
patch_embeds
=
self
.
patch_embedding
(
pixel_values
.
to
(
self
.
layernorm_pre
.
weight
.
dtype
))
hidden_state
=
patch_embeds
hidden_state
=
ps
.
get_tp_group
().
all_gather
(
hidden_state
)
# tile embeddings
_
,
num_patches
,
dim
=
hidden_state
.
shape
hidden_state
=
hidden_state
.
reshape
(
batch_size
*
num_concurrent_media
,
num_tiles
,
-
1
,
dim
)
hidden_state
=
self
.
pre_tile_positional_embedding
(
hidden_state
,
aspect_ratio_ids
)
# apply cls token
hidden_state
=
hidden_state
.
reshape
(
batch_size
*
num_concurrent_media
*
num_tiles
,
num_patches
,
dim
)
hidden_state
=
self
.
apply_class_embedding
(
hidden_state
)
num_patches
+=
1
# apply position embeddings
hidden_state
=
hidden_state
.
reshape
(
batch_size
*
num_concurrent_media
,
num_tiles
,
num_patches
,
dim
)
hidden_state
=
self
.
gated_positional_embedding
(
hidden_state
,
aspect_ratio_ids
)
# apply encoder
hidden_state
=
self
.
layernorm_pre
(
hidden_state
)
# Compute the number of tokens to pad
num_padding_patches
=
(
8
-
(
hidden_state
.
shape
[
-
2
]
%
8
))
%
8
# Compute padding tuple for pad function
padding
=
(
0
,
0
,
0
,
num_padding_patches
)
# (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
# Pad the tensor
hidden_state
=
F
.
pad
(
hidden_state
,
padding
,
mode
=
"constant"
,
value
=
0
)
slice_index
=
-
num_padding_patches
if
num_padding_patches
>
0
else
None
attention_mask
=
aspect_ratio_mask
.
reshape
(
batch_size
*
num_concurrent_media
,
-
1
)
attention_mask
=
_prepare_aspect_ratio_attention_mask
(
aspect_ratio_mask
=
attention_mask
,
num_patches
=
self
.
num_patches
,
target_length
=
hidden_state
.
shape
[
2
],
dtype
=
self
.
layernorm_pre
.
weight
.
dtype
,
)
hidden_state
=
hidden_state
.
view
(
batch_size
*
num_concurrent_media
,
-
1
,
dim
)
output
=
self
.
transformer
(
hidden_state
,
attention_mask
=
attention_mask
,
)
hidden_state
,
intermediate_hidden_states
=
output
[
0
],
output
[
1
]
intermediate_hidden_states
=
torch
.
stack
(
intermediate_hidden_states
,
dim
=-
1
)
# apply global encoder
hidden_state
=
self
.
layernorm_post
(
hidden_state
)
hidden_state
=
hidden_state
.
reshape
(
batch_size
*
num_concurrent_media
,
num_tiles
,
num_patches
+
num_padding_patches
,
dim
)
hidden_state
=
self
.
post_tile_positional_embedding
(
hidden_state
,
aspect_ratio_ids
)
hidden_state
=
hidden_state
.
reshape
(
batch_size
*
num_concurrent_media
,
num_tiles
*
(
num_patches
+
num_padding_patches
),
dim
)
hidden_state
=
self
.
global_transformer
(
hidden_state
,
attention_mask
=
attention_mask
)[
0
]
hidden_state
=
hidden_state
.
reshape
(
batch_size
*
num_concurrent_media
,
num_tiles
,
num_patches
+
num_padding_patches
,
dim
)
hidden_state
=
hidden_state
[:,
:,
:
slice_index
]
# adding intermediate layer outputs
hidden_state
=
hidden_state
.
reshape
(
batch_size
,
num_concurrent_media
,
num_tiles
,
num_patches
,
dim
)
intermediate_hidden_states
=
intermediate_hidden_states
.
reshape
(
batch_size
*
num_concurrent_media
,
num_tiles
,
num_patches
+
num_padding_patches
,
-
1
)
intermediate_hidden_states
=
intermediate_hidden_states
[:,
:,
:
slice_index
]
intermediate_hidden_states
=
intermediate_hidden_states
.
reshape
(
batch_size
,
num_concurrent_media
,
num_tiles
,
num_patches
,
-
1
)
hidden_state
=
torch
.
cat
([
hidden_state
,
intermediate_hidden_states
],
dim
=-
1
)
return
hidden_state
class
MllamaTextRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
):
"""
MllamaTextRMSNorm is equivalent to T5LayerNorm
"""
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
hidden_states
):
input_dtype
=
hidden_states
.
dtype
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
variance
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
return
self
.
weight
*
hidden_states
.
to
(
input_dtype
)
def
extra_repr
(
self
):
return
f
"
{
tuple
(
self
.
weight
.
shape
)
}
, eps=
{
self
.
variance_epsilon
}
"
class
MllamaTextCrossAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
self
,
config
:
Optional
[
config_mllama
.
MllamaTextConfig
]
=
None
,
layer_idx
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
model_parallel_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads
=
self
.
config
.
num_attention_heads
self
.
num_local_heads
=
self
.
num_heads
//
self
.
model_parallel_size
self
.
num_key_value_heads
=
self
.
config
.
num_key_value_heads
self
.
num_local_key_value_heads
=
\
self
.
num_key_value_heads
//
self
.
model_parallel_size
self
.
dropout
=
config
.
dropout
self
.
hidden_size
=
config
.
hidden_size
self
.
head_dim
=
config
.
hidden_size
//
self
.
num_heads
self
.
layer_idx
=
layer_idx
self
.
num_key_value_groups
=
self
.
num_heads
//
self
.
num_key_value_heads
self
.
q_local_size
=
self
.
num_local_heads
*
self
.
head_dim
self
.
kv_local_size
=
self
.
num_local_key_value_heads
*
self
.
head_dim
# TODO: change to Q/KV separate linear after #7448 is merged
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
hidden_size
,
self
.
head_dim
,
self
.
num_heads
,
self
.
num_key_value_heads
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
num_heads
*
self
.
head_dim
,
self
.
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
quant_config
=
quant_config
,
)
# vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
# use huggingface's instead
self
.
q_norm
=
MllamaTextRMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
self
.
k_norm
=
MllamaTextRMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
Attention
(
self
.
num_local_heads
,
self
.
head_dim
,
self
.
scaling
,
self
.
num_local_key_value_heads
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
],
cross_attention_states
:
Optional
[
torch
.
Tensor
],
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv_dec
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
_
,
_
=
qkv_dec
.
split
(
[
self
.
q_local_size
,
self
.
kv_local_size
,
self
.
kv_local_size
],
dim
=-
1
)
if
cross_attention_states
is
None
:
k
=
None
v
=
None
else
:
qkv_enc
,
_
=
self
.
qkv_proj
(
cross_attention_states
)
_
,
k
,
v
=
qkv_enc
.
split
(
[
self
.
q_local_size
,
self
.
kv_local_size
,
self
.
kv_local_size
],
dim
=-
1
)
k
=
k
.
view
(
-
1
,
self
.
num_local_key_value_heads
,
self
.
head_dim
)
v
=
v
.
view
(
-
1
,
self
.
num_local_key_value_heads
,
self
.
head_dim
)
k
=
self
.
k_norm
(
k
)
q
=
q
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
head_dim
)
q
=
self
.
q_norm
(
q
)
output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
,
attn_type
=
AttentionType
.
ENCODER_DECODER
)
out
,
_
=
self
.
o_proj
(
output
)
return
out
class
MllamaCrossAttentionDecoderLayer
(
torch
.
nn
.
Module
):
"""Cross-attention transformer block with tanh-gated attention
and feedforward."""
def
__init__
(
self
,
config
:
config_mllama
.
MllamaTextConfig
,
layer_idx
:
int
,
quant_config
:
Optional
[
QuantizationConfig
])
\
->
None
:
super
().
__init__
()
self
.
layer_idx
=
layer_idx
self
.
cross_attn
=
MllamaTextCrossAttention
(
config
=
config
,
layer_idx
=
layer_idx
,
quant_config
=
quant_config
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
cross_attn_attn_gate
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
1
))
self
.
mlp
=
LlamaMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
cross_attn_mlp_gate
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
1
))
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
cross_attention_states
:
torch
.
Tensor
,
cross_attention_mask
:
torch
.
Tensor
,
full_text_row_masked_out_mask
:
torch
.
Tensor
,
kv_cache
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
cross_attn
(
hidden_states
=
hidden_states
,
attention_mask
=
cross_attention_mask
,
cross_attention_states
=
cross_attention_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
hidden_states
=
full_text_row_masked_out_mask
*
hidden_states
hidden_states
=
residual
+
self
.
cross_attn_attn_gate
.
tanh
(
)
*
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
full_text_row_masked_out_mask
*
hidden_states
hidden_states
=
residual
+
self
.
cross_attn_mlp_gate
.
tanh
(
)
*
hidden_states
return
hidden_states
class
MllamaTextModel
(
nn
.
Module
):
config_class
=
config_mllama
.
MllamaTextConfig
base_model_prefix
=
"model"
def
__init__
(
self
,
config
:
config_mllama
.
MllamaTextConfig
,
cache_config
:
Optional
[
CacheConfig
],
quant_config
:
Optional
[
QuantizationConfig
]):
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
+
8
,
config
.
hidden_size
)
self
.
cross_attention_layers
=
config
.
cross_attention_layers
layers
=
[]
for
layer_idx
in
range
(
config
.
num_hidden_layers
):
if
layer_idx
in
self
.
cross_attention_layers
:
layers
.
append
(
MllamaCrossAttentionDecoderLayer
(
config
,
layer_idx
,
quant_config
=
quant_config
))
else
:
# TODO: force LlamaDecoderLayer to config.attention_bias=False
layers
.
append
(
LlamaDecoderLayer
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
))
self
.
layers
=
nn
.
ModuleList
(
layers
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
Optional
[
torch
.
LongTensor
],
cross_attention_states
:
Optional
[
torch
.
LongTensor
],
cross_attention_mask
:
Optional
[
torch
.
LongTensor
],
full_text_row_masked_out_mask
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
skip_cross_attention
:
bool
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
inputs_embeds
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
):
if
isinstance
(
decoder_layer
,
MllamaCrossAttentionDecoderLayer
):
if
not
skip_cross_attention
:
hidden_states
=
decoder_layer
(
hidden_states
=
hidden_states
,
cross_attention_states
=
cross_attention_states
,
cross_attention_mask
=
cross_attention_mask
,
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
,
kv_cache
=
kv_caches
[
idx
],
attn_metadata
=
attn_metadata
,
)
elif
isinstance
(
decoder_layer
,
LlamaDecoderLayer
):
hidden_states
,
residual
=
decoder_layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_caches
[
idx
],
attn_metadata
=
attn_metadata
,
residual
=
None
,
)
hidden_states
=
hidden_states
+
residual
else
:
raise
ValueError
(
f
"Unknown decoder layer type
{
type
(
decoder_layer
)
}
"
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
class
MllamaForCausalLM
(
nn
.
Module
):
config_class
=
config_mllama
.
MllamaTextConfig
base_model_prefix
=
"language_model"
_no_split_modules
=
[
"MllamaCrossAttentionDecoderLayer"
,
"MllamaSelfAttentionDecoderLayer"
]
def
__init__
(
self
,
config
:
config_mllama
.
MllamaTextConfig
,
cache_config
:
Optional
[
CacheConfig
],
quant_config
:
Optional
[
QuantizationConfig
]):
super
().
__init__
()
self
.
vocab_size
=
config
.
vocab_size
self
.
model
=
MllamaTextModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
,
quant_config
=
quant_config
,
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
Optional
[
torch
.
LongTensor
],
cross_attention_states
:
Optional
[
torch
.
LongTensor
],
cross_attention_mask
:
Optional
[
torch
.
LongTensor
],
full_text_row_masked_out_mask
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
skip_cross_attention
:
bool
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
cross_attention_states
=
cross_attention_states
,
cross_attention_mask
=
cross_attention_mask
,
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
skip_cross_attention
=
skip_cross_attention
,
)
return
hidden_states
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
()
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_mllama_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_decoder_data_for_mllama
)
@
INPUT_REGISTRY
.
register_dummy_encoder_data
(
dummy_encoder_data_for_mllama
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_mllama
)
class
MllamaForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaConfig
,
multimodal_config
:
MultiModalConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
vocab_size
=
config
.
text_config
.
vocab_size
self
.
hidden_size
=
config
.
text_config
.
hidden_size
self
.
max_num_tiles
=
config
.
vision_config
.
max_num_tiles
self
.
vision_output_dim
=
config
.
vision_config
.
vision_output_dim
self
.
pad_token_id
=
\
config
.
pad_token_id
if
config
.
pad_token_id
is
not
None
else
-
1
self
.
image_size
=
config
.
vision_config
.
image_size
self
.
vision_model
=
MllamaVisionModel
(
config
.
vision_config
)
self
.
language_model
=
MllamaForCausalLM
(
config
.
text_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
)
self
.
multi_modal_projector
=
nn
.
Linear
(
config
.
vision_config
.
vision_output_dim
,
config
.
text_config
.
hidden_size
,
bias
=
True
,
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
output_hidden_states
,
config
.
text_config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
language_model
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
):
# tensor with the same shape will be batched together by
# MultiModalInputs.batch, so pixel_values here can be:
# - List[List[torch.Tensor]]:
# with shape (num_tiles, 3, image_res, image_res)
# - List[torch.Tensor]:
# with shape (num_image, num_tiles, 3, image_res, image_res)
# - torch.Tensor:
# with shape (bs, num_image, num_tiles, 3, image_res, image_res)
pixel_values
:
Optional
[
Union
[
List
[
List
[
torch
.
Tensor
]],
List
[
torch
.
Tensor
],
torch
.
Tensor
]]
=
kwargs
.
pop
(
"pixel_values"
,
None
)
image_embeds
:
Optional
[
Union
[
List
[
List
[
torch
.
Tensor
]],
List
[
torch
.
Tensor
],
torch
.
Tensor
]]
=
kwargs
.
pop
(
"image_embeds"
,
None
)
aspect_ratio_ids
:
Optional
[
Union
[
List
[
List
[
torch
.
Tensor
]],
List
[
torch
.
Tensor
],
torch
.
Tensor
]]
=
kwargs
.
pop
(
"aspect_ratio_ids"
,
None
)
aspect_ratio_mask
:
Optional
[
Union
[
List
[
List
[
torch
.
Tensor
]],
List
[
torch
.
Tensor
],
torch
.
Tensor
]]
=
kwargs
.
pop
(
"aspect_ratio_mask"
,
None
)
if
pixel_values
is
None
and
image_embeds
is
None
:
return
None
if
pixel_values
is
not
None
and
image_embeds
is
not
None
:
raise
ValueError
(
"Both pixel values and image embeds are provided."
)
if
pixel_values
is
not
None
:
assert
aspect_ratio_ids
is
not
None
assert
aspect_ratio_mask
is
not
None
max_num_images
=
max
([
len
(
x
[
0
])
for
x
in
pixel_values
])
if
max_num_images
==
0
:
raise
ValueError
(
"No images provided."
)
max_num_tiles
=
max
(
max
([
len
(
x
)
for
x
in
y
[
0
]])
for
y
in
pixel_values
)
device
=
self
.
multi_modal_projector
.
weight
.
device
bsz
=
len
(
pixel_values
)
out_num_tiles
=
[]
out_images
=
torch
.
zeros
(
bsz
,
max_num_images
,
max_num_tiles
,
3
,
self
.
image_size
,
self
.
image_size
,
dtype
=
torch
.
float32
,
device
=
device
,
)
out_ar_ids
=
torch
.
ones
(
bsz
,
max_num_images
,
dtype
=
torch
.
int64
,
device
=
device
)
out_ar_mask
=
torch
.
zeros
(
bsz
,
max_num_images
,
max_num_tiles
,
dtype
=
torch
.
int64
,
device
=
device
)
for
b
in
range
(
len
(
pixel_values
)):
_num_tiles
=
[]
for
i
in
range
(
len
(
pixel_values
[
b
][
0
])):
img
=
pixel_values
[
b
][
0
][
i
]
out_images
[
b
,
i
,
:
img
.
shape
[
0
]]
=
img
out_ar_ids
[
b
,
i
]
=
aspect_ratio_ids
[
b
][
0
][
i
]
out_ar_mask
[
b
,
i
]
=
aspect_ratio_mask
[
b
][
0
][
i
]
_num_tiles
.
append
(
img
.
shape
[
0
])
out_num_tiles
.
append
(
_num_tiles
)
return
MllamaImagePixelInputs
(
type
=
"pixel_values"
,
data
=
out_images
,
aspect_ratio_ids
=
out_ar_ids
,
aspect_ratio_mask
=
out_ar_mask
,
)
if
image_embeds
is
not
None
:
raise
NotImplementedError
raise
AssertionError
(
"This line should be unreachable."
)
def
flat_encoder_result
(
self
,
cross_attention_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
):
cross_attention_states_flat
=
torch
.
zeros
(
sum
(
attn_metadata
.
encoder_seq_lens
),
cross_attention_states
.
shape
[
-
1
],
device
=
cross_attention_states
.
device
,
dtype
=
cross_attention_states
.
dtype
)
start_pos
=
0
for
seq_len
,
vision_token_in_batch
in
zip
(
attn_metadata
.
encoder_seq_lens
,
cross_attention_states
):
end_pos
=
start_pos
+
seq_len
cross_attention_states_flat
[
start_pos
:
end_pos
]
=
vision_token_in_batch
[:
seq_len
]
start_pos
=
end_pos
cross_attention_states
=
cross_attention_states_flat
full_text_row_masked_out_mask
=
torch
.
ones
(
(
attn_metadata
.
num_prefill_tokens
,
1
),
dtype
=
torch
.
bool
)
start_pos
=
0
for
seq_len
,
encoder_seq_len
in
zip
(
attn_metadata
.
seq_lens_tensor
.
cpu
(),
attn_metadata
.
encoder_seq_lens
):
if
encoder_seq_len
==
0
:
full_text_row_masked_out_mask
[
start_pos
:
start_pos
+
seq_len
]
=
False
start_pos
+=
seq_len
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
.
to
(
cross_attention_states
.
device
)
return
cross_attention_states
,
full_text_row_masked_out_mask
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
**
kwargs
:
object
,
)
->
Union
[
Tuple
,
CausalLMOutputWithPast
]:
if
attn_metadata
.
num_prefill_tokens
>
0
and
\
attn_metadata
.
num_decode_tokens
>
0
:
raise
ValueError
(
"Chunk prefill not supported"
)
image_inputs
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_inputs
is
None
:
cross_attention_mask
=
None
full_text_row_masked_out_mask
=
(
attn_metadata
.
encoder_seq_lens_tensor
!=
0
).
reshape
(
-
1
,
1
).
to
(
input_ids
.
device
)
cross_attention_states
=
None
skip_cross_attention
=
max
(
attn_metadata
.
encoder_seq_lens
)
==
0
else
:
# NOTE: llama's reference implementation runs vision model on CPU
pixel_values
=
image_inputs
[
'data'
]
aspect_ratio_ids
=
image_inputs
[
'aspect_ratio_ids'
]
aspect_ratio_mask
=
image_inputs
[
'aspect_ratio_mask'
]
cross_attention_states
=
self
.
vision_model
(
pixel_values
,
aspect_ratio_ids
,
aspect_ratio_mask
)
cross_attention_states
=
self
.
multi_modal_projector
(
cross_attention_states
)
bsz
,
_
,
_
,
_
,
image_token_dim
=
tuple
(
cross_attention_states
.
shape
)
cross_attention_states
=
cross_attention_states
.
view
(
bsz
,
-
1
,
image_token_dim
)
cross_attention_states
,
full_text_row_masked_out_mask
=
\
self
.
flat_encoder_result
(
cross_attention_states
,
attn_metadata
)
skip_cross_attention
=
False
# TODO: support multi-image by this mask
cross_attention_mask
=
None
outputs
=
self
.
language_model
(
input_ids
=
input_ids
,
positions
=
positions
,
cross_attention_states
=
cross_attention_states
,
cross_attention_mask
=
cross_attention_mask
,
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
skip_cross_attention
=
skip_cross_attention
,
)
return
outputs
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
updated_params
=
set
()
for
name
,
loaded_weight
in
weights
:
if
'patch_embedding.weight'
in
name
:
name
=
name
.
replace
(
'patch_embedding.weight'
,
'patch_embedding._linear.weight'
)
loaded_weight
=
loaded_weight
.
view
(
loaded_weight
.
shape
[
0
],
-
1
)
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
updated_params
.
add
(
name
)
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
param
=
params_dict
.
pop
(
name
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/olmoe.py
0 → 100644
View file @
539aa992
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only OLMoE model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
print_warning_once
class
OlmoeMoE
(
nn
.
Module
):
"""A tensor-parallel MoE implementation for Olmoe that shards each expert
across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def
__init__
(
self
,
num_experts
:
int
,
top_k
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
# Gate always runs at half / full precision for now.
self
.
gate
=
ReplicatedLinear
(
hidden_size
,
num_experts
,
bias
=
False
,
quant_config
=
None
)
self
.
experts
=
FusedMoE
(
num_experts
=
num_experts
,
top_k
=
top_k
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
reduce_results
=
True
,
renormalize
=
False
,
quant_config
=
quant_config
,
tp_size
=
tp_size
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape
=
hidden_states
.
shape
hidden_dim
=
hidden_states
.
shape
[
-
1
]
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
return
final_hidden_states
.
view
(
orig_shape
)
class
OlmoeAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
4096
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
q_norm
=
RMSNorm
(
hidden_size
,
eps
=
1e-5
)
self
.
k_norm
=
RMSNorm
(
hidden_size
,
eps
=
1e-5
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
is_neox_style
=
True
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
q_norm
(
q
.
contiguous
()),
self
.
k_norm
(
k
.
contiguous
())
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
OlmoeDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
layer_idx
:
int
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
4096
)
self
.
self_attn
=
OlmoeAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
config
.
num_key_value_heads
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
)
self
.
mlp
=
OlmoeMoE
(
num_experts
=
config
.
num_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
quant_config
=
quant_config
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
1e-5
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
1e-5
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
class
OlmoeModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
OlmoeDecoderLayer
(
config
,
layer_idx
,
cache_config
,
quant_config
=
quant_config
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
1e-5
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
attn_metadata
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
OlmoeForCausalLM
(
nn
.
Module
):
fall_back_to_pt_during_load
=
False
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
OlmoeModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
num_experts
)
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
# Skip non-stacked layers and experts (experts handled below).
if
weight_name
not
in
name
:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if
"mlp.experts"
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Remapping the name of FP8 kv-scale.
if
name
.
endswith
(
"kv_scale"
):
remapped_kv_scale_name
=
name
.
replace
(
".kv_scale"
,
".attn.kv_scale"
)
if
remapped_kv_scale_name
not
in
params_dict
:
print_warning_once
(
"Found kv scale in the checkpoint "
f
"(e.g.
{
name
}
), but not found the expected "
f
"name in the model "
f
"(e.g.
{
remapped_kv_scale_name
}
). "
"kv-scale is not loaded."
)
continue
else
:
name
=
remapped_kv_scale_name
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/paligemma.py
View file @
539aa992
import
itertools
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
...
...
@@ -23,7 +22,7 @@ from vllm.sequence import IntermediateTensors
from
.interfaces
import
SupportsMultiModal
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
dummy_seq_data_for_siglip
,
get_max_siglip_image_tokens
)
from
.utils
import
filter
_weights
,
merge_multimodal_embeddings
from
.utils
import
group
_weights
_with_prefix
,
merge_multimodal_embeddings
logger
=
init_logger
(
__name__
)
...
...
@@ -153,7 +152,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
self
.
unpadded_vocab_size
=
config
.
text_config
.
vocab_size
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
)
config
.
text_config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
Sampler
()
def
_validate_pixel_values
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -286,21 +286,18 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# prepare weight iterators for components
vit_
weights
,
mlp_weights
,
llm_weights
=
itertools
.
tee
(
weights
,
3
)
weights
_group
=
group_weights_with_prefix
(
weights
)
# load vision tower
vit_weights
=
filter_weights
(
vit_weights
,
"vision_tower"
)
self
.
vision_tower
.
load_weights
(
vit_weights
)
self
.
vision_tower
.
load_weights
(
weights_group
[
"vision_tower"
])
# load mlp projector
mlp_weights
=
filter_weights
(
mlp_weights
,
"multi_modal_projector"
)
mlp_params_dict
=
dict
(
self
.
multi_modal_projector
.
named_parameters
())
for
name
,
loaded_weight
in
mlp_
weights
:
for
name
,
loaded_weight
in
weights
_group
[
"multi_modal_projector"
]
:
param
=
mlp_params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load llm backbone
llm_weights
=
filter_weights
(
llm_weights
,
"language_model"
)
self
.
language_model
.
load_weights
(
llm_weights
)
self
.
language_model
.
load_weights
(
weights_group
[
"language_model"
])
vllm/model_executor/models/persimmon.py
View file @
539aa992
...
...
@@ -213,10 +213,10 @@ class PersimmonModel(nn.Module):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
text_
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
text_config
.
vocab_size
,
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
([
PersimmonDecoderLayer
(
config
,
cache_config
=
cache_config
,
...
...
@@ -257,14 +257,14 @@ class PersimmonForCausalLM(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
text_
config
.
vocab_size
self
.
model
=
PersimmonModel
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
config
.
text_config
.
vocab_size
,
config
.
hidden_size
,
bias
=
False
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
text_config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
...
...
vllm/model_executor/models/phi3.py
0 → 100644
View file @
539aa992
# coding=utf-8
# Adapted from llama.py
"""Inference-only Phi3 model code inherit from Llama.py"""
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
class
Phi3ForCausalLM
(
LlamaForCausalLM
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"qkv_proj"
,
],
"gate_up_proj"
:
[
"gate_up_proj"
,
],
}
vllm/model_executor/models/phi3v.py
View file @
539aa992
...
...
@@ -307,7 +307,7 @@ def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90
def
_calc_hd_transform_size
(
*
,
width
:
int
,
height
:
int
,
hd_num
:
int
=
16
):
def
_calc_hd_transform_size
(
*
,
width
:
int
,
height
:
int
,
hd_num
:
int
):
transposed
=
False
if
width
<
height
:
width
,
height
=
height
,
width
...
...
@@ -337,8 +337,10 @@ def get_phi3v_image_feature_size(
*
,
input_height
:
int
,
input_width
:
int
,
num_crops
:
int
,
)
->
int
:
num_crops
=
hf_config
.
get
(
"num_crops"
,
16
)
if
num_crops
is
None
:
num_crops
=
hf_config
.
get
(
"num_crops"
,
16
)
new_width
,
new_height
=
_calc_hd_transform_size
(
width
=
input_width
,
height
=
input_height
,
hd_num
=
num_crops
)
...
...
@@ -347,20 +349,26 @@ def get_phi3v_image_feature_size(
+
(
new_height
//
336
+
1
)
*
12
def
get_max_phi3v_image_tokens
(
ctx
:
InputContext
):
def
get_max_phi3v_image_tokens
(
ctx
:
InputContext
,
*
,
num_crops
:
Optional
[
int
]
=
None
):
return
get_phi3v_image_feature_size
(
ctx
.
get_hf_image_processor_config
(),
input_height
=
MAX_IMAGE_FEATURE_SIZE_HEIGHT
,
input_width
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
,
num_crops
=
num_crops
,
)
def
dummy_data_for_phi3v
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
def
dummy_data_for_phi3v
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
*
,
num_crops
:
Optional
[
int
]
=
None
):
num_images
=
mm_counts
[
"image"
]
image_feature_size
=
get_max_phi3v_image_tokens
(
ctx
)
image_feature_size
=
get_max_phi3v_image_tokens
(
ctx
,
num_crops
=
num_crops
)
seq_data
=
dummy_seq_data_for_clip
(
CLIP_VIT_LARGE_PATCH14_336_CONFIG
,
...
...
@@ -398,7 +406,10 @@ def _get_image_placeholder_token_ids(model_config: ModelConfig,
return
image_placeholder_token_ids
def
input_processor_for_phi3v
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
):
def
input_processor_for_phi3v
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
,
*
,
num_crops
:
Optional
[
int
]
=
None
):
multi_modal_data
=
llm_inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
return
llm_inputs
...
...
@@ -412,7 +423,8 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
image_feature_size
=
[
get_phi3v_image_feature_size
(
hf_config
,
input_width
=
w
,
input_height
=
h
)
input_height
=
h
,
num_crops
=
num_crops
)
]
image_data
=
[
image_data
]
elif
is_list_of
(
image_data
,
Image
.
Image
):
...
...
@@ -422,7 +434,8 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
image_feature_size
.
append
(
get_phi3v_image_feature_size
(
hf_config
,
input_width
=
w
,
input_height
=
h
))
input_height
=
h
,
num_crops
=
num_crops
))
elif
isinstance
(
image_data
,
torch
.
Tensor
):
num_images
,
image_feature_size
,
hidden_size
=
image_data
.
shape
elif
is_list_of
(
image_data
,
torch
.
Tensor
):
...
...
vllm/model_executor/models/phimoe.py
View file @
539aa992
...
...
@@ -321,13 +321,13 @@ class PhiMoEAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
True
,
quant_config
=
None
,
quant_config
=
quant_config
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
True
,
quant_config
=
None
,
quant_config
=
quant_config
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
...
...
@@ -491,6 +491,10 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA):
"o_proj"
,
"embed_tokens"
,
"lm_head"
,
"w1"
,
"w2"
,
"w3"
,
"gate"
,
]
embedding_modules
=
{
"embed_tokens"
:
"input_embeddings"
,
...
...
vllm/model_executor/models/pixtral.py
View file @
539aa992
from
array
import
array
from
dataclasses
import
dataclass
,
fields
from
itertools
import
tee
from
typing
import
Iterable
,
List
,
Mapping
,
Optional
,
Tuple
,
Union
...
...
@@ -24,8 +23,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SequenceData
)
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
from
.interfaces
import
SupportsMultiModal
from
.utils
import
init_vllm_registered_model
...
...
@@ -63,13 +61,11 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
image_feature_size
=
(
size
**
2
)
//
(
patch_size
**
2
)
num_image_tokens
=
image_feature_size
*
num_images
seq_data
=
SequenceData
.
from_token_counts
(
(
image_token_id
,
num_image_tokens
),
(
0
,
seq_len
-
num_image_tokens
),
)
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
image_token_id
])
*
num_image_tokens
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
num_image_tokens
)
seq_data
=
SequenceData
(
token_ids
)
mm_data
=
{
"image"
:
num_images
*
[
image
]}
return
seq_data
,
mm_data
...
...
@@ -454,7 +450,7 @@ class Transformer(nn.Module):
return
x
def
position_meshgrid
(
patch_embeds_list
:
l
ist
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
def
position_meshgrid
(
patch_embeds_list
:
L
ist
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
positions
=
torch
.
cat
([
torch
.
stack
(
torch
.
meshgrid
(
...
...
vllm/model_executor/models/qwen.py
View file @
539aa992
...
...
@@ -7,7 +7,6 @@
import
math
import
re
from
array
import
array
from
functools
import
partial
from
typing
import
(
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
...
...
@@ -48,8 +47,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SequenceData
)
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
from
vllm.utils
import
is_list_of
from
.utils
import
flatten_bn
,
is_pp_missing_parameter
,
make_layers
...
...
@@ -689,8 +687,9 @@ def input_processor_for_qwen(ctx: InputContext,
prompt
=
llm_inputs
.
get
(
"prompt"
)
prompt_token_ids
=
llm_inputs
[
"prompt_token_ids"
]
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
)
image_data
=
multi_modal_data
[
"image"
]
if
isinstance
(
image_data
,
torch
.
Tensor
):
num_dims
=
len
(
image_data
.
shape
)
...
...
@@ -750,8 +749,9 @@ def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs:
return
MultiModalInputs
()
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
)
image_pair_tok
=
tokenizer
.
encode
(
IMG_START
+
IMG_END
,
add_special_tokens
=
False
,
...
...
@@ -832,15 +832,16 @@ def dummy_data_for_qwen(
# The presence of a visual config indicates this is a multimodal model.
# If we don't have it, the model is considered an LLM for warmup purposes.
if
not
hasattr
(
hf_config
,
"visual"
):
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
]
*
seq_len
))
seq_data
=
SequenceData
.
from_token_counts
((
0
,
seq_len
))
mm_data
=
None
return
seq_data
,
mm_data
# We have a visual component - use images to warm up
num_images
=
mm_counts
[
"image"
]
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
)
# Build the image prompts with no imgpads; the tokenizer will add img pads
image_prompt
=
''
.
join
(
...
...
@@ -859,11 +860,13 @@ def dummy_data_for_qwen(
if
len
(
toks
)
<
seq_len
:
toks
+=
[
0
]
*
(
seq_len
-
len
(
toks
))
seq_data
=
SequenceData
.
from_seqs
(
toks
)
# Build the input images; width/height doesn't actually matter here since
# the data will get resized and the # of tokens per image is constant
image
=
Image
.
new
(
"RGB"
,
(
224
,
224
),
color
=
0
)
mm_data
=
{
"image"
:
image
if
num_images
==
1
else
[
image
]
*
num_images
}
return
S
eq
uenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
toks
))
,
mm_data
return
s
eq
_data
,
mm_data
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
(
input_mapper_for_qwen
)
...
...
vllm/model_executor/models/qwen2.py
View file @
539aa992
...
...
@@ -51,7 +51,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
from
.utils
import
is_pp_missing_parameter
,
make_layers
from
.utils
import
PPMissingLayer
,
is_pp_missing_parameter
,
make_layers
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
...
...
@@ -247,11 +247,16 @@ class Qwen2Model(nn.Module):
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
)
if
get_pp_group
().
is_first_rank
or
(
config
.
tie_word_embeddings
and
get_pp_group
().
is_last_rank
):
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
Qwen2DecoderLayer
(
config
=
config
,
...
...
@@ -260,7 +265,10 @@ class Qwen2Model(nn.Module):
prefix
=
f
"
{
prefix
}
.layers"
,
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
if
get_pp_group
().
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
()
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
539aa992
...
...
@@ -22,7 +22,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from
array
import
array
from
functools
import
lru_cache
,
partial
from
typing
import
(
Iterable
,
List
,
Mapping
,
Optional
,
Tuple
,
Type
,
TypedDict
,
Union
)
...
...
@@ -46,7 +45,7 @@ from vllm.attention import AttentionMetadata
from
vllm.attention.selector
import
(
_Backend
,
backend_name_to_enum
,
get_global_forced_attn_backend
)
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.distributed
import
parallel_state
from
vllm.distributed
import
get_pp_group
,
parallel_state
from
vllm.distributed
import
utils
as
dist_utils
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
...
...
@@ -66,9 +65,12 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
from
vllm.multimodal.base
import
MultiModalData
from
vllm.multimodal.image
import
cached_get_image_processor
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SequenceData
)
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
from
vllm.transformers_utils.processor
import
get_processor
from
vllm.utils
import
is_cpu
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
)
logger
=
init_logger
(
__name__
)
...
...
@@ -207,7 +209,7 @@ class Qwen2VisionAttention(nn.Module):
selected_backend
=
backend_name_to_enum
(
backend_by_env_var
)
if
selected_backend
is
None
:
# For Volta and Turing GPUs, use xformers instead.
device_available
=
current_platform
.
get
_device_capability
(
)[
0
]
>=
8
device_available
=
current_platform
.
has
_device_capability
(
80
)
if
device_available
:
from
transformers.utils
import
is_flash_attn_2_available
...
...
@@ -280,6 +282,21 @@ class Qwen2VisionAttention(nn.Module):
context_layer
=
rearrange
(
output
,
"(b s) ... -> b s ..."
,
b
=
batch_size
)
elif
is_cpu
():
seq_length
=
q
.
size
(
1
)
q
,
k
,
v
=
[
rearrange
(
x
,
"b s h d -> b h s d"
)
for
x
in
[
q
,
k
,
v
]]
attention_mask
=
torch
.
zeros
([
1
,
seq_length
,
seq_length
],
device
=
q
.
device
,
dtype
=
torch
.
bool
)
for
i
in
range
(
1
,
len
(
cu_seqlens
)):
attention_mask
[...,
cu_seqlens
[
i
-
1
]:
cu_seqlens
[
i
],
cu_seqlens
[
i
-
1
]:
cu_seqlens
[
i
]]
=
True
output
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attention_mask
,
dropout_p
=
0.0
)
context_layer
=
rearrange
(
output
,
"b h s d -> b s h d "
)
else
:
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalMask
...
...
@@ -681,15 +698,14 @@ def dummy_data_for_qwen2_vl(
"--limit-mm-per-prompt."
)
hf_config
=
ctx
.
get_hf_config
(
Qwen2VLConfig
)
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
hf_config
.
vision_start_token_id
])
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
hf_config
.
image_token_id
])
*
max_llm_image_tokens
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
hf_config
.
vision_end_token_id
])
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
max_llm_image_tokens
-
2
)
dummy_seqdata
=
SequenceData
(
token_ids
)
dummy_seqdata
=
SequenceData
.
from_token_counts
(
(
hf_config
.
vision_start_token_id
,
1
),
(
hf_config
.
image_token_id
,
max_llm_image_tokens
),
(
hf_config
.
vision_end_token_id
,
1
),
(
0
,
seq_len
-
max_llm_image_tokens
-
2
),
)
dummy_image
=
Image
.
new
(
"RGB"
,
(
max_resized_width
,
max_resized_height
),
color
=
0
)
...
...
@@ -859,15 +875,21 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
self
.
model
=
Qwen2Model
(
config
,
cache_config
,
quant_config
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
if
get_pp_group
().
is_last_rank
:
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
self
.
lm_head
=
PPMissingLayer
()
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
def
_validate_and_reshape_mm_tensor
(
self
,
mm_input
:
Union
[
torch
.
Tensor
,
...
...
@@ -982,7 +1004,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
video_input
=
self
.
_parse_and_validate_video_input
(
**
kwargs
)
if
image_input
is
None
and
video_input
is
None
:
if
(
image_input
is
None
and
video_input
is
None
)
or
not
get_pp_group
().
is_first_rank
:
inputs_embeds
=
None
else
:
if
getattr
(
self
.
config
,
"rope_scaling"
,
{}).
get
(
"type"
,
...
...
@@ -1018,6 +1041,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
positions
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
return
hidden_states
...
...
@@ -1058,6 +1082,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
...
...
@@ -1084,6 +1110,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
except
KeyError
:
print
(
params_dict
.
keys
())
...
...
vllm/model_executor/models/siglip.py
View file @
539aa992
...
...
@@ -2,9 +2,9 @@
within a vision language model."""
import
math
from
array
import
array
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
from
PIL
import
Image
from
torch
import
nn
...
...
@@ -24,7 +24,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.multimodal.utils
import
(
cached_get_tokenizer
,
repeat_and_pad_placeholder_tokens
)
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
from
vllm.sequence
import
SequenceData
try
:
from
xformers
import
ops
as
xops
...
...
@@ -67,11 +67,10 @@ def dummy_seq_data_for_siglip(
else
:
image_feature_size
=
image_feature_size_override
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
image_token_id
])
*
image_feature_size
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
image_feature_size
)
return
SequenceData
(
token_ids
)
return
SequenceData
.
from_token_counts
(
(
image_token_id
,
image_feature_size
*
num_images
),
(
0
,
seq_len
-
image_feature_size
*
num_images
),
)
def
dummy_image_for_siglip
(
...
...
@@ -91,6 +90,24 @@ def dummy_image_for_siglip(
return
{
"image"
:
image
if
num_images
==
1
else
[
image
]
*
num_images
}
def
dummy_video_for_siglip
(
hf_config
:
SiglipVisionConfig
,
num_frames
:
int
,
*
,
image_width_override
:
Optional
[
int
]
=
None
,
image_height_override
:
Optional
[
int
]
=
None
,
):
pil_frame
=
dummy_image_for_siglip
(
hf_config
,
num_images
=
1
,
image_width_override
=
image_width_override
,
image_height_override
=
image_height_override
)
np_frame
=
np
.
array
(
pil_frame
[
"image"
])
mm_data_per_video
=
np
.
repeat
([
np_frame
],
num_frames
,
axis
=
0
)
mm_data
=
{
"video"
:
mm_data_per_video
}
return
mm_data
def
input_processor_for_siglip
(
model_config
:
ModelConfig
,
hf_config
:
SiglipVisionConfig
,
...
...
@@ -503,6 +520,7 @@ class SiglipVisionModel(nn.Module):
num_hidden_layers_override
:
Optional
[
int
]
=
None
,
):
super
().
__init__
()
num_heads
=
config
.
num_attention_heads
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
shard_weight
=
USE_XFORMERS_OPS
and
num_heads
%
tp_size
==
0
...
...
@@ -513,10 +531,6 @@ class SiglipVisionModel(nn.Module):
num_hidden_layers_override
=
num_hidden_layers_override
,
)
@
property
def
_require_post_layernorm
(
self
)
->
bool
:
return
self
.
vision_model
.
post_layernorm
is
not
None
def
get_input_embeddings
(
self
)
->
nn
.
Module
:
return
self
.
vision_model
.
embeddings
.
patch_embedding
...
...
@@ -542,12 +556,12 @@ class SiglipVisionModel(nn.Module):
for
name
,
loaded_weight
in
weights
:
# post_layernorm is optional in SiglipVisionModel
if
(
"vision_model.post_layernorm"
in
name
and
not
self
.
_require_
post_layernorm
):
if
(
name
.
startswith
(
"vision_model.post_layernorm"
)
and
self
.
vision_model
.
post_layernorm
is
None
):
continue
# omit layers when num_hidden_layers_override is set
if
"vision_model.encoder.layers
."
in
name
:
if
name
.
startswith
(
"vision_model.encoder.layers
"
)
:
layer_idx
=
int
(
name
.
split
(
"."
)[
3
])
if
layer_idx
>=
layer_count
:
continue
...
...
vllm/model_executor/models/solar.py
0 → 100644
View file @
539aa992
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Solar model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
get_compressed_tensors_cache_scale
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
kv_cache_scales_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.models.interfaces
import
SupportsLoRA
from
vllm.model_executor.models.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
make_layers
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_hip
class
SolarMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
False
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
input_size
=
hidden_size
,
output_sizes
=
[
intermediate_size
]
*
2
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
)
self
.
down_proj
=
RowParallelLinear
(
input_size
=
intermediate_size
,
output_size
=
hidden_size
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
SolarAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
False
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
self
.
head_dim
=
getattr
(
config
,
"head_dim"
,
self
.
hidden_size
//
self
.
total_num_heads
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
=
hidden_size
,
head_size
=
self
.
head_dim
,
total_num_heads
=
self
.
total_num_heads
,
total_num_kv_heads
=
self
.
total_num_kv_heads
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
input_size
=
self
.
total_num_heads
*
self
.
head_dim
,
output_size
=
hidden_size
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
SolarDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
if
rope_scaling
is
not
None
and
getattr
(
config
,
"original_max_position_embeddings"
,
None
):
rope_scaling
[
"original_max_position_embeddings"
]
\
=
config
.
original_max_position_embeddings
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias
=
getattr
(
config
,
"attention_bias"
,
False
)
or
getattr
(
config
,
"bias"
,
False
)
self
.
self_attn
=
SolarAttention
(
config
=
config
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
getattr
(
config
,
"num_key_value_heads"
,
config
.
num_attention_heads
),
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
quant_config
=
quant_config
,
bias
=
attention_bias
,
cache_config
=
cache_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
self
.
mlp
=
SolarMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
bias
=
getattr
(
config
,
"mlp_bias"
,
False
),
prefix
=
f
"
{
prefix
}
.mlp"
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
class
SolarModel
(
nn
.
Module
):
def
__init__
(
self
,
config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
lora_vocab
=
((
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
)
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
self
.
org_vocab_size
=
config
.
vocab_size
if
get_pp_group
().
is_first_rank
or
(
config
.
tie_word_embeddings
and
get_pp_group
().
is_last_rank
):
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
SolarDecoderLayer
(
config
=
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
prefix
,
),
prefix
=
f
"
{
prefix
}
.layers"
,
)
if
get_pp_group
().
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
()
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
bskcn_h_1
=
None
bskcn_h_2
=
None
bskcn_r_1
=
None
bskcn_r_2
=
None
bskcn_tv
=
(
self
.
config
.
bskcn_tv
[
0
]
if
self
.
training
else
self
.
config
.
bskcn_tv
[
1
])
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
if
i
in
self
.
config
.
bskcn_1
:
bskcn_h_1
=
hidden_states
.
clone
()
bskcn_r_1
=
residual
.
clone
()
if
i
in
self
.
config
.
bskcn_2
:
bskcn_h_2
=
hidden_states
.
clone
()
bskcn_r_2
=
residual
.
clone
()
if
i
in
self
.
config
.
bskcn_3
:
hidden_states
=
bskcn_h_1
*
bskcn_tv
+
hidden_states
*
(
1
-
bskcn_tv
)
residual
=
bskcn_r_1
*
bskcn_tv
+
residual
*
(
1
-
bskcn_tv
)
if
i
in
self
.
config
.
bskcn_4
:
hidden_states
=
bskcn_h_2
*
bskcn_tv
+
hidden_states
*
(
1
-
bskcn_tv
)
residual
=
bskcn_r_2
*
bskcn_tv
+
residual
*
(
1
-
bskcn_tv
)
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"residual"
:
residual
})
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
SolarForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
# LoRA specific attributes
supported_lora_modules
=
[
"qkv_proj"
,
"o_proj"
,
"gate_up_proj"
,
"down_proj"
,
"embed_tokens"
,
"lm_head"
,
]
embedding_modules
=
{
"embed_tokens"
:
"input_embeddings"
,
"lm_head"
:
"output_embeddings"
,
}
embedding_padding_modules
=
[
"lm_head"
]
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"k_proj"
:
(
"qkv_proj"
,
1
),
"v_proj"
:
(
"qkv_proj"
,
2
),
"gate_proj"
:
(
"gate_up_proj"
,
0
),
"up_proj"
:
(
"gate_up_proj"
,
1
),
}
def
__init__
(
self
,
config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
model
=
SolarModel
(
config
,
cache_config
,
quant_config
,
lora_config
=
lora_config
,
prefix
=
"model"
,
)
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
quant_config
=
quant_config
,
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
Sampler
()
else
:
self
.
lm_head
=
PPMissingLayer
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
model_output
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
return
model_output
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
make_empty_intermediate_tensors
(
self
,
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
IntermediateTensors
:
return
IntermediateTensors
({
"hidden_states"
:
torch
.
zeros
(
(
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
,
),
"residual"
:
torch
.
zeros
(
(
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
,
),
})
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
(
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if
scale_name
:
=
get_compressed_tensors_cache_scale
(
name
):
# Loading kv cache scales for compressed-tensors quantization
param
=
params_dict
[
scale_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
loaded_weight
=
loaded_weight
[
0
]
weight_loader
(
param
,
loaded_weight
)
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def
load_kv_cache_scales
(
self
,
quantization_param_path
:
str
)
->
None
:
tp_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
for
layer_idx
,
scaling_factor
in
kv_cache_scales_loader
(
quantization_param_path
,
tp_rank
,
tp_size
,
self
.
config
.
num_hidden_layers
,
self
.
config
.
__class__
.
model_type
,
):
if
not
isinstance
(
self
.
model
.
layers
[
layer_idx
],
nn
.
Identity
):
layer_self_attn
=
self
.
model
.
layers
[
layer_idx
].
self_attn
if
is_hip
():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor
*=
2
if
hasattr
(
layer_self_attn
,
"kv_scale"
):
layer_self_attn
.
attn
.
_kv_scale
=
scaling_factor
else
:
raise
RuntimeError
(
"Self attention has no KV cache scaling "
"factor attribute!"
)
vllm/model_executor/models/ultravox.py
View file @
539aa992
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
import
itertools
import
math
from
array
import
array
from
functools
import
lru_cache
...
...
@@ -21,15 +20,16 @@ from vllm.config import CacheConfig, MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
from
vllm.inputs.data
import
LLMInputs
from
vllm.inputs.registry
import
InputContext
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
,
get_act_fn
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.loader
import
DefaultModelLoader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
SupportsMultiModal
from
vllm.model_executor.models.utils
import
(
filter_weights
,
flatten_bn
,
from
vllm.model_executor.models.utils
import
(
flatten_bn
,
group_weights_with_prefix
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
...
@@ -43,8 +43,6 @@ from vllm.transformers_utils.configs.ultravox import UltravoxConfig
_AUDIO_PLACEHOLDER_TOKEN
=
128002
_AUDIO_TOKENS_PER_SECOND
=
6.25
logger
=
init_logger
(
__name__
)
class
UltravoxAudioFeatureInputs
(
TypedDict
):
type
:
Literal
[
"audio_features"
]
...
...
@@ -77,15 +75,11 @@ def get_ultravox_max_audio_tokens(ctx: InputContext):
return
math
.
ceil
(
feature_extractor
.
chunk_length
*
_AUDIO_TOKENS_PER_SECOND
)
def
dummy_data_for_ultravox
(
def
dummy_
seq_
data_for_ultravox
(
ctx
:
InputContext
,
seq_len
:
int
,
mm
_count
s
:
Mapping
[
str
,
int
]
,
audio
_count
:
int
,
):
feature_extractor
=
whisper_feature_extractor
(
ctx
)
audio_count
=
mm_counts
[
"audio"
]
audio_placeholder
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
_AUDIO_PLACEHOLDER_TOKEN
])
*
get_ultravox_max_audio_tokens
(
ctx
)
...
...
@@ -96,10 +90,28 @@ def dummy_data_for_ultravox(
other_token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
len
(
audio_token_ids
))
return
SequenceData
(
audio_token_ids
+
other_token_ids
)
def
dummy_audio_for_ultravox
(
ctx
:
InputContext
,
audio_count
:
int
,
):
feature_extractor
=
whisper_feature_extractor
(
ctx
)
audio_and_sr
=
(
np
.
array
([
0.0
]
*
feature_extractor
.
chunk_length
),
1
)
mm_dict
=
{
"audio"
:
[
audio_and_sr
]
*
audio_count
}
return
{
"audio"
:
[
audio_and_sr
]
*
audio_count
}
def
dummy_data_for_ultravox
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
):
audio_count
=
mm_counts
[
"audio"
]
seq_data
=
dummy_seq_data_for_ultravox
(
ctx
,
seq_len
,
audio_count
)
mm_dict
=
dummy_audio_for_ultravox
(
ctx
,
audio_count
)
return
(
S
eq
uenceData
(
audio_token_ids
+
other_token_ids
)
,
mm_dict
)
return
(
s
eq
_data
,
mm_dict
)
def
input_mapper_for_ultravox
(
ctx
:
InputContext
,
data
:
object
):
...
...
@@ -323,14 +335,23 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
self
.
multi_modal_config
=
multimodal_config
assert
self
.
multi_modal_config
self
.
secondary_weights
=
[]
self
.
audio_tower
=
ModifiedWhisperEncoder
(
config
.
audio_config
)
if
config
.
audio_model_id
is
not
None
:
self
.
audio_tower
=
ModifiedWhisperEncoder
.
from_pretrained
(
config
.
audio_model_id
)
else
:
self
.
audio_tower
=
ModifiedWhisperEncoder
(
config
.
audio_config
)
self
.
secondary_weights
.
append
(
DefaultModelLoader
.
Source
(
model_or_path
=
config
.
audio_model_id
,
revision
=
None
,
prefix
=
"audio_tower."
,
))
self
.
multi_modal_projector
=
UltravoxProjector
(
config
)
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
cache_config
,
quant_config
)
if
config
.
text_model_id
is
not
None
:
self
.
secondary_weights
.
append
(
DefaultModelLoader
.
Source
(
model_or_path
=
config
.
text_model_id
,
revision
=
None
,
prefix
=
"language_model."
))
def
_audio_features_to_embeddings
(
self
,
input_features
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -453,11 +474,22 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# prepare weight iterators for components
projector_weights
,
llm_weights
=
itertools
.
tee
(
weights
,
2
)
weights_group
=
group_weights_with_prefix
(
weights
)
# load audio tower weights
audio_tower_weights
=
weights_group
[
"audio_tower"
]
audio_tower_params_dict
=
dict
(
self
.
audio_tower
.
named_parameters
(
prefix
=
self
.
audio_tower
.
base_model_prefix
))
for
name
,
loaded_weight
in
audio_tower_weights
:
if
name
in
audio_tower_params_dict
:
param
=
audio_tower_params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load projector weights
projector_weights
=
filter_weights
(
projector_weights
,
"multi_modal_projector"
)
projector_weights
=
weights_group
[
"multi_modal_projector"
]
projector_params_dict
=
dict
(
self
.
multi_modal_projector
.
named_parameters
())
for
name
,
loaded_weight
in
projector_weights
:
...
...
@@ -467,5 +499,4 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
weight_loader
(
param
,
loaded_weight
)
# load llm backbone
llm_weights
=
filter_weights
(
llm_weights
,
"language_model"
)
self
.
language_model
.
load_weights
(
llm_weights
)
self
.
language_model
.
load_weights
(
weights_group
[
"language_model"
])
vllm/model_executor/models/utils.py
View file @
539aa992
import
itertools
from
collections
import
UserDict
from
typing
import
(
Dict
,
Iterable
,
List
,
Literal
,
Optional
,
Protocol
,
Tuple
,
Union
,
overload
)
...
...
@@ -16,7 +18,23 @@ from vllm.sequence import IntermediateTensors
from
vllm.utils
import
is_pin_memory_available
def
filter_weights
(
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
prefix
:
str
):
class
WeightsGroup
(
UserDict
):
"""
Wraps grouped weights dictionary for a more informative error message
when attempting to access a weight component that does not exist.
"""
def
__getitem__
(
self
,
key
:
str
)
->
int
:
try
:
return
super
().
__getitem__
(
key
)
except
KeyError
as
exc
:
msg
=
(
f
"There is no weights named with the prefix:
{
key
}
. "
f
"Available prefix:
{
set
(
self
.
keys
())
}
"
)
raise
KeyError
(
msg
)
from
exc
def
filter_weights
(
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
prefix
:
str
)
->
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]:
"""
Helper function to load weights for inner vLLM models.
...
...
@@ -30,6 +48,22 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
yield
name
,
loaded_weight
def
group_weights_with_prefix
(
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]
)
->
Dict
[
str
,
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]]:
"""
Helper function to group weights with prefix
"""
init_weights
,
repeated_weights
=
itertools
.
tee
(
weights
,
2
)
weights_prefix
=
{
name
.
split
(
"."
)[
0
]
for
name
,
_
in
init_weights
}
repeated_weights
=
itertools
.
tee
(
repeated_weights
,
len
(
weights_prefix
))
return
WeightsGroup
({
prefix
:
filter_weights
(
component
,
prefix
)
for
component
,
prefix
in
zip
(
repeated_weights
,
weights_prefix
)
})
def
init_vllm_registered_model
(
hf_config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
],
...
...
vllm/model_executor/parameter.py
View file @
539aa992
...
...
@@ -328,6 +328,64 @@ class PackedvLLMParameter(ModelWeightParameter):
marlin_tile_size
=
self
.
marlin_tile_size
)
def
permute_param_layout_
(
param
:
BasevLLMParameter
,
input_dim
:
int
,
output_dim
:
int
,
**
kwargs
)
->
BasevLLMParameter
:
"""
Permute a parameter's layout to the specified input and output dimensions,
useful for forcing the parameter into a known layout, for example, if I need
a packed (quantized) weight matrix to be in the layout
{input_dim = 0, output_dim = 1, packed_dim = 0}
then I can call:
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
to ensure x is in the correct layout (permuting it to the correct layout if
required, asserting if it cannot get it to the correct layout)
"""
curr_input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
curr_output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
if
curr_input_dim
is
None
or
curr_output_dim
is
None
:
assert
param
.
data
.
dim
()
==
2
,
\
"permute_param_layout_ only supports 2D parameters when either "
\
"input_dim or output_dim is not set"
# if one of the dimensions is not set, set it to the opposite of the other
# we can only do this since we asserted the parameter is 2D above
if
curr_input_dim
is
None
:
assert
curr_output_dim
is
not
None
,
\
"either input or output dim must be set"
curr_input_dim
=
(
curr_output_dim
+
1
)
%
2
if
curr_output_dim
is
None
:
assert
curr_input_dim
is
not
None
,
\
"either input or output dim must be set"
curr_output_dim
=
(
curr_input_dim
+
1
)
%
2
# create permutation from the current layout to the layout with
# self.input_dim at input_dim and self.output_dim at output_dim preserving
# other dimensions
perm
=
[
i
for
i
in
range
(
param
.
data
.
dim
())
if
i
not
in
[
curr_input_dim
,
curr_output_dim
]
]
perm
.
insert
(
input_dim
,
curr_input_dim
)
perm
.
insert
(
output_dim
,
curr_output_dim
)
if
"packed_dim"
in
kwargs
:
assert
hasattr
(
param
,
"packed_dim"
)
and
\
param
.
packed_dim
==
perm
[
kwargs
[
"packed_dim"
]],
\
"permute_param_layout_ currently doesn't support repacking"
param
.
data
=
param
.
data
.
permute
(
*
perm
)
if
hasattr
(
param
,
"_input_dim"
):
param
.
_input_dim
=
input_dim
if
hasattr
(
param
,
"_output_dim"
):
param
.
_output_dim
=
output_dim
if
"packed_dim"
in
kwargs
and
hasattr
(
param
,
"_packed_dim"
):
param
.
_packed_dim
=
kwargs
[
"packed_dim"
]
return
param
def
_adjust_shard_indexes_for_marlin
(
shard_size
,
shard_offset
,
marlin_tile_size
):
return
shard_size
*
marlin_tile_size
,
shard_offset
*
marlin_tile_size
...
...
vllm/model_executor/sampling_metadata.py
View file @
539aa992
import
random
from
array
import
array
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
...
...
@@ -8,15 +7,10 @@ import torch
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.triton_utils.sample
import
get_num_triton_sampler_splits
from
vllm.utils
import
(
PyObjectCache
,
async_tensor_h2d
,
is_pin_memory_available
,
make_tensor_with_pad
,
maybe_expand_dim
)
is_pin_memory_available
,
make_tensor_with_pad
)
_SAMPLING_EPS
=
1e-5
_SEED_0_REPLACEMENT
=
3403598558
# Some triton sampler related code is guarded before it is ready.
_USE_TRITON_SAMPLER
=
False
@
dataclass
...
...
@@ -74,12 +68,12 @@ def gen_seq_group_to_sample_builder(num_seqs: int):
generator
=
None
,
is_prompt
=
True
,
prompt_logprob_indices
=
[],
sample_indices
=
[])
sample_indices
=
[],
)
class
SamplingMetadataCache
:
"""Used to cache SamplingMetadata objects between scheduler iterations
"""
"""Used to cache SamplingMetadata objects between scheduler iterations"""
def
__init__
(
self
):
self
.
_seq_group_to_sample_cache
:
Dict
[
int
,
PyObjectCache
]
=
{}
...
...
@@ -124,12 +118,12 @@ class SamplingMetadata:
The first tuple is [1, 2] (sampled index within original logit),
and the second tuple is [0, 1] (sampled index within pruned logit).
num_prompts: Number of prompt sequence groups in seq_groups.
skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
serialization of token outputs.
reuse_sampling_tensors: Indicates if we want to reuse sampling
reuse_sampling_tensors: Indicates if we want to reuse sampling
tensors that are part of the sampler forward pass. Currently,
it is mainly used for multi-step decode.
"""
def
__init__
(
...
...
@@ -165,16 +159,19 @@ class SamplingMetadata:
num_prompts
,
)
=
_prepare_seq_groups
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
device
,
generators
,
cache
)
selected_token_indices
=
async_tensor_h2d
(
selected_token_indices
,
dtype
=
torch
.
long
,
target_device
=
device
,
pin_memory
=
pin_memory
)
selected_token_indices
=
async_tensor_h2d
(
selected_token_indices
,
dtype
=
torch
.
long
,
target_device
=
device
,
pin_memory
=
pin_memory
,
)
categorized_sample_indices
=
{
t
:
maybe_expand_dim
(
async_tensor_h2d
(
seq_ids
,
dtype
=
torch
.
int
,
target_device
=
device
,
pin_memory
=
pin_memory
),
2
,
2
)
t
:
async_tensor_h2d
(
seq_ids
,
dtype
=
torch
.
int
,
target_device
=
device
,
pin_memory
=
pin_memory
,
)
for
t
,
seq_ids
in
categorized_sample_indices
.
items
()
}
...
...
@@ -201,8 +198,8 @@ def _prepare_seq_groups(
device
:
str
,
generators
:
Optional
[
Dict
[
str
,
torch
.
Generator
]]
=
None
,
cache
:
Optional
[
SamplingMetadataCache
]
=
None
,
)
->
Tuple
[
List
[
SequenceGroupToSample
],
List
[
int
],
Dict
[
SamplingType
,
List
[
Tuple
[
int
,
int
]]],
int
]:
)
->
Tuple
[
List
[
SequenceGroupToSample
],
List
[
int
],
Dict
[
SamplingType
,
List
[
int
]]
,
int
,
]:
"""Prepare sequence groups and indices for sampling.
Args:
...
...
@@ -233,16 +230,13 @@ def _prepare_seq_groups(
# Sampling type -> (
# indices to sample/prompt logprob within pruned output logits,
# indices to sample within pruned logits)
categorized_sample_indices
:
Dict
[
SamplingType
,
List
[
Tuple
[
int
,
int
]
]]
=
{
categorized_sample_indices
:
Dict
[
SamplingType
,
List
[
int
]]
=
{
t
:
[]
for
t
in
SamplingType
}
# Index of logits to compute logprob. Logits include both prompt logprob
# and sample logprob indices.
logit_idx
=
0
# Index to sample from a sample tensor. It is used by triton sample kernel.
# See `_sample_with_triton_kernel` for more details.
sample_idx
=
0
# Total number of prompts from given sequence groups.
num_prompts
=
0
...
...
@@ -264,10 +258,10 @@ def _prepare_seq_groups(
# If the current seq group is in decode stage, it is None.
seq_len
:
Optional
[
int
]
=
None
query_len
:
Optional
[
int
]
=
None
prompt_logprob_indices
:
List
[
int
]
=
\
sample_obj
.
prompt_logprob_indices
if
cache
is
not
None
else
[]
sample_indices
:
List
[
int
]
=
\
sample_obj
.
sample_indices
if
cache
is
not
None
else
[]
prompt_logprob_indices
:
List
[
int
]
=
(
sample_obj
.
prompt_logprob_indices
if
cache
is
not
None
else
[]
)
sample_indices
:
List
[
int
]
=
(
sample_obj
.
sample_indices
if
cache
is
not
None
else
[]
)
do_sample
=
seq_group_metadata
.
do_sample
if
seq_group_metadata
.
is_prompt
:
...
...
@@ -333,11 +327,8 @@ def _prepare_seq_groups(
if
do_sample
:
sample_indices
.
extend
(
range
(
logit_idx
,
logit_idx
+
sample_len
))
categorized_sample_indices
[
sampling_params
.
sampling_type
].
extend
(
list
(
zip
(
range
(
logit_idx
,
logit_idx
+
sample_len
),
range
(
sample_idx
,
sample_idx
+
sample_len
))))
list
(
range
(
logit_idx
,
logit_idx
+
sample_len
)))
logit_idx
+=
sample_len
sample_idx
+=
sample_len
if
cache
is
not
None
:
sample_obj
.
sampling_params
=
sampling_params
...
...
@@ -356,7 +347,8 @@ def _prepare_seq_groups(
generator
=
generator
,
is_prompt
=
is_prompt
,
prompt_logprob_indices
=
list
(
prompt_logprob_indices
),
sample_indices
=
list
(
sample_indices
))
sample_indices
=
list
(
sample_indices
),
)
seq_groups
.
append
(
sample_obj
)
...
...
@@ -378,9 +370,6 @@ class SamplingTensors:
presence_penalties
:
torch
.
Tensor
frequency_penalties
:
torch
.
Tensor
repetition_penalties
:
torch
.
Tensor
sampling_seeds
:
torch
.
Tensor
sample_indices
:
torch
.
Tensor
extra_seeds
:
Optional
[
torch
.
Tensor
]
prompt_tokens
:
torch
.
Tensor
output_tokens
:
torch
.
Tensor
...
...
@@ -391,15 +380,7 @@ class SamplingTensors:
vocab_size
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
*
,
extra_seeds_to_generate
:
int
=
0
,
extra_entropy
:
Optional
[
Tuple
[
int
,
...]]
=
None
)
->
Tuple
[
"SamplingTensors"
,
bool
,
bool
,
bool
]:
"""
extra_seeds_to_generate: extra seeds to generate using the
user-defined seed for each sequence.
extra_entropy: extra entropy to use when generating seeds.
"""
prompt_tokens
:
List
[
array
]
=
[]
output_tokens
:
List
[
array
]
=
[]
top_ks
:
List
[
int
]
=
[]
...
...
@@ -409,19 +390,10 @@ class SamplingTensors:
presence_penalties
:
List
[
float
]
=
[]
frequency_penalties
:
List
[
float
]
=
[]
repetition_penalties
:
List
[
float
]
=
[]
sampling_seeds
:
List
[
int
]
=
[]
sample_indices
:
List
[
int
]
=
[]
do_penalties
=
False
do_top_p_top_k
=
False
do_min_p
=
False
if
_USE_TRITON_SAMPLER
:
prompt_best_of
:
List
[
int
]
=
[]
# We need one base seed per Triton slice.
seeds_to_generate
=
(
extra_seeds_to_generate
+
get_num_triton_sampler_splits
(
vocab_size
))
assert
sampling_metadata
.
seq_groups
is
not
None
for
seq_group
in
sampling_metadata
.
seq_groups
:
seq_ids
=
seq_group
.
seq_ids
...
...
@@ -452,7 +424,7 @@ class SamplingTensors:
do_penalties
=
True
is_prompt
=
seq_group
.
is_prompt
if
(
is_prompt
and
sampling_params
.
prompt_logprobs
is
not
None
)
:
if
is_prompt
and
sampling_params
.
prompt_logprobs
is
not
None
:
# For tokens in the prompt that we only need to get
# their logprobs
query_len
=
seq_group
.
query_len
...
...
@@ -477,28 +449,6 @@ class SamplingTensors:
frequency_penalties
+=
[
f
]
*
len
(
seq_ids
)
repetition_penalties
+=
[
r
]
*
len
(
seq_ids
)
if
_USE_TRITON_SAMPLER
:
if
is_prompt
:
prompt_best_of
.
append
(
sampling_params
.
best_of
)
query_len
=
seq_group
.
query_len
assert
query_len
is
not
None
seed
=
sampling_params
.
seed
is_greedy
=
sampling_params
.
sampling_type
==
SamplingType
.
GREEDY
for
seq_id
in
seq_ids
:
seq_data
=
seq_group
.
seq_data
[
seq_id
]
extra_entropy
=
extra_entropy
or
()
seq_seeds
=
cls
.
_get_sequence_seeds
(
seed
,
seq_data
.
get_len
(),
*
extra_entropy
,
seq_id
,
seeds_to_generate
=
seeds_to_generate
,
is_greedy
=
is_greedy
)
sampling_seeds
.
append
(
seq_seeds
)
sample_indices
.
extend
(
seq_group
.
sample_indices
)
if
do_penalties
:
for
seq_group
in
sampling_metadata
.
seq_groups
:
seq_ids
=
seq_group
.
seq_ids
...
...
@@ -518,23 +468,37 @@ class SamplingTensors:
output_tokens
.
append
(
seq_data
.
output_token_ids_array
)
sampling_tensors
=
SamplingTensors
.
from_lists
(
temperatures
,
top_ps
,
top_ks
,
min_ps
,
presence_penalties
,
frequency_penalties
,
repetition_penalties
,
sampling_seeds
,
sample_indices
,
prompt_tokens
,
output_tokens
,
vocab_size
,
extra_seeds_to_generate
,
device
,
dtype
)
temperatures
,
top_ps
,
top_ks
,
min_ps
,
presence_penalties
,
frequency_penalties
,
repetition_penalties
,
prompt_tokens
,
output_tokens
,
vocab_size
,
device
,
dtype
,
)
return
(
sampling_tensors
,
do_penalties
,
do_top_p_top_k
,
do_min_p
)
@
classmethod
def
from_lists
(
cls
,
temperatures
:
List
[
float
],
top_ps
:
List
[
float
],
top_ks
:
List
[
int
],
min_ps
:
List
[
float
],
presence_penalties
:
List
[
float
],
frequency_penalties
:
List
[
float
],
repetition_penalties
:
List
[
float
],
sampling_seeds
:
List
[
int
],
sample_indices
:
List
[
int
],
prompt_tokens
:
List
[
array
],
output_tokens
:
List
[
array
],
vocab_size
:
int
,
extra_seeds_to_generate
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
)
->
"SamplingTensors"
:
def
from_lists
(
cls
,
temperatures
:
List
[
float
],
top_ps
:
List
[
float
],
top_ks
:
List
[
int
],
min_ps
:
List
[
float
],
presence_penalties
:
List
[
float
],
frequency_penalties
:
List
[
float
],
repetition_penalties
:
List
[
float
],
prompt_tokens
:
List
[
array
],
output_tokens
:
List
[
array
],
vocab_size
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
)
->
"SamplingTensors"
:
# Note that the performance will be very bad without
# pinned memory.
pin_memory
=
is_pin_memory_available
()
...
...
@@ -603,34 +567,9 @@ class SamplingTensors:
dtype
=
torch
.
int
,
pin_memory
=
pin_memory
,
)
sample_indices_t
=
torch
.
tensor
(
sample_indices
,
device
=
"cpu"
,
dtype
=
torch
.
long
,
pin_memory
=
pin_memory
,
)
# need to transpose and make contiguous to
# copy the tensor correctly.
# [batch_size, n_seeds] -> [n_seeds, batch_size]
sampling_seeds_t
=
torch
.
tensor
(
sampling_seeds
,
device
=
"cpu"
,
dtype
=
torch
.
long
,
pin_memory
=
pin_memory
,
).
t
().
contiguous
()
# Because the memory is pinned, we can do non-blocking
# transfer to device.
# How many seeds the sample operation itself will need.
num_base_seeds
=
sampling_seeds_t
.
shape
[
0
]
-
extra_seeds_to_generate
sampling_seeds_gpu
=
sampling_seeds_t
.
to
(
device
=
device
,
non_blocking
=
True
)
extra_seeds_gpu
=
sampling_seeds_gpu
[
num_base_seeds
:]
if
not
extra_seeds_gpu
.
numel
():
extra_seeds_gpu
=
None
sampling_seeds_gpu
=
sampling_seeds_gpu
[:
num_base_seeds
]
return
cls
(
temperatures
=
temperatures_t
.
to
(
device
=
device
,
non_blocking
=
True
),
top_ps
=
top_ps_t
.
to
(
device
=
device
,
non_blocking
=
True
),
...
...
@@ -644,38 +583,4 @@ class SamplingTensors:
non_blocking
=
True
),
prompt_tokens
=
prompt_t
.
to
(
device
=
device
,
non_blocking
=
True
),
output_tokens
=
output_t
.
to
(
device
=
device
,
non_blocking
=
True
),
sampling_seeds
=
sampling_seeds_gpu
,
sample_indices
=
sample_indices_t
.
to
(
device
=
device
,
non_blocking
=
True
),
extra_seeds
=
extra_seeds_gpu
,
)
@
staticmethod
def
_get_sequence_seeds
(
seed
:
int
,
*
extra_entropy
:
int
,
seeds_to_generate
:
int
,
is_greedy
:
bool
,
):
"""Get `seeds_to_generate` child seeds from `seed` and extra entropy."""
if
not
is_greedy
:
if
seed
is
None
:
randint_fn
=
random
.
randint
else
:
generator
=
random
.
Random
(
str
((
seed
,
)
+
extra_entropy
))
randint_fn
=
generator
.
randint
lo
,
hi
=
torch
.
iinfo
(
torch
.
long
).
min
,
torch
.
iinfo
(
torch
.
long
).
max
# If the user/random sets seed = 0 but request should
# have sampling, we need to change it to something
# else. We use a constant in that case.
# This way we don't need to create and load a bool
# matrix in the sampling kernel, which reduces CPU
# overhead and latency.
seq_seeds
=
[
randint_fn
(
lo
,
hi
)
or
_SEED_0_REPLACEMENT
for
_
in
range
(
seeds_to_generate
)
]
else
:
# For the kernel, seed == 0 means greedy decoding.
seq_seeds
=
[
0
]
*
seeds_to_generate
return
seq_seeds
vllm/model_executor/utils.py
View file @
539aa992
"""Utils for model executor."""
import
random
from
typing
import
Any
,
Dict
,
Optional
import
numpy
as
np
import
torch
from
vllm.utils
import
seed_everything
def
set_random_seed
(
seed
:
int
)
->
None
:
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
seed
)
seed_everything
(
seed
)
def
set_weight_attrs
(
...
...
vllm/multimodal/base.py
View file @
539aa992
...
...
@@ -14,7 +14,8 @@ from typing_extensions import TypeAlias
from
vllm.config
import
ModelConfig
from
vllm.inputs
import
InputContext
from
vllm.logger
import
init_logger
from
vllm.utils
import
JSONTree
,
is_list_of
,
json_map_leaves
from
vllm.utils
import
(
JSONTree
,
get_allowed_kwarg_only_overrides
,
is_list_of
,
json_map_leaves
)
logger
=
init_logger
(
__name__
)
...
...
@@ -53,6 +54,12 @@ class MultiModalInputs(_MultiModalInputsBase):
if
isinstance
(
nested_tensors
,
torch
.
Tensor
):
return
nested_tensors
if
isinstance
(
nested_tensors
,
np
.
ndarray
):
return
torch
.
from_numpy
(
nested_tensors
)
if
isinstance
(
nested_tensors
,
(
int
,
float
)):
return
torch
.
tensor
(
nested_tensors
)
stacked
=
[
MultiModalInputs
.
_try_stack
(
t
)
for
t
in
nested_tensors
]
if
not
is_list_of
(
stacked
,
torch
.
Tensor
,
check
=
"all"
):
# Only tensors (not lists) can be stacked.
...
...
@@ -256,11 +263,20 @@ class MultiModalPlugin(ABC):
model_cls
,
_
=
get_model_architecture
(
model_config
)
mapper
=
self
.
_input_mappers
.
get
(
model_cls
)
# Only get processor kwargs at mapping time if we are not using the
# input mapper; no overrides are used on the default here because they
# should be passed to the huggingface resource at initialization time.
if
mapper
is
not
None
and
mapper
!=
self
.
_default_input_mapper
:
mm_processor_kwargs
=
get_allowed_kwarg_only_overrides
(
mapper
,
overrides
=
model_config
.
mm_processor_kwargs
)
else
:
mm_processor_kwargs
=
{}
if
mapper
is
None
:
raise
KeyError
(
f
"No input mapper in
{
self
}
is registered for "
f
"model class
{
model_cls
.
__name__
}
."
)
return
mapper
(
InputContext
(
model_config
),
data
)
return
mapper
(
InputContext
(
model_config
),
data
,
**
mm_processor_kwargs
)
@
abstractmethod
def
_default_max_multimodal_tokens
(
self
,
ctx
:
InputContext
)
->
int
:
...
...
@@ -333,7 +349,10 @@ class MultiModalPlugin(ABC):
f
"for model class
{
model_cls
.
__name__
}
in
{
self
}
."
)
if
callable
(
max_mm_tokens
):
max_mm_tokens
=
max_mm_tokens
(
InputContext
(
model_config
))
mm_processor_kwargs
=
get_allowed_kwarg_only_overrides
(
max_mm_tokens
,
overrides
=
model_config
.
mm_processor_kwargs
)
max_mm_tokens
=
max_mm_tokens
(
InputContext
(
model_config
),
**
mm_processor_kwargs
)
self
.
_validate_max_multimodal_tokens
(
max_mm_tokens
)
...
...
Prev
1
…
13
14
15
16
17
18
19
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment