Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
539aa992
Commit
539aa992
authored
Sep 27, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.2' into v0.6.2-dev
parents
93872128
7193774b
Changes
383
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2640 additions
and
313 deletions
+2640
-313
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+111
-38
vllm/model_executor/models/mllama.py
vllm/model_executor/models/mllama.py
+1142
-0
vllm/model_executor/models/olmoe.py
vllm/model_executor/models/olmoe.py
+409
-0
vllm/model_executor/models/paligemma.py
vllm/model_executor/models/paligemma.py
+7
-10
vllm/model_executor/models/persimmon.py
vllm/model_executor/models/persimmon.py
+6
-6
vllm/model_executor/models/phi3.py
vllm/model_executor/models/phi3.py
+17
-0
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+22
-9
vllm/model_executor/models/phimoe.py
vllm/model_executor/models/phimoe.py
+6
-2
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+6
-10
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+14
-11
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+15
-7
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+48
-20
vllm/model_executor/models/siglip.py
vllm/model_executor/models/siglip.py
+28
-14
vllm/model_executor/models/solar.py
vllm/model_executor/models/solar.py
+580
-0
vllm/model_executor/models/ultravox.py
vllm/model_executor/models/ultravox.py
+53
-22
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+35
-1
vllm/model_executor/parameter.py
vllm/model_executor/parameter.py
+58
-0
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+58
-153
vllm/model_executor/utils.py
vllm/model_executor/utils.py
+3
-7
vllm/multimodal/base.py
vllm/multimodal/base.py
+22
-3
No files found.
vllm/model_executor/models/minicpmv.py
View file @
539aa992
...
@@ -23,7 +23,6 @@
...
@@ -23,7 +23,6 @@
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
import
math
import
math
import
re
import
re
from
array
import
array
from
functools
import
partial
from
functools
import
partial
from
typing
import
(
Any
,
Callable
,
Iterable
,
List
,
Mapping
,
Optional
,
Tuple
,
from
typing
import
(
Any
,
Callable
,
Iterable
,
List
,
Mapping
,
Optional
,
Tuple
,
TypedDict
)
TypedDict
)
...
@@ -34,11 +33,11 @@ from PIL import Image
...
@@ -34,11 +33,11 @@ from PIL import Image
from
torch
import
nn
from
torch
import
nn
from
torch.nn.init
import
trunc_normal_
from
torch.nn.init
import
trunc_normal_
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
typing_extensions
import
NotRequired
from
vllm.attention
import
AttentionMetadata
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
...
@@ -54,21 +53,30 @@ from vllm.model_executor.models.minicpm import MiniCPMModel
...
@@ -54,21 +53,30 @@ from vllm.model_executor.models.minicpm import MiniCPMModel
from
vllm.model_executor.models.qwen2
import
Qwen2Model
from
vllm.model_executor.models.qwen2
import
Qwen2Model
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.image
import
cached_get_image_processor
from
vllm.multimodal.image
import
cached_get_image_processor
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
SequenceData
)
from
.idefics2_vision_model
import
Idefics2VisionTransformer
from
.idefics2_vision_model
import
Idefics2VisionTransformer
logger
=
init_logger
(
__name__
)
_KEYS_TO_MODIFY_MAPPING
=
{
_KEYS_TO_MODIFY_MAPPING
=
{
"llm.lm_head"
:
"lm_head"
,
"llm.lm_head"
:
"lm_head"
,
"llm.model"
:
"llm"
,
"llm.model"
:
"llm"
,
}
}
class
MiniCPMVImageInput
(
TypedDict
):
"""Input mapper input with auxiliary data for computing image bounds."""
image
:
Image
.
Image
# Image bounds token ids in 0-dim scaler tensor.
im_start_id
:
torch
.
Tensor
im_end_id
:
torch
.
Tensor
slice_start_id
:
NotRequired
[
torch
.
Tensor
]
slice_end_id
:
NotRequired
[
torch
.
Tensor
]
class
MiniCPMVImagePixelInputs
(
TypedDict
):
class
MiniCPMVImagePixelInputs
(
TypedDict
):
pixel_values
:
List
[
torch
.
Tensor
]
pixel_values
:
List
[
torch
.
Tensor
]
"""
"""
...
@@ -93,8 +101,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
...
@@ -93,8 +101,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
"""
"""
MiniCPMVImageInputs
=
MiniCPMVImagePixelInputs
DEFAULT_LN
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
DEFAULT_LN
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
...
@@ -239,6 +245,25 @@ class Resampler2_5(BaseResampler):
...
@@ -239,6 +245,25 @@ class Resampler2_5(BaseResampler):
return
x
return
x
def
_build_image_input
(
ctx
:
InputContext
,
image
:
Image
.
Image
)
->
MiniCPMVImageInput
:
tokenizer
=
cached_get_tokenizer
(
ctx
.
model_config
.
tokenizer
,
trust_remote_code
=
ctx
.
model_config
.
trust_remote_code
)
if
hasattr
(
tokenizer
,
"slice_start_id"
):
return
MiniCPMVImageInput
(
image
=
image
,
im_start_id
=
torch
.
tensor
(
tokenizer
.
im_start_id
),
im_end_id
=
torch
.
tensor
(
tokenizer
.
im_end_id
),
slice_start_id
=
torch
.
tensor
(
tokenizer
.
slice_start_id
),
slice_end_id
=
torch
.
tensor
(
tokenizer
.
slice_end_id
))
else
:
return
MiniCPMVImageInput
(
image
=
image
,
im_start_id
=
torch
.
tensor
(
tokenizer
.
im_start_id
),
im_end_id
=
torch
.
tensor
(
tokenizer
.
im_end_id
))
def
get_version_by_config
(
config
:
PretrainedConfig
)
->
Tuple
[
int
,
...]:
def
get_version_by_config
(
config
:
PretrainedConfig
)
->
Tuple
[
int
,
...]:
version_float
=
getattr
(
config
,
"version"
,
None
)
version_float
=
getattr
(
config
,
"version"
,
None
)
...
@@ -259,14 +284,16 @@ def get_max_minicpmv_image_tokens(ctx: InputContext):
...
@@ -259,14 +284,16 @@ def get_max_minicpmv_image_tokens(ctx: InputContext):
def
dummy_seq_data_for_minicpmv
(
seq_len
:
int
,
num_images
:
int
):
def
dummy_seq_data_for_minicpmv
(
seq_len
:
int
,
num_images
:
int
):
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
seq_len
return
SequenceData
.
from_token_counts
((
0
,
seq_len
))
return
SequenceData
(
token_ids
)
def
dummy_image_for_minicpmv
(
hf_config
:
PretrainedConfig
,
num_images
:
int
):
def
dummy_image_for_minicpmv
(
ctx
:
InputContext
,
hf_config
:
PretrainedConfig
,
num_images
:
int
):
width
=
height
=
hf_config
.
image_size
width
=
height
=
hf_config
.
image_size
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
)
image
=
_build_image_input
(
ctx
,
return
{
"image"
:
image
if
num_images
==
1
else
[
image
]
*
num_images
}
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
))
return
{
"image"
:
[
image
]
if
num_images
==
1
else
[
image
]
*
num_images
}
def
dummy_data_for_minicpmv
(
ctx
:
InputContext
,
seq_len
:
int
,
def
dummy_data_for_minicpmv
(
ctx
:
InputContext
,
seq_len
:
int
,
...
@@ -275,7 +302,7 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
...
@@ -275,7 +302,7 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
num_images
=
mm_counts
[
"image"
]
num_images
=
mm_counts
[
"image"
]
seq_data
=
dummy_seq_data_for_minicpmv
(
seq_len
,
num_images
)
seq_data
=
dummy_seq_data_for_minicpmv
(
seq_len
,
num_images
)
mm_data
=
dummy_image_for_minicpmv
(
hf_config
,
num_images
)
mm_data
=
dummy_image_for_minicpmv
(
ctx
,
hf_config
,
num_images
)
return
seq_data
,
mm_data
return
seq_data
,
mm_data
...
@@ -286,8 +313,9 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
...
@@ -286,8 +313,9 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
return
llm_inputs
return
llm_inputs
model_config
=
ctx
.
model_config
model_config
=
ctx
.
model_config
version
=
get_version_by_config
(
model_config
.
hf_config
)
version
=
get_version_by_config
(
model_config
.
hf_config
)
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
tokenizer
=
cached_get_tokenizer
(
trust_remote_code
=
True
)
model_config
.
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
)
image_processor
=
cached_get_image_processor
(
model_config
.
tokenizer
)
image_processor
=
cached_get_image_processor
(
model_config
.
tokenizer
)
def
get_placeholder
(
image_size
:
Tuple
[
int
,
int
],
num_image
:
int
):
def
get_placeholder
(
image_size
:
Tuple
[
int
,
int
],
num_image
:
int
):
...
@@ -323,6 +351,10 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
...
@@ -323,6 +351,10 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
new_prompt
=
""
.
join
(
new_prompt_chunks
)
new_prompt
=
""
.
join
(
new_prompt_chunks
)
new_token_ids
=
tokenizer
.
encode
(
new_prompt
)
new_token_ids
=
tokenizer
.
encode
(
new_prompt
)
multi_modal_data
[
"image"
]
=
[
_build_image_input
(
ctx
,
image
)
for
image
in
images
]
llm_inputs
=
LLMInputs
(
llm_inputs
=
LLMInputs
(
prompt_token_ids
=
new_token_ids
,
prompt_token_ids
=
new_token_ids
,
prompt
=
new_prompt
,
prompt
=
new_prompt
,
...
@@ -331,6 +363,32 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
...
@@ -331,6 +363,32 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
return
llm_inputs
return
llm_inputs
def
input_mapper_for_minicpmv
(
ctx
:
InputContext
,
data
:
object
):
model_config
=
ctx
.
model_config
image_processor
=
cached_get_image_processor
(
model_config
.
model
,
trust_remote_code
=
model_config
.
trust_remote_code
)
if
image_processor
is
None
:
raise
RuntimeError
(
"No HuggingFace processor is available "
"to process the image object"
)
if
not
isinstance
(
data
,
list
):
raise
ValueError
(
"Image input must be list of MiniCPMVImageInput, got (%s)"
,
data
)
batch_data
=
image_processor
\
.
preprocess
([
img
[
"image"
]
for
img
in
data
],
return_tensors
=
"pt"
)
\
.
data
if
len
(
data
)
>
0
:
batch_data
[
"im_start_id"
]
=
data
[
0
][
"im_start_id"
]
batch_data
[
"im_end_id"
]
=
data
[
0
][
"im_end_id"
]
if
"slice_start_id"
in
data
[
0
]:
batch_data
[
"slice_start_id"
]
=
data
[
0
][
"slice_start_id"
]
batch_data
[
"slice_end_id"
]
=
data
[
0
][
"slice_end_id"
]
return
MultiModalInputs
(
batch_data
)
class
MiniCPMVBaseModel
(
nn
.
Module
,
SupportsMultiModal
):
class
MiniCPMVBaseModel
(
nn
.
Module
,
SupportsMultiModal
):
"""
"""
The abstract class of MiniCPMV can only be inherited, but cannot be
The abstract class of MiniCPMV can only be inherited, but cannot be
...
@@ -371,7 +429,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
...
@@ -371,7 +429,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
def
get_embedding
(
def
get_embedding
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
image_inputs
:
Optional
[
MiniCPMVImageInputs
],
image_inputs
:
Optional
[
MiniCPMVImage
Pixel
Inputs
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
vlm_embedding
:
torch
.
Tensor
=
self
.
llm
.
embed_tokens
(
input_ids
)
vlm_embedding
:
torch
.
Tensor
=
self
.
llm
.
embed_tokens
(
input_ids
)
if
hasattr
(
self
.
config
,
"scale_emb"
):
if
hasattr
(
self
.
config
,
"scale_emb"
):
...
@@ -399,14 +457,20 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
...
@@ -399,14 +457,20 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
return
vlm_embedding
,
vision_hidden_states
return
vlm_embedding
,
vision_hidden_states
def
_get_image_bounds
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_get_image_bounds
(
tokenizer
=
cached_get_tokenizer
(
self
.
config
.
_name_or_path
,
self
,
trust_remote_code
=
True
)
input_ids
:
torch
.
Tensor
,
start_cond
=
input_ids
==
tokenizer
.
im_start_id
im_start_id
:
torch
.
Tensor
,
end_cond
=
input_ids
==
tokenizer
.
im_end_id
im_end_id
:
torch
.
Tensor
,
if
hasattr
(
tokenizer
,
"slice_start_id"
):
slice_start_id
:
Optional
[
torch
.
Tensor
]
=
None
,
start_cond
|=
(
input_ids
==
tokenizer
.
slice_start_id
)
slice_end_id
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
end_cond
|=
(
input_ids
==
tokenizer
.
slice_end_id
)
# All the images in the batch should share the same special image
# bound token ids.
start_cond
=
input_ids
==
im_start_id
[
0
]
end_cond
=
input_ids
==
im_end_id
[
0
]
if
slice_start_id
is
not
None
:
start_cond
|=
(
input_ids
==
slice_start_id
[
0
])
end_cond
|=
(
input_ids
==
slice_end_id
[
0
])
image_start_tokens
,
=
torch
.
where
(
start_cond
)
image_start_tokens
,
=
torch
.
where
(
start_cond
)
image_start_tokens
+=
1
image_start_tokens
+=
1
...
@@ -425,7 +489,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
...
@@ -425,7 +489,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
**
kwargs
:
object
,
**
kwargs
:
object
,
)
->
Optional
[
MiniCPMVImageInputs
]:
)
->
Optional
[
MiniCPMVImage
Pixel
Inputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
[])
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
[])
tgt_sizes
=
kwargs
.
pop
(
"tgt_sizes"
,
[])
tgt_sizes
=
kwargs
.
pop
(
"tgt_sizes"
,
[])
...
@@ -462,8 +526,17 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
...
@@ -462,8 +526,17 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
if
len
(
pixel_values_flat
)
==
0
:
if
len
(
pixel_values_flat
)
==
0
:
return
None
return
None
return
MiniCPMVImageInputs
(
im_start_id
=
kwargs
.
pop
(
"im_start_id"
,
None
)
image_bounds
=
self
.
_get_image_bounds
(
input_ids
),
im_end_id
=
kwargs
.
pop
(
"im_end_id"
,
None
)
slice_start_id
=
kwargs
.
pop
(
"slice_start_id"
,
None
)
slice_end_id
=
kwargs
.
pop
(
"slice_end_id"
,
None
)
if
im_start_id
is
None
:
return
None
return
MiniCPMVImagePixelInputs
(
image_bounds
=
self
.
_get_image_bounds
(
input_ids
,
im_start_id
,
im_end_id
,
slice_start_id
,
slice_end_id
),
pixel_values
=
pixel_values_flat
,
pixel_values
=
pixel_values_flat
,
tgt_sizes
=
torch
.
stack
(
tgt_sizes_flat
),
tgt_sizes
=
torch
.
stack
(
tgt_sizes_flat
),
)
)
...
@@ -570,8 +643,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
...
@@ -570,8 +643,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
raise
NotImplementedError
raise
NotImplementedError
def
get_vision_hidden_states
(
self
,
def
get_vision_hidden_states
(
data
:
MiniCPMVImageInputs
)
->
torch
.
Tensor
:
self
,
data
:
MiniCPMVImage
Pixel
Inputs
)
->
torch
.
Tensor
:
raise
NotImplementedError
raise
NotImplementedError
def
is_default_weight_loading
(
self
,
name
:
str
)
->
bool
:
def
is_default_weight_loading
(
self
,
name
:
str
)
->
bool
:
...
@@ -660,8 +733,8 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
...
@@ -660,8 +733,8 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
res
.
append
(
self
.
resampler
(
vision_embedding
,
tgt_size
))
res
.
append
(
self
.
resampler
(
vision_embedding
,
tgt_size
))
return
torch
.
vstack
(
res
)
return
torch
.
vstack
(
res
)
def
get_vision_hidden_states
(
self
,
def
get_vision_hidden_states
(
data
:
MiniCPMVImageInputs
)
->
torch
.
Tensor
:
self
,
data
:
MiniCPMVImage
Pixel
Inputs
)
->
torch
.
Tensor
:
pixel_values
=
data
[
"pixel_values"
]
pixel_values
=
data
[
"pixel_values"
]
return
self
.
get_vision_embedding
(
pixel_values
)
return
self
.
get_vision_embedding
(
pixel_values
)
...
@@ -719,8 +792,8 @@ class MiniCPMV2_5(MiniCPMVBaseModel):
...
@@ -719,8 +792,8 @@ class MiniCPMV2_5(MiniCPMVBaseModel):
vision_embedding
=
self
.
resampler
(
vision_embedding
,
tgt_sizes
)
vision_embedding
=
self
.
resampler
(
vision_embedding
,
tgt_sizes
)
return
vision_embedding
return
vision_embedding
def
get_vision_hidden_states
(
self
,
def
get_vision_hidden_states
(
data
:
MiniCPMVImageInputs
)
->
torch
.
Tensor
:
self
,
data
:
MiniCPMVImage
Pixel
Inputs
)
->
torch
.
Tensor
:
pixel_values
=
data
[
"pixel_values"
]
pixel_values
=
data
[
"pixel_values"
]
tgt_sizes
=
data
[
"tgt_sizes"
]
tgt_sizes
=
data
[
"tgt_sizes"
]
...
@@ -813,8 +886,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
...
@@ -813,8 +886,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
).
last_hidden_state
).
last_hidden_state
return
vision_embedding
return
vision_embedding
def
get_vision_hidden_states
(
self
,
def
get_vision_hidden_states
(
data
:
MiniCPMVImageInputs
)
->
torch
.
Tensor
:
self
,
data
:
MiniCPMVImage
Pixel
Inputs
)
->
torch
.
Tensor
:
pixel_values
=
data
[
"pixel_values"
]
pixel_values
=
data
[
"pixel_values"
]
tgt_sizes
=
data
[
"tgt_sizes"
]
tgt_sizes
=
data
[
"tgt_sizes"
]
...
@@ -857,7 +930,7 @@ _SUPPORT_VERSION = {
...
@@ -857,7 +930,7 @@ _SUPPORT_VERSION = {
}
}
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
()
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
(
input_mapper_for_minicpmv
)
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_minicpmv_image_tokens
)
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_minicpmv_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_minicpmv
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_minicpmv
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_minicpmv
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_minicpmv
)
...
@@ -884,7 +957,7 @@ class MiniCPMV(MiniCPMVBaseModel):
...
@@ -884,7 +957,7 @@ class MiniCPMV(MiniCPMVBaseModel):
version
=
str
(
config
.
version
).
split
(
"."
)
version
=
str
(
config
.
version
).
split
(
"."
)
version
=
tuple
([
int
(
x
)
for
x
in
version
])
version
=
tuple
([
int
(
x
)
for
x
in
version
])
# Dispatch class based on version
# Dispatch class based on version
instance_class
=
_SUPPORT_VERSION
.
get
(
version
,
None
)
instance_class
=
_SUPPORT_VERSION
.
get
(
version
)
if
instance_class
is
None
:
if
instance_class
is
None
:
raise
ValueError
(
raise
ValueError
(
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6"
)
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6"
)
...
...
vllm/model_executor/models/mllama.py
0 → 100644
View file @
539aa992
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Mllama model."""
import
math
from
array
import
array
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
import
torch
import
torch.nn.functional
as
F
import
torch.utils.checkpoint
import
transformers.models.mllama.configuration_mllama
as
config_mllama
from
PIL
import
Image
from
torch
import
nn
from
transformers.modeling_outputs
import
(
BaseModelOutput
,
CausalLMOutputWithPast
)
from
transformers.models.mllama.image_processing_mllama
import
(
get_optimal_tiled_canvas
)
import
vllm.distributed.parallel_state
as
ps
from
vllm.attention
import
Attention
,
AttentionMetadata
,
AttentionType
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
from
.clip
import
CLIPMLP
from
.interfaces
import
SupportsMultiModal
from
.llama
import
LlamaDecoderLayer
,
LlamaMLP
logger
=
init_logger
(
__name__
)
MLLAMA_IMAGE_TOKEN_ID
=
128256
MLLAMA_IMAGE_TOKEN
=
"<|image|>"
class
MllamaImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
data
:
torch
.
Tensor
"""Shape: """
"""(batch_size, max_num_image, max_num_chunk, num_channel, height, width)"""
aspect_ratio_ids
:
torch
.
Tensor
"""Shape: `(batch_size, max_num_image)`"""
aspect_ratio_mask
:
torch
.
Tensor
"""Shape: `(batch_size, max_num_image, max_num_tiles)`"""
# TODO: support LlamaImageEmbeddingInputs
def
input_processor_for_mllama
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
):
# move encoder_prompt to prompt
if
llm_inputs
.
get
(
"prompt"
)
is
None
:
llm_inputs
[
"prompt"
]
=
llm_inputs
[
"encoder_prompt"
]
llm_inputs
[
"prompt_token_ids"
]
=
llm_inputs
[
"encoder_prompt_token_ids"
]
# process multi-modal data
assert
"decoder_multi_modal_data"
not
in
llm_inputs
,
\
"multi-modal data should be put in encoder message of mllama"
multi_modal_data
=
llm_inputs
.
get
(
"encoder_multi_modal_data"
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
\
or
multi_modal_data
[
"image"
]
is
None
:
# text-only
llm_inputs
[
"encoder_prompt"
]
=
""
llm_inputs
[
"encoder_prompt_token_ids"
]
=
[]
llm_inputs
[
"encoder_multi_modal_data"
]
=
{}
return
llm_inputs
# get num_tiles
if
isinstance
(
multi_modal_data
[
'image'
],
Image
.
Image
):
multi_modal_data
[
'image'
]
=
[
multi_modal_data
[
'image'
]]
hf_config
=
ctx
.
model_config
.
hf_config
num_tiles
=
0
for
image
in
multi_modal_data
[
"image"
]:
width
,
height
=
image
.
size
tile_size
=
hf_config
.
vision_config
.
image_size
canvas_height
,
canvas_width
=
get_optimal_tiled_canvas
(
image_height
=
height
,
image_width
=
width
,
max_image_tiles
=
hf_config
.
vision_config
.
max_num_tiles
,
tile_size
=
tile_size
,
)
num_tiles_height
=
canvas_height
//
tile_size
num_tiles_width
=
canvas_width
//
tile_size
num_tiles
+=
num_tiles_height
*
num_tiles_width
# set encoder prompt based on num_tiles
assert
hf_config
.
vision_config
.
image_size
%
14
==
0
,
\
"chunk size should be multiple of 14"
token_per_chunk
=
(
hf_config
.
vision_config
.
image_size
//
14
)
**
2
+
1
num_tokens
=
num_tiles
*
token_per_chunk
llm_inputs
[
"encoder_prompt"
]
=
MLLAMA_IMAGE_TOKEN
*
num_tokens
llm_inputs
[
"encoder_prompt_token_ids"
]
=
[
MLLAMA_IMAGE_TOKEN_ID
]
*
num_tokens
return
llm_inputs
def
get_max_mllama_image_tokens
(
ctx
:
InputContext
)
->
int
:
hf_config
=
ctx
.
model_config
.
hf_config
token_per_chunk
=
(
hf_config
.
vision_config
.
image_size
//
14
)
**
2
+
1
return
hf_config
.
vision_config
.
max_num_tiles
*
token_per_chunk
def
dummy_decoder_seq_data
(
seq_len
:
int
,
num_images
:
int
):
# <|image|> * num_images + 0 * (seq_len - num_images)
assert
seq_len
>=
num_images
,
\
"seq_len should be greater than or equal to num_images"
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
MLLAMA_IMAGE_TOKEN_ID
])
*
num_images
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
num_images
)
return
SequenceData
(
token_ids
)
def
dummy_encoder_seq_data
(
ctx
:
InputContext
,
num_images
:
int
):
num_tokens
=
get_max_mllama_image_tokens
(
ctx
)
*
num_images
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
MLLAMA_IMAGE_TOKEN_ID
])
*
num_tokens
return
SequenceData
(
token_ids
)
def
dummy_image
(
num_images
:
int
,
):
width
=
height
=
1024
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
)
return
{
"image"
:
image
if
num_images
==
1
else
[
image
]
*
num_images
}
def
dummy_decoder_data_for_mllama
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
num_images
=
mm_counts
[
"image"
]
return
dummy_decoder_seq_data
(
seq_len
,
num_images
),
None
def
dummy_encoder_data_for_mllama
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
num_images
=
mm_counts
[
"image"
]
return
dummy_encoder_seq_data
(
ctx
,
num_images
),
dummy_image
(
num_images
)
def
_prepare_aspect_ratio_attention_mask
(
aspect_ratio_mask
:
torch
.
Tensor
,
num_patches
:
int
,
target_length
:
int
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
# Expand aspect ratio mask to target_length
batch_size
,
max_num_tiles
=
aspect_ratio_mask
.
shape
attention_mask
=
aspect_ratio_mask
.
view
(
batch_size
,
max_num_tiles
,
1
,
1
).
to
(
dtype
)
attention_mask
=
attention_mask
.
repeat
(
1
,
1
,
target_length
,
1
)
# Mask padding patches
pad_patches
=
target_length
-
num_patches
attention_mask
[:,
:,
-
pad_patches
:]
=
0
# Invert the mask (0 -> 1, 1 -> 0)
attention_mask
=
1
-
attention_mask
# Reshape to 2D and create 4D attention mask
# (batch_size, 1, max_num_tiles*target_length, max_num_tiles*target_length)
attention_mask
=
attention_mask
.
reshape
(
batch_size
,
max_num_tiles
*
target_length
,
1
)
attention_mask
=
attention_mask
@
attention_mask
.
transpose
(
-
1
,
-
2
)
*
torch
.
finfo
(
dtype
).
min
attention_mask
=
attention_mask
.
unsqueeze
(
1
)
return
attention_mask
class
ColumnParallelConv2dPatch
(
torch
.
nn
.
Module
):
"""Conv2D Patching layer with model parallelism.
Column parallel over unfolded input.
Arguments:
in_channels: Input channels.
out_channels: Output channels.
kernel_size: Size of convolution kernel.
stride (default 1): Stride for convolution.
bias (default False): Use bias in Conv2d.
Input: (bsz, in_channels, width, height)
Output: (bsz, num_tokens, out_channels)
"""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
Union
[
int
,
Tuple
[
int
,
int
]],
stride
:
Union
[
int
,
Tuple
[
int
,
int
]],
bias
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
if
isinstance
(
kernel_size
,
int
):
kernel_size
=
(
kernel_size
,
kernel_size
)
self
.
_unfold
=
torch
.
nn
.
Unfold
(
kernel_size
=
kernel_size
,
stride
=
stride
)
self
.
_linear
=
ColumnParallelLinear
(
in_channels
*
kernel_size
[
0
]
*
kernel_size
[
1
],
out_channels
,
bias
=
bias
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
self
.
_unfold
(
x
)
x
=
x
.
permute
(
0
,
2
,
1
)
x
,
_
=
self
.
_linear
(
x
)
return
x
class
MllamaPrecomputedAspectRatioEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaVisionConfig
,
is_gated
:
bool
=
True
):
super
().
__init__
()
self
.
max_num_tiles
=
config
.
max_num_tiles
self
.
hidden_size
=
config
.
hidden_size
self
.
max_aspect_ratio_id
=
config
.
max_aspect_ratio_id
self
.
is_gated
=
is_gated
self
.
embedding
=
nn
.
Embedding
(
self
.
max_aspect_ratio_id
+
1
,
self
.
max_num_tiles
*
self
.
hidden_size
)
if
is_gated
:
self
.
gate
=
nn
.
Parameter
(
torch
.
zeros
(
1
))
def
forward
(
self
,
hidden_state
:
torch
.
Tensor
,
aspect_ratio_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
embeddings
=
self
.
embedding
(
aspect_ratio_ids
)
embeddings
=
embeddings
.
reshape
(
-
1
,
self
.
max_num_tiles
,
1
,
self
.
hidden_size
)
if
self
.
is_gated
:
embeddings
=
embeddings
*
self
.
gate
.
tanh
()
hidden_state
=
hidden_state
+
embeddings
return
hidden_state
class
MllamaPrecomputedPositionEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaVisionConfig
):
super
().
__init__
()
self
.
max_num_tiles
=
config
.
max_num_tiles
self
.
max_aspect_ratio_id
=
config
.
max_aspect_ratio_id
self
.
num_patches
=
(
config
.
image_size
//
config
.
patch_size
)
**
2
+
1
self
.
hidden_size
=
config
.
hidden_size
self
.
scale
=
config
.
hidden_size
**-
0.5
self
.
gate
=
nn
.
Parameter
(
torch
.
zeros
(
1
))
# position embedding
position_embedding
=
torch
.
randn
(
self
.
num_patches
,
self
.
hidden_size
)
self
.
embedding
=
nn
.
Parameter
(
self
.
scale
*
position_embedding
)
# tile position embedding
self
.
tile_embedding
=
nn
.
Embedding
(
self
.
max_aspect_ratio_id
+
1
,
self
.
max_num_tiles
*
self
.
num_patches
*
self
.
hidden_size
)
def
forward
(
self
,
hidden_state
:
torch
.
Tensor
,
aspect_ratio_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# position embeddings
gated_position_embedding
=
(
1
-
self
.
gate
.
tanh
())
*
self
.
embedding
hidden_state
=
hidden_state
+
gated_position_embedding
.
view
(
1
,
1
,
self
.
num_patches
,
self
.
hidden_size
)
# precomputed tile position embeddings
tile_position_embedding
=
self
.
tile_embedding
(
aspect_ratio_ids
)
batch_size
=
hidden_state
.
shape
[
0
]
tile_position_embedding
=
tile_position_embedding
.
reshape
(
batch_size
,
self
.
max_num_tiles
,
self
.
num_patches
,
self
.
hidden_size
)
gated_tile_position_embedding
=
self
.
gate
.
tanh
(
)
*
tile_position_embedding
hidden_state
=
hidden_state
+
gated_tile_position_embedding
return
hidden_state
# TODO: support other attention backends for attention in vision model
class
MllamaVisionSdpaAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaVisionConfig
):
super
().
__init__
()
model_parallel_size
=
get_tensor_model_parallel_world_size
()
self
.
embed_dim
=
config
.
hidden_size
self
.
num_heads
=
config
.
attention_heads
self
.
head_dim
=
config
.
hidden_size
//
config
.
attention_heads
self
.
num_local_heads
=
self
.
num_heads
//
model_parallel_size
self
.
q_size
=
self
.
num_local_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_local_heads
*
self
.
head_dim
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
embed_dim
,
self
.
head_dim
,
self
.
num_heads
,
bias
=
False
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
num_heads
*
self
.
head_dim
,
self
.
embed_dim
,
bias
=
False
,
input_is_parallel
=
True
,
)
def
forward
(
self
,
hidden_state
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_state
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
=
q
.
view
(
q
.
shape
[
0
],
q
.
shape
[
1
],
self
.
num_local_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
k
=
k
.
view
(
k
.
shape
[
0
],
k
.
shape
[
1
],
self
.
num_local_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
v
=
v
.
view
(
v
.
shape
[
0
],
v
.
shape
[
1
],
self
.
num_local_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
# TODO: remove padding in image encoder
attn_output
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
attention_mask
,
dropout_p
=
0.0
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
reshape
(
attn_output
.
shape
[
0
],
attn_output
.
shape
[
1
],
-
1
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
MllamaVisionEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaVisionConfig
,
is_gated
:
bool
=
False
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
num_attention_heads
=
config
.
attention_heads
self
.
is_gated
=
is_gated
self
.
intermediate_size
=
config
.
intermediate_size
self
.
self_attn
=
MllamaVisionSdpaAttention
(
config
)
self
.
mlp
=
CLIPMLP
(
config
)
self
.
input_layernorm
=
nn
.
LayerNorm
(
self
.
hidden_size
,
eps
=
config
.
norm_eps
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
self
.
hidden_size
,
eps
=
config
.
norm_eps
)
# there used to be an if else here, no code path
if
is_gated
:
self
.
gate_attn
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
math
.
pi
/
4
)
self
.
gate_ffn
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
math
.
pi
/
4
)
def
forward
(
self
,
hidden_state
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
):
# Self Attention
residual
=
hidden_state
hidden_state
=
self
.
input_layernorm
(
hidden_state
)
hidden_state
=
self
.
self_attn
(
hidden_state
,
attention_mask
=
attention_mask
)
gate_attn
=
1
if
not
self
.
is_gated
else
self
.
gate_attn
.
tanh
()
hidden_state
=
residual
+
gate_attn
*
hidden_state
# Feed forward
residual
=
hidden_state
hidden_state
=
self
.
post_attention_layernorm
(
hidden_state
)
hidden_state
=
self
.
mlp
(
hidden_state
)
gate_ffn
=
1
if
not
self
.
is_gated
else
self
.
gate_ffn
.
tanh
()
hidden_state
=
residual
+
gate_ffn
*
hidden_state
return
hidden_state
class
MllamaVisionEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaVisionConfig
,
num_layers
=
32
,
is_gated
=
False
,
output_hidden_states
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
layers
=
nn
.
ModuleList
([
MllamaVisionEncoderLayer
(
config
,
is_gated
)
for
_
in
range
(
num_layers
)
])
self
.
output_hidden_states
=
output_hidden_states
or
[]
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutput
]:
encoder_states
=
()
for
i
,
encoder_layer
in
enumerate
(
self
.
layers
):
if
i
in
self
.
output_hidden_states
:
encoder_states
=
encoder_states
+
(
hidden_states
,
)
hidden_states
=
encoder_layer
(
hidden_states
,
attention_mask
,
)
if
len
(
self
.
layers
)
-
1
in
self
.
output_hidden_states
:
encoder_states
=
encoder_states
+
(
hidden_states
,
)
return
hidden_states
,
encoder_states
class
MllamaVisionModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaVisionConfig
):
super
().
__init__
()
self
.
image_size
=
config
.
image_size
self
.
patch_size
=
config
.
patch_size
self
.
max_num_tiles
=
config
.
max_num_tiles
self
.
hidden_size
=
config
.
hidden_size
self
.
in_channels
=
config
.
num_channels
self
.
intermediate_layers_indices
=
config
.
intermediate_layers_indices
self
.
num_patches
=
(
self
.
image_size
//
self
.
patch_size
)
**
2
+
1
self
.
scale
=
config
.
hidden_size
**-
0.5
self
.
patch_embedding
=
ColumnParallelConv2dPatch
(
in_channels
=
config
.
num_channels
,
out_channels
=
self
.
hidden_size
,
kernel_size
=
self
.
patch_size
,
stride
=
self
.
patch_size
,
bias
=
False
,
)
self
.
class_embedding
=
nn
.
Parameter
(
self
.
scale
*
torch
.
randn
(
self
.
hidden_size
))
self
.
gated_positional_embedding
=
MllamaPrecomputedPositionEmbedding
(
config
)
self
.
pre_tile_positional_embedding
=
\
MllamaPrecomputedAspectRatioEmbedding
(
config
,
is_gated
=
True
)
self
.
post_tile_positional_embedding
=
\
MllamaPrecomputedAspectRatioEmbedding
(
config
,
is_gated
=
True
)
# layer norms
self
.
layernorm_pre
=
nn
.
LayerNorm
(
self
.
hidden_size
)
self
.
layernorm_post
=
nn
.
LayerNorm
(
self
.
hidden_size
)
# encoders
self
.
transformer
=
MllamaVisionEncoder
(
config
,
config
.
num_hidden_layers
,
is_gated
=
False
,
output_hidden_states
=
config
.
intermediate_layers_indices
)
self
.
global_transformer
=
MllamaVisionEncoder
(
config
,
config
.
num_global_layers
,
is_gated
=
True
)
def
apply_class_embedding
(
self
,
hidden_state
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
,
_
,
hidden_size
=
hidden_state
.
shape
class_embedding
=
self
.
class_embedding
.
expand
(
batch_size
,
1
,
hidden_size
)
hidden_state
=
torch
.
cat
([
class_embedding
,
hidden_state
],
dim
=
1
)
return
hidden_state
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
,
aspect_ratio_ids
:
torch
.
Tensor
,
aspect_ratio_mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
,
num_concurrent_media
,
num_tiles
,
num_channels
,
\
height
,
width
=
pixel_values
.
shape
pixel_values
=
pixel_values
.
reshape
(
batch_size
*
num_concurrent_media
*
num_tiles
,
num_channels
,
height
,
width
)
aspect_ratio_ids
=
aspect_ratio_ids
.
reshape
(
batch_size
*
num_concurrent_media
,
-
1
)
# patch embedding
patch_embeds
=
self
.
patch_embedding
(
pixel_values
.
to
(
self
.
layernorm_pre
.
weight
.
dtype
))
hidden_state
=
patch_embeds
hidden_state
=
ps
.
get_tp_group
().
all_gather
(
hidden_state
)
# tile embeddings
_
,
num_patches
,
dim
=
hidden_state
.
shape
hidden_state
=
hidden_state
.
reshape
(
batch_size
*
num_concurrent_media
,
num_tiles
,
-
1
,
dim
)
hidden_state
=
self
.
pre_tile_positional_embedding
(
hidden_state
,
aspect_ratio_ids
)
# apply cls token
hidden_state
=
hidden_state
.
reshape
(
batch_size
*
num_concurrent_media
*
num_tiles
,
num_patches
,
dim
)
hidden_state
=
self
.
apply_class_embedding
(
hidden_state
)
num_patches
+=
1
# apply position embeddings
hidden_state
=
hidden_state
.
reshape
(
batch_size
*
num_concurrent_media
,
num_tiles
,
num_patches
,
dim
)
hidden_state
=
self
.
gated_positional_embedding
(
hidden_state
,
aspect_ratio_ids
)
# apply encoder
hidden_state
=
self
.
layernorm_pre
(
hidden_state
)
# Compute the number of tokens to pad
num_padding_patches
=
(
8
-
(
hidden_state
.
shape
[
-
2
]
%
8
))
%
8
# Compute padding tuple for pad function
padding
=
(
0
,
0
,
0
,
num_padding_patches
)
# (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
# Pad the tensor
hidden_state
=
F
.
pad
(
hidden_state
,
padding
,
mode
=
"constant"
,
value
=
0
)
slice_index
=
-
num_padding_patches
if
num_padding_patches
>
0
else
None
attention_mask
=
aspect_ratio_mask
.
reshape
(
batch_size
*
num_concurrent_media
,
-
1
)
attention_mask
=
_prepare_aspect_ratio_attention_mask
(
aspect_ratio_mask
=
attention_mask
,
num_patches
=
self
.
num_patches
,
target_length
=
hidden_state
.
shape
[
2
],
dtype
=
self
.
layernorm_pre
.
weight
.
dtype
,
)
hidden_state
=
hidden_state
.
view
(
batch_size
*
num_concurrent_media
,
-
1
,
dim
)
output
=
self
.
transformer
(
hidden_state
,
attention_mask
=
attention_mask
,
)
hidden_state
,
intermediate_hidden_states
=
output
[
0
],
output
[
1
]
intermediate_hidden_states
=
torch
.
stack
(
intermediate_hidden_states
,
dim
=-
1
)
# apply global encoder
hidden_state
=
self
.
layernorm_post
(
hidden_state
)
hidden_state
=
hidden_state
.
reshape
(
batch_size
*
num_concurrent_media
,
num_tiles
,
num_patches
+
num_padding_patches
,
dim
)
hidden_state
=
self
.
post_tile_positional_embedding
(
hidden_state
,
aspect_ratio_ids
)
hidden_state
=
hidden_state
.
reshape
(
batch_size
*
num_concurrent_media
,
num_tiles
*
(
num_patches
+
num_padding_patches
),
dim
)
hidden_state
=
self
.
global_transformer
(
hidden_state
,
attention_mask
=
attention_mask
)[
0
]
hidden_state
=
hidden_state
.
reshape
(
batch_size
*
num_concurrent_media
,
num_tiles
,
num_patches
+
num_padding_patches
,
dim
)
hidden_state
=
hidden_state
[:,
:,
:
slice_index
]
# adding intermediate layer outputs
hidden_state
=
hidden_state
.
reshape
(
batch_size
,
num_concurrent_media
,
num_tiles
,
num_patches
,
dim
)
intermediate_hidden_states
=
intermediate_hidden_states
.
reshape
(
batch_size
*
num_concurrent_media
,
num_tiles
,
num_patches
+
num_padding_patches
,
-
1
)
intermediate_hidden_states
=
intermediate_hidden_states
[:,
:,
:
slice_index
]
intermediate_hidden_states
=
intermediate_hidden_states
.
reshape
(
batch_size
,
num_concurrent_media
,
num_tiles
,
num_patches
,
-
1
)
hidden_state
=
torch
.
cat
([
hidden_state
,
intermediate_hidden_states
],
dim
=-
1
)
return
hidden_state
class
MllamaTextRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
):
"""
MllamaTextRMSNorm is equivalent to T5LayerNorm
"""
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
hidden_states
):
input_dtype
=
hidden_states
.
dtype
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
variance
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
return
self
.
weight
*
hidden_states
.
to
(
input_dtype
)
def
extra_repr
(
self
):
return
f
"
{
tuple
(
self
.
weight
.
shape
)
}
, eps=
{
self
.
variance_epsilon
}
"
class
MllamaTextCrossAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
self
,
config
:
Optional
[
config_mllama
.
MllamaTextConfig
]
=
None
,
layer_idx
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
model_parallel_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads
=
self
.
config
.
num_attention_heads
self
.
num_local_heads
=
self
.
num_heads
//
self
.
model_parallel_size
self
.
num_key_value_heads
=
self
.
config
.
num_key_value_heads
self
.
num_local_key_value_heads
=
\
self
.
num_key_value_heads
//
self
.
model_parallel_size
self
.
dropout
=
config
.
dropout
self
.
hidden_size
=
config
.
hidden_size
self
.
head_dim
=
config
.
hidden_size
//
self
.
num_heads
self
.
layer_idx
=
layer_idx
self
.
num_key_value_groups
=
self
.
num_heads
//
self
.
num_key_value_heads
self
.
q_local_size
=
self
.
num_local_heads
*
self
.
head_dim
self
.
kv_local_size
=
self
.
num_local_key_value_heads
*
self
.
head_dim
# TODO: change to Q/KV separate linear after #7448 is merged
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
hidden_size
,
self
.
head_dim
,
self
.
num_heads
,
self
.
num_key_value_heads
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
num_heads
*
self
.
head_dim
,
self
.
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
quant_config
=
quant_config
,
)
# vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
# use huggingface's instead
self
.
q_norm
=
MllamaTextRMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
self
.
k_norm
=
MllamaTextRMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
Attention
(
self
.
num_local_heads
,
self
.
head_dim
,
self
.
scaling
,
self
.
num_local_key_value_heads
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
],
cross_attention_states
:
Optional
[
torch
.
Tensor
],
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv_dec
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
_
,
_
=
qkv_dec
.
split
(
[
self
.
q_local_size
,
self
.
kv_local_size
,
self
.
kv_local_size
],
dim
=-
1
)
if
cross_attention_states
is
None
:
k
=
None
v
=
None
else
:
qkv_enc
,
_
=
self
.
qkv_proj
(
cross_attention_states
)
_
,
k
,
v
=
qkv_enc
.
split
(
[
self
.
q_local_size
,
self
.
kv_local_size
,
self
.
kv_local_size
],
dim
=-
1
)
k
=
k
.
view
(
-
1
,
self
.
num_local_key_value_heads
,
self
.
head_dim
)
v
=
v
.
view
(
-
1
,
self
.
num_local_key_value_heads
,
self
.
head_dim
)
k
=
self
.
k_norm
(
k
)
q
=
q
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
head_dim
)
q
=
self
.
q_norm
(
q
)
output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
,
attn_type
=
AttentionType
.
ENCODER_DECODER
)
out
,
_
=
self
.
o_proj
(
output
)
return
out
class
MllamaCrossAttentionDecoderLayer
(
torch
.
nn
.
Module
):
"""Cross-attention transformer block with tanh-gated attention
and feedforward."""
def
__init__
(
self
,
config
:
config_mllama
.
MllamaTextConfig
,
layer_idx
:
int
,
quant_config
:
Optional
[
QuantizationConfig
])
\
->
None
:
super
().
__init__
()
self
.
layer_idx
=
layer_idx
self
.
cross_attn
=
MllamaTextCrossAttention
(
config
=
config
,
layer_idx
=
layer_idx
,
quant_config
=
quant_config
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
cross_attn_attn_gate
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
1
))
self
.
mlp
=
LlamaMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
cross_attn_mlp_gate
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
1
))
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
cross_attention_states
:
torch
.
Tensor
,
cross_attention_mask
:
torch
.
Tensor
,
full_text_row_masked_out_mask
:
torch
.
Tensor
,
kv_cache
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
cross_attn
(
hidden_states
=
hidden_states
,
attention_mask
=
cross_attention_mask
,
cross_attention_states
=
cross_attention_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
hidden_states
=
full_text_row_masked_out_mask
*
hidden_states
hidden_states
=
residual
+
self
.
cross_attn_attn_gate
.
tanh
(
)
*
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
full_text_row_masked_out_mask
*
hidden_states
hidden_states
=
residual
+
self
.
cross_attn_mlp_gate
.
tanh
(
)
*
hidden_states
return
hidden_states
class
MllamaTextModel
(
nn
.
Module
):
config_class
=
config_mllama
.
MllamaTextConfig
base_model_prefix
=
"model"
def
__init__
(
self
,
config
:
config_mllama
.
MllamaTextConfig
,
cache_config
:
Optional
[
CacheConfig
],
quant_config
:
Optional
[
QuantizationConfig
]):
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
+
8
,
config
.
hidden_size
)
self
.
cross_attention_layers
=
config
.
cross_attention_layers
layers
=
[]
for
layer_idx
in
range
(
config
.
num_hidden_layers
):
if
layer_idx
in
self
.
cross_attention_layers
:
layers
.
append
(
MllamaCrossAttentionDecoderLayer
(
config
,
layer_idx
,
quant_config
=
quant_config
))
else
:
# TODO: force LlamaDecoderLayer to config.attention_bias=False
layers
.
append
(
LlamaDecoderLayer
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
))
self
.
layers
=
nn
.
ModuleList
(
layers
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
Optional
[
torch
.
LongTensor
],
cross_attention_states
:
Optional
[
torch
.
LongTensor
],
cross_attention_mask
:
Optional
[
torch
.
LongTensor
],
full_text_row_masked_out_mask
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
skip_cross_attention
:
bool
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
inputs_embeds
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
):
if
isinstance
(
decoder_layer
,
MllamaCrossAttentionDecoderLayer
):
if
not
skip_cross_attention
:
hidden_states
=
decoder_layer
(
hidden_states
=
hidden_states
,
cross_attention_states
=
cross_attention_states
,
cross_attention_mask
=
cross_attention_mask
,
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
,
kv_cache
=
kv_caches
[
idx
],
attn_metadata
=
attn_metadata
,
)
elif
isinstance
(
decoder_layer
,
LlamaDecoderLayer
):
hidden_states
,
residual
=
decoder_layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_caches
[
idx
],
attn_metadata
=
attn_metadata
,
residual
=
None
,
)
hidden_states
=
hidden_states
+
residual
else
:
raise
ValueError
(
f
"Unknown decoder layer type
{
type
(
decoder_layer
)
}
"
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
class
MllamaForCausalLM
(
nn
.
Module
):
config_class
=
config_mllama
.
MllamaTextConfig
base_model_prefix
=
"language_model"
_no_split_modules
=
[
"MllamaCrossAttentionDecoderLayer"
,
"MllamaSelfAttentionDecoderLayer"
]
def
__init__
(
self
,
config
:
config_mllama
.
MllamaTextConfig
,
cache_config
:
Optional
[
CacheConfig
],
quant_config
:
Optional
[
QuantizationConfig
]):
super
().
__init__
()
self
.
vocab_size
=
config
.
vocab_size
self
.
model
=
MllamaTextModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
,
quant_config
=
quant_config
,
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
Optional
[
torch
.
LongTensor
],
cross_attention_states
:
Optional
[
torch
.
LongTensor
],
cross_attention_mask
:
Optional
[
torch
.
LongTensor
],
full_text_row_masked_out_mask
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
skip_cross_attention
:
bool
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
cross_attention_states
=
cross_attention_states
,
cross_attention_mask
=
cross_attention_mask
,
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
skip_cross_attention
=
skip_cross_attention
,
)
return
hidden_states
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
()
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_mllama_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_decoder_data_for_mllama
)
@
INPUT_REGISTRY
.
register_dummy_encoder_data
(
dummy_encoder_data_for_mllama
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_mllama
)
class
MllamaForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaConfig
,
multimodal_config
:
MultiModalConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
vocab_size
=
config
.
text_config
.
vocab_size
self
.
hidden_size
=
config
.
text_config
.
hidden_size
self
.
max_num_tiles
=
config
.
vision_config
.
max_num_tiles
self
.
vision_output_dim
=
config
.
vision_config
.
vision_output_dim
self
.
pad_token_id
=
\
config
.
pad_token_id
if
config
.
pad_token_id
is
not
None
else
-
1
self
.
image_size
=
config
.
vision_config
.
image_size
self
.
vision_model
=
MllamaVisionModel
(
config
.
vision_config
)
self
.
language_model
=
MllamaForCausalLM
(
config
.
text_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
)
self
.
multi_modal_projector
=
nn
.
Linear
(
config
.
vision_config
.
vision_output_dim
,
config
.
text_config
.
hidden_size
,
bias
=
True
,
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
output_hidden_states
,
config
.
text_config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
language_model
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
):
# tensor with the same shape will be batched together by
# MultiModalInputs.batch, so pixel_values here can be:
# - List[List[torch.Tensor]]:
# with shape (num_tiles, 3, image_res, image_res)
# - List[torch.Tensor]:
# with shape (num_image, num_tiles, 3, image_res, image_res)
# - torch.Tensor:
# with shape (bs, num_image, num_tiles, 3, image_res, image_res)
pixel_values
:
Optional
[
Union
[
List
[
List
[
torch
.
Tensor
]],
List
[
torch
.
Tensor
],
torch
.
Tensor
]]
=
kwargs
.
pop
(
"pixel_values"
,
None
)
image_embeds
:
Optional
[
Union
[
List
[
List
[
torch
.
Tensor
]],
List
[
torch
.
Tensor
],
torch
.
Tensor
]]
=
kwargs
.
pop
(
"image_embeds"
,
None
)
aspect_ratio_ids
:
Optional
[
Union
[
List
[
List
[
torch
.
Tensor
]],
List
[
torch
.
Tensor
],
torch
.
Tensor
]]
=
kwargs
.
pop
(
"aspect_ratio_ids"
,
None
)
aspect_ratio_mask
:
Optional
[
Union
[
List
[
List
[
torch
.
Tensor
]],
List
[
torch
.
Tensor
],
torch
.
Tensor
]]
=
kwargs
.
pop
(
"aspect_ratio_mask"
,
None
)
if
pixel_values
is
None
and
image_embeds
is
None
:
return
None
if
pixel_values
is
not
None
and
image_embeds
is
not
None
:
raise
ValueError
(
"Both pixel values and image embeds are provided."
)
if
pixel_values
is
not
None
:
assert
aspect_ratio_ids
is
not
None
assert
aspect_ratio_mask
is
not
None
max_num_images
=
max
([
len
(
x
[
0
])
for
x
in
pixel_values
])
if
max_num_images
==
0
:
raise
ValueError
(
"No images provided."
)
max_num_tiles
=
max
(
max
([
len
(
x
)
for
x
in
y
[
0
]])
for
y
in
pixel_values
)
device
=
self
.
multi_modal_projector
.
weight
.
device
bsz
=
len
(
pixel_values
)
out_num_tiles
=
[]
out_images
=
torch
.
zeros
(
bsz
,
max_num_images
,
max_num_tiles
,
3
,
self
.
image_size
,
self
.
image_size
,
dtype
=
torch
.
float32
,
device
=
device
,
)
out_ar_ids
=
torch
.
ones
(
bsz
,
max_num_images
,
dtype
=
torch
.
int64
,
device
=
device
)
out_ar_mask
=
torch
.
zeros
(
bsz
,
max_num_images
,
max_num_tiles
,
dtype
=
torch
.
int64
,
device
=
device
)
for
b
in
range
(
len
(
pixel_values
)):
_num_tiles
=
[]
for
i
in
range
(
len
(
pixel_values
[
b
][
0
])):
img
=
pixel_values
[
b
][
0
][
i
]
out_images
[
b
,
i
,
:
img
.
shape
[
0
]]
=
img
out_ar_ids
[
b
,
i
]
=
aspect_ratio_ids
[
b
][
0
][
i
]
out_ar_mask
[
b
,
i
]
=
aspect_ratio_mask
[
b
][
0
][
i
]
_num_tiles
.
append
(
img
.
shape
[
0
])
out_num_tiles
.
append
(
_num_tiles
)
return
MllamaImagePixelInputs
(
type
=
"pixel_values"
,
data
=
out_images
,
aspect_ratio_ids
=
out_ar_ids
,
aspect_ratio_mask
=
out_ar_mask
,
)
if
image_embeds
is
not
None
:
raise
NotImplementedError
raise
AssertionError
(
"This line should be unreachable."
)
def
flat_encoder_result
(
self
,
cross_attention_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
):
cross_attention_states_flat
=
torch
.
zeros
(
sum
(
attn_metadata
.
encoder_seq_lens
),
cross_attention_states
.
shape
[
-
1
],
device
=
cross_attention_states
.
device
,
dtype
=
cross_attention_states
.
dtype
)
start_pos
=
0
for
seq_len
,
vision_token_in_batch
in
zip
(
attn_metadata
.
encoder_seq_lens
,
cross_attention_states
):
end_pos
=
start_pos
+
seq_len
cross_attention_states_flat
[
start_pos
:
end_pos
]
=
vision_token_in_batch
[:
seq_len
]
start_pos
=
end_pos
cross_attention_states
=
cross_attention_states_flat
full_text_row_masked_out_mask
=
torch
.
ones
(
(
attn_metadata
.
num_prefill_tokens
,
1
),
dtype
=
torch
.
bool
)
start_pos
=
0
for
seq_len
,
encoder_seq_len
in
zip
(
attn_metadata
.
seq_lens_tensor
.
cpu
(),
attn_metadata
.
encoder_seq_lens
):
if
encoder_seq_len
==
0
:
full_text_row_masked_out_mask
[
start_pos
:
start_pos
+
seq_len
]
=
False
start_pos
+=
seq_len
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
.
to
(
cross_attention_states
.
device
)
return
cross_attention_states
,
full_text_row_masked_out_mask
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
**
kwargs
:
object
,
)
->
Union
[
Tuple
,
CausalLMOutputWithPast
]:
if
attn_metadata
.
num_prefill_tokens
>
0
and
\
attn_metadata
.
num_decode_tokens
>
0
:
raise
ValueError
(
"Chunk prefill not supported"
)
image_inputs
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_inputs
is
None
:
cross_attention_mask
=
None
full_text_row_masked_out_mask
=
(
attn_metadata
.
encoder_seq_lens_tensor
!=
0
).
reshape
(
-
1
,
1
).
to
(
input_ids
.
device
)
cross_attention_states
=
None
skip_cross_attention
=
max
(
attn_metadata
.
encoder_seq_lens
)
==
0
else
:
# NOTE: llama's reference implementation runs vision model on CPU
pixel_values
=
image_inputs
[
'data'
]
aspect_ratio_ids
=
image_inputs
[
'aspect_ratio_ids'
]
aspect_ratio_mask
=
image_inputs
[
'aspect_ratio_mask'
]
cross_attention_states
=
self
.
vision_model
(
pixel_values
,
aspect_ratio_ids
,
aspect_ratio_mask
)
cross_attention_states
=
self
.
multi_modal_projector
(
cross_attention_states
)
bsz
,
_
,
_
,
_
,
image_token_dim
=
tuple
(
cross_attention_states
.
shape
)
cross_attention_states
=
cross_attention_states
.
view
(
bsz
,
-
1
,
image_token_dim
)
cross_attention_states
,
full_text_row_masked_out_mask
=
\
self
.
flat_encoder_result
(
cross_attention_states
,
attn_metadata
)
skip_cross_attention
=
False
# TODO: support multi-image by this mask
cross_attention_mask
=
None
outputs
=
self
.
language_model
(
input_ids
=
input_ids
,
positions
=
positions
,
cross_attention_states
=
cross_attention_states
,
cross_attention_mask
=
cross_attention_mask
,
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
skip_cross_attention
=
skip_cross_attention
,
)
return
outputs
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
updated_params
=
set
()
for
name
,
loaded_weight
in
weights
:
if
'patch_embedding.weight'
in
name
:
name
=
name
.
replace
(
'patch_embedding.weight'
,
'patch_embedding._linear.weight'
)
loaded_weight
=
loaded_weight
.
view
(
loaded_weight
.
shape
[
0
],
-
1
)
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
updated_params
.
add
(
name
)
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
param
=
params_dict
.
pop
(
name
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/olmoe.py
0 → 100644
View file @
539aa992
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only OLMoE model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
print_warning_once
class
OlmoeMoE
(
nn
.
Module
):
"""A tensor-parallel MoE implementation for Olmoe that shards each expert
across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def
__init__
(
self
,
num_experts
:
int
,
top_k
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
# Gate always runs at half / full precision for now.
self
.
gate
=
ReplicatedLinear
(
hidden_size
,
num_experts
,
bias
=
False
,
quant_config
=
None
)
self
.
experts
=
FusedMoE
(
num_experts
=
num_experts
,
top_k
=
top_k
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
reduce_results
=
True
,
renormalize
=
False
,
quant_config
=
quant_config
,
tp_size
=
tp_size
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape
=
hidden_states
.
shape
hidden_dim
=
hidden_states
.
shape
[
-
1
]
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
return
final_hidden_states
.
view
(
orig_shape
)
class
OlmoeAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
4096
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
q_norm
=
RMSNorm
(
hidden_size
,
eps
=
1e-5
)
self
.
k_norm
=
RMSNorm
(
hidden_size
,
eps
=
1e-5
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
is_neox_style
=
True
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
q_norm
(
q
.
contiguous
()),
self
.
k_norm
(
k
.
contiguous
())
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
OlmoeDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
layer_idx
:
int
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
4096
)
self
.
self_attn
=
OlmoeAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
config
.
num_key_value_heads
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
)
self
.
mlp
=
OlmoeMoE
(
num_experts
=
config
.
num_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
quant_config
=
quant_config
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
1e-5
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
1e-5
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
class
OlmoeModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
OlmoeDecoderLayer
(
config
,
layer_idx
,
cache_config
,
quant_config
=
quant_config
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
1e-5
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
attn_metadata
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
OlmoeForCausalLM
(
nn
.
Module
):
fall_back_to_pt_during_load
=
False
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
OlmoeModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
num_experts
)
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
# Skip non-stacked layers and experts (experts handled below).
if
weight_name
not
in
name
:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if
"mlp.experts"
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Remapping the name of FP8 kv-scale.
if
name
.
endswith
(
"kv_scale"
):
remapped_kv_scale_name
=
name
.
replace
(
".kv_scale"
,
".attn.kv_scale"
)
if
remapped_kv_scale_name
not
in
params_dict
:
print_warning_once
(
"Found kv scale in the checkpoint "
f
"(e.g.
{
name
}
), but not found the expected "
f
"name in the model "
f
"(e.g.
{
remapped_kv_scale_name
}
). "
"kv-scale is not loaded."
)
continue
else
:
name
=
remapped_kv_scale_name
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/paligemma.py
View file @
539aa992
import
itertools
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
TypedDict
,
Union
)
...
@@ -23,7 +22,7 @@ from vllm.sequence import IntermediateTensors
...
@@ -23,7 +22,7 @@ from vllm.sequence import IntermediateTensors
from
.interfaces
import
SupportsMultiModal
from
.interfaces
import
SupportsMultiModal
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
dummy_seq_data_for_siglip
,
get_max_siglip_image_tokens
)
dummy_seq_data_for_siglip
,
get_max_siglip_image_tokens
)
from
.utils
import
filter
_weights
,
merge_multimodal_embeddings
from
.utils
import
group
_weights
_with_prefix
,
merge_multimodal_embeddings
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -153,7 +152,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -153,7 +152,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
self
.
unpadded_vocab_size
=
config
.
text_config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
text_config
.
vocab_size
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
)
config
.
text_config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
def
_validate_pixel_values
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_validate_pixel_values
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -286,21 +286,18 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -286,21 +286,18 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# prepare weight iterators for components
# prepare weight iterators for components
vit_
weights
,
mlp_weights
,
llm_weights
=
itertools
.
tee
(
weights
,
3
)
weights
_group
=
group_weights_with_prefix
(
weights
)
# load vision tower
# load vision tower
vit_weights
=
filter_weights
(
vit_weights
,
"vision_tower"
)
self
.
vision_tower
.
load_weights
(
weights_group
[
"vision_tower"
])
self
.
vision_tower
.
load_weights
(
vit_weights
)
# load mlp projector
# load mlp projector
mlp_weights
=
filter_weights
(
mlp_weights
,
"multi_modal_projector"
)
mlp_params_dict
=
dict
(
self
.
multi_modal_projector
.
named_parameters
())
mlp_params_dict
=
dict
(
self
.
multi_modal_projector
.
named_parameters
())
for
name
,
loaded_weight
in
mlp_
weights
:
for
name
,
loaded_weight
in
weights
_group
[
"multi_modal_projector"
]
:
param
=
mlp_params_dict
[
name
]
param
=
mlp_params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
# load llm backbone
# load llm backbone
llm_weights
=
filter_weights
(
llm_weights
,
"language_model"
)
self
.
language_model
.
load_weights
(
weights_group
[
"language_model"
])
self
.
language_model
.
load_weights
(
llm_weights
)
vllm/model_executor/models/persimmon.py
View file @
539aa992
...
@@ -213,10 +213,10 @@ class PersimmonModel(nn.Module):
...
@@ -213,10 +213,10 @@ class PersimmonModel(nn.Module):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
text_
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
hidden_size
)
config
.
text_config
.
vocab_size
,
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
([
self
.
layers
=
nn
.
ModuleList
([
PersimmonDecoderLayer
(
config
,
PersimmonDecoderLayer
(
config
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
...
@@ -257,14 +257,14 @@ class PersimmonForCausalLM(nn.Module):
...
@@ -257,14 +257,14 @@ class PersimmonForCausalLM(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
text_
config
.
vocab_size
self
.
model
=
PersimmonModel
(
config
,
self
.
model
=
PersimmonModel
(
config
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
config
.
text_config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
bias
=
False
)
bias
=
False
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
text_config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
...
...
vllm/model_executor/models/phi3.py
0 → 100644
View file @
539aa992
# coding=utf-8
# Adapted from llama.py
"""Inference-only Phi3 model code inherit from Llama.py"""
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
class
Phi3ForCausalLM
(
LlamaForCausalLM
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"qkv_proj"
,
],
"gate_up_proj"
:
[
"gate_up_proj"
,
],
}
vllm/model_executor/models/phi3v.py
View file @
539aa992
...
@@ -307,7 +307,7 @@ def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
...
@@ -307,7 +307,7 @@ def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90
def
_calc_hd_transform_size
(
*
,
width
:
int
,
height
:
int
,
hd_num
:
int
=
16
):
def
_calc_hd_transform_size
(
*
,
width
:
int
,
height
:
int
,
hd_num
:
int
):
transposed
=
False
transposed
=
False
if
width
<
height
:
if
width
<
height
:
width
,
height
=
height
,
width
width
,
height
=
height
,
width
...
@@ -337,8 +337,10 @@ def get_phi3v_image_feature_size(
...
@@ -337,8 +337,10 @@ def get_phi3v_image_feature_size(
*
,
*
,
input_height
:
int
,
input_height
:
int
,
input_width
:
int
,
input_width
:
int
,
num_crops
:
int
,
)
->
int
:
)
->
int
:
num_crops
=
hf_config
.
get
(
"num_crops"
,
16
)
if
num_crops
is
None
:
num_crops
=
hf_config
.
get
(
"num_crops"
,
16
)
new_width
,
new_height
=
_calc_hd_transform_size
(
width
=
input_width
,
new_width
,
new_height
=
_calc_hd_transform_size
(
width
=
input_width
,
height
=
input_height
,
height
=
input_height
,
hd_num
=
num_crops
)
hd_num
=
num_crops
)
...
@@ -347,20 +349,26 @@ def get_phi3v_image_feature_size(
...
@@ -347,20 +349,26 @@ def get_phi3v_image_feature_size(
+
(
new_height
//
336
+
1
)
*
12
+
(
new_height
//
336
+
1
)
*
12
def
get_max_phi3v_image_tokens
(
ctx
:
InputContext
):
def
get_max_phi3v_image_tokens
(
ctx
:
InputContext
,
*
,
num_crops
:
Optional
[
int
]
=
None
):
return
get_phi3v_image_feature_size
(
return
get_phi3v_image_feature_size
(
ctx
.
get_hf_image_processor_config
(),
ctx
.
get_hf_image_processor_config
(),
input_height
=
MAX_IMAGE_FEATURE_SIZE_HEIGHT
,
input_height
=
MAX_IMAGE_FEATURE_SIZE_HEIGHT
,
input_width
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
,
input_width
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
,
num_crops
=
num_crops
,
)
)
def
dummy_data_for_phi3v
(
ctx
:
InputContext
,
seq_len
:
int
,
def
dummy_data_for_phi3v
(
ctx
:
InputContext
,
mm_counts
:
Mapping
[
str
,
int
]):
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
*
,
num_crops
:
Optional
[
int
]
=
None
):
num_images
=
mm_counts
[
"image"
]
num_images
=
mm_counts
[
"image"
]
image_feature_size
=
get_max_phi3v_image_tokens
(
ctx
)
image_feature_size
=
get_max_phi3v_image_tokens
(
ctx
,
num_crops
=
num_crops
)
seq_data
=
dummy_seq_data_for_clip
(
seq_data
=
dummy_seq_data_for_clip
(
CLIP_VIT_LARGE_PATCH14_336_CONFIG
,
CLIP_VIT_LARGE_PATCH14_336_CONFIG
,
...
@@ -398,7 +406,10 @@ def _get_image_placeholder_token_ids(model_config: ModelConfig,
...
@@ -398,7 +406,10 @@ def _get_image_placeholder_token_ids(model_config: ModelConfig,
return
image_placeholder_token_ids
return
image_placeholder_token_ids
def
input_processor_for_phi3v
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
):
def
input_processor_for_phi3v
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
,
*
,
num_crops
:
Optional
[
int
]
=
None
):
multi_modal_data
=
llm_inputs
.
get
(
"multi_modal_data"
)
multi_modal_data
=
llm_inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
return
llm_inputs
return
llm_inputs
...
@@ -412,7 +423,8 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
...
@@ -412,7 +423,8 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
image_feature_size
=
[
image_feature_size
=
[
get_phi3v_image_feature_size
(
hf_config
,
get_phi3v_image_feature_size
(
hf_config
,
input_width
=
w
,
input_width
=
w
,
input_height
=
h
)
input_height
=
h
,
num_crops
=
num_crops
)
]
]
image_data
=
[
image_data
]
image_data
=
[
image_data
]
elif
is_list_of
(
image_data
,
Image
.
Image
):
elif
is_list_of
(
image_data
,
Image
.
Image
):
...
@@ -422,7 +434,8 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
...
@@ -422,7 +434,8 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
image_feature_size
.
append
(
image_feature_size
.
append
(
get_phi3v_image_feature_size
(
hf_config
,
get_phi3v_image_feature_size
(
hf_config
,
input_width
=
w
,
input_width
=
w
,
input_height
=
h
))
input_height
=
h
,
num_crops
=
num_crops
))
elif
isinstance
(
image_data
,
torch
.
Tensor
):
elif
isinstance
(
image_data
,
torch
.
Tensor
):
num_images
,
image_feature_size
,
hidden_size
=
image_data
.
shape
num_images
,
image_feature_size
,
hidden_size
=
image_data
.
shape
elif
is_list_of
(
image_data
,
torch
.
Tensor
):
elif
is_list_of
(
image_data
,
torch
.
Tensor
):
...
...
vllm/model_executor/models/phimoe.py
View file @
539aa992
...
@@ -321,13 +321,13 @@ class PhiMoEAttention(nn.Module):
...
@@ -321,13 +321,13 @@ class PhiMoEAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
self
.
total_num_kv_heads
,
bias
=
True
,
bias
=
True
,
quant_config
=
None
,
quant_config
=
quant_config
,
)
)
self
.
o_proj
=
RowParallelLinear
(
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
hidden_size
,
bias
=
True
,
bias
=
True
,
quant_config
=
None
,
quant_config
=
quant_config
,
)
)
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
self
.
head_dim
,
...
@@ -491,6 +491,10 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA):
...
@@ -491,6 +491,10 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA):
"o_proj"
,
"o_proj"
,
"embed_tokens"
,
"embed_tokens"
,
"lm_head"
,
"lm_head"
,
"w1"
,
"w2"
,
"w3"
,
"gate"
,
]
]
embedding_modules
=
{
embedding_modules
=
{
"embed_tokens"
:
"input_embeddings"
,
"embed_tokens"
:
"input_embeddings"
,
...
...
vllm/model_executor/models/pixtral.py
View file @
539aa992
from
array
import
array
from
dataclasses
import
dataclass
,
fields
from
dataclasses
import
dataclass
,
fields
from
itertools
import
tee
from
itertools
import
tee
from
typing
import
Iterable
,
List
,
Mapping
,
Optional
,
Tuple
,
Union
from
typing
import
Iterable
,
List
,
Mapping
,
Optional
,
Tuple
,
Union
...
@@ -24,8 +23,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
...
@@ -24,8 +23,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
SequenceData
)
from
.interfaces
import
SupportsMultiModal
from
.interfaces
import
SupportsMultiModal
from
.utils
import
init_vllm_registered_model
from
.utils
import
init_vllm_registered_model
...
@@ -63,13 +61,11 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
...
@@ -63,13 +61,11 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
image_feature_size
=
(
size
**
2
)
//
(
patch_size
**
2
)
image_feature_size
=
(
size
**
2
)
//
(
patch_size
**
2
)
num_image_tokens
=
image_feature_size
*
num_images
num_image_tokens
=
image_feature_size
*
num_images
seq_data
=
SequenceData
.
from_token_counts
(
(
image_token_id
,
num_image_tokens
),
(
0
,
seq_len
-
num_image_tokens
),
)
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
image_token_id
])
*
num_image_tokens
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
num_image_tokens
)
seq_data
=
SequenceData
(
token_ids
)
mm_data
=
{
"image"
:
num_images
*
[
image
]}
mm_data
=
{
"image"
:
num_images
*
[
image
]}
return
seq_data
,
mm_data
return
seq_data
,
mm_data
...
@@ -454,7 +450,7 @@ class Transformer(nn.Module):
...
@@ -454,7 +450,7 @@ class Transformer(nn.Module):
return
x
return
x
def
position_meshgrid
(
patch_embeds_list
:
l
ist
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
def
position_meshgrid
(
patch_embeds_list
:
L
ist
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
positions
=
torch
.
cat
([
positions
=
torch
.
cat
([
torch
.
stack
(
torch
.
stack
(
torch
.
meshgrid
(
torch
.
meshgrid
(
...
...
vllm/model_executor/models/qwen.py
View file @
539aa992
...
@@ -7,7 +7,6 @@
...
@@ -7,7 +7,6 @@
import
math
import
math
import
re
import
re
from
array
import
array
from
functools
import
partial
from
functools
import
partial
from
typing
import
(
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Literal
,
Mapping
,
from
typing
import
(
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
Optional
,
Tuple
,
TypedDict
,
Union
)
...
@@ -48,8 +47,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
...
@@ -48,8 +47,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
SequenceData
)
from
vllm.utils
import
is_list_of
from
vllm.utils
import
is_list_of
from
.utils
import
flatten_bn
,
is_pp_missing_parameter
,
make_layers
from
.utils
import
flatten_bn
,
is_pp_missing_parameter
,
make_layers
...
@@ -689,8 +687,9 @@ def input_processor_for_qwen(ctx: InputContext,
...
@@ -689,8 +687,9 @@ def input_processor_for_qwen(ctx: InputContext,
prompt
=
llm_inputs
.
get
(
"prompt"
)
prompt
=
llm_inputs
.
get
(
"prompt"
)
prompt_token_ids
=
llm_inputs
[
"prompt_token_ids"
]
prompt_token_ids
=
llm_inputs
[
"prompt_token_ids"
]
model_config
=
ctx
.
model_config
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
tokenizer
=
cached_get_tokenizer
(
trust_remote_code
=
True
)
model_config
.
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
)
image_data
=
multi_modal_data
[
"image"
]
image_data
=
multi_modal_data
[
"image"
]
if
isinstance
(
image_data
,
torch
.
Tensor
):
if
isinstance
(
image_data
,
torch
.
Tensor
):
num_dims
=
len
(
image_data
.
shape
)
num_dims
=
len
(
image_data
.
shape
)
...
@@ -750,8 +749,9 @@ def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs:
...
@@ -750,8 +749,9 @@ def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs:
return
MultiModalInputs
()
return
MultiModalInputs
()
model_config
=
ctx
.
model_config
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
tokenizer
=
cached_get_tokenizer
(
trust_remote_code
=
True
)
model_config
.
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
)
image_pair_tok
=
tokenizer
.
encode
(
IMG_START
+
IMG_END
,
image_pair_tok
=
tokenizer
.
encode
(
IMG_START
+
IMG_END
,
add_special_tokens
=
False
,
add_special_tokens
=
False
,
...
@@ -832,15 +832,16 @@ def dummy_data_for_qwen(
...
@@ -832,15 +832,16 @@ def dummy_data_for_qwen(
# The presence of a visual config indicates this is a multimodal model.
# The presence of a visual config indicates this is a multimodal model.
# If we don't have it, the model is considered an LLM for warmup purposes.
# If we don't have it, the model is considered an LLM for warmup purposes.
if
not
hasattr
(
hf_config
,
"visual"
):
if
not
hasattr
(
hf_config
,
"visual"
):
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
]
*
seq_len
))
seq_data
=
SequenceData
.
from_token_counts
((
0
,
seq_len
))
mm_data
=
None
mm_data
=
None
return
seq_data
,
mm_data
return
seq_data
,
mm_data
# We have a visual component - use images to warm up
# We have a visual component - use images to warm up
num_images
=
mm_counts
[
"image"
]
num_images
=
mm_counts
[
"image"
]
model_config
=
ctx
.
model_config
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
tokenizer
=
cached_get_tokenizer
(
trust_remote_code
=
True
)
model_config
.
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
)
# Build the image prompts with no imgpads; the tokenizer will add img pads
# Build the image prompts with no imgpads; the tokenizer will add img pads
image_prompt
=
''
.
join
(
image_prompt
=
''
.
join
(
...
@@ -859,11 +860,13 @@ def dummy_data_for_qwen(
...
@@ -859,11 +860,13 @@ def dummy_data_for_qwen(
if
len
(
toks
)
<
seq_len
:
if
len
(
toks
)
<
seq_len
:
toks
+=
[
0
]
*
(
seq_len
-
len
(
toks
))
toks
+=
[
0
]
*
(
seq_len
-
len
(
toks
))
seq_data
=
SequenceData
.
from_seqs
(
toks
)
# Build the input images; width/height doesn't actually matter here since
# Build the input images; width/height doesn't actually matter here since
# the data will get resized and the # of tokens per image is constant
# the data will get resized and the # of tokens per image is constant
image
=
Image
.
new
(
"RGB"
,
(
224
,
224
),
color
=
0
)
image
=
Image
.
new
(
"RGB"
,
(
224
,
224
),
color
=
0
)
mm_data
=
{
"image"
:
image
if
num_images
==
1
else
[
image
]
*
num_images
}
mm_data
=
{
"image"
:
image
if
num_images
==
1
else
[
image
]
*
num_images
}
return
S
eq
uenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
toks
))
,
mm_data
return
s
eq
_data
,
mm_data
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
(
input_mapper_for_qwen
)
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
(
input_mapper_for_qwen
)
...
...
vllm/model_executor/models/qwen2.py
View file @
539aa992
...
@@ -51,7 +51,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
...
@@ -51,7 +51,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
from
.interfaces
import
SupportsLoRA
from
.utils
import
is_pp_missing_parameter
,
make_layers
from
.utils
import
PPMissingLayer
,
is_pp_missing_parameter
,
make_layers
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
...
@@ -247,11 +247,16 @@ class Qwen2Model(nn.Module):
...
@@ -247,11 +247,16 @@ class Qwen2Model(nn.Module):
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
if
get_pp_group
().
is_first_rank
or
(
config
.
tie_word_embeddings
config
.
vocab_size
,
and
get_pp_group
().
is_last_rank
):
config
.
hidden_size
,
self
.
embed_tokens
=
VocabParallelEmbedding
(
quant_config
=
quant_config
,
config
.
vocab_size
,
)
config
.
hidden_size
,
quant_config
=
quant_config
,
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
config
.
num_hidden_layers
,
lambda
prefix
:
Qwen2DecoderLayer
(
config
=
config
,
lambda
prefix
:
Qwen2DecoderLayer
(
config
=
config
,
...
@@ -260,7 +265,10 @@ class Qwen2Model(nn.Module):
...
@@ -260,7 +265,10 @@ class Qwen2Model(nn.Module):
prefix
=
f
"
{
prefix
}
.layers"
,
prefix
=
f
"
{
prefix
}
.layers"
,
)
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
if
get_pp_group
().
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
()
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
return
self
.
embed_tokens
(
input_ids
)
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
539aa992
...
@@ -22,7 +22,6 @@
...
@@ -22,7 +22,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from
array
import
array
from
functools
import
lru_cache
,
partial
from
functools
import
lru_cache
,
partial
from
typing
import
(
Iterable
,
List
,
Mapping
,
Optional
,
Tuple
,
Type
,
TypedDict
,
from
typing
import
(
Iterable
,
List
,
Mapping
,
Optional
,
Tuple
,
Type
,
TypedDict
,
Union
)
Union
)
...
@@ -46,7 +45,7 @@ from vllm.attention import AttentionMetadata
...
@@ -46,7 +45,7 @@ from vllm.attention import AttentionMetadata
from
vllm.attention.selector
import
(
_Backend
,
backend_name_to_enum
,
from
vllm.attention.selector
import
(
_Backend
,
backend_name_to_enum
,
get_global_forced_attn_backend
)
get_global_forced_attn_backend
)
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.distributed
import
parallel_state
from
vllm.distributed
import
get_pp_group
,
parallel_state
from
vllm.distributed
import
utils
as
dist_utils
from
vllm.distributed
import
utils
as
dist_utils
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -66,9 +65,12 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
...
@@ -66,9 +65,12 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
from
vllm.multimodal.base
import
MultiModalData
from
vllm.multimodal.base
import
MultiModalData
from
vllm.multimodal.image
import
cached_get_image_processor
from
vllm.multimodal.image
import
cached_get_image_processor
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
SequenceData
)
from
vllm.transformers_utils.processor
import
get_processor
from
vllm.transformers_utils.processor
import
get_processor
from
vllm.utils
import
is_cpu
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -207,7 +209,7 @@ class Qwen2VisionAttention(nn.Module):
...
@@ -207,7 +209,7 @@ class Qwen2VisionAttention(nn.Module):
selected_backend
=
backend_name_to_enum
(
backend_by_env_var
)
selected_backend
=
backend_name_to_enum
(
backend_by_env_var
)
if
selected_backend
is
None
:
if
selected_backend
is
None
:
# For Volta and Turing GPUs, use xformers instead.
# For Volta and Turing GPUs, use xformers instead.
device_available
=
current_platform
.
get
_device_capability
(
)[
0
]
>=
8
device_available
=
current_platform
.
has
_device_capability
(
80
)
if
device_available
:
if
device_available
:
from
transformers.utils
import
is_flash_attn_2_available
from
transformers.utils
import
is_flash_attn_2_available
...
@@ -280,6 +282,21 @@ class Qwen2VisionAttention(nn.Module):
...
@@ -280,6 +282,21 @@ class Qwen2VisionAttention(nn.Module):
context_layer
=
rearrange
(
output
,
context_layer
=
rearrange
(
output
,
"(b s) ... -> b s ..."
,
"(b s) ... -> b s ..."
,
b
=
batch_size
)
b
=
batch_size
)
elif
is_cpu
():
seq_length
=
q
.
size
(
1
)
q
,
k
,
v
=
[
rearrange
(
x
,
"b s h d -> b h s d"
)
for
x
in
[
q
,
k
,
v
]]
attention_mask
=
torch
.
zeros
([
1
,
seq_length
,
seq_length
],
device
=
q
.
device
,
dtype
=
torch
.
bool
)
for
i
in
range
(
1
,
len
(
cu_seqlens
)):
attention_mask
[...,
cu_seqlens
[
i
-
1
]:
cu_seqlens
[
i
],
cu_seqlens
[
i
-
1
]:
cu_seqlens
[
i
]]
=
True
output
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attention_mask
,
dropout_p
=
0.0
)
context_layer
=
rearrange
(
output
,
"b h s d -> b s h d "
)
else
:
else
:
from
xformers
import
ops
as
xops
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalMask
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalMask
...
@@ -681,15 +698,14 @@ def dummy_data_for_qwen2_vl(
...
@@ -681,15 +698,14 @@ def dummy_data_for_qwen2_vl(
"--limit-mm-per-prompt."
)
"--limit-mm-per-prompt."
)
hf_config
=
ctx
.
get_hf_config
(
Qwen2VLConfig
)
hf_config
=
ctx
.
get_hf_config
(
Qwen2VLConfig
)
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
hf_config
.
vision_start_token_id
])
dummy_seqdata
=
SequenceData
.
from_token_counts
(
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
(
hf_config
.
vision_start_token_id
,
1
),
[
hf_config
.
image_token_id
])
*
max_llm_image_tokens
(
hf_config
.
image_token_id
,
max_llm_image_tokens
),
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
(
hf_config
.
vision_end_token_id
,
1
),
[
hf_config
.
vision_end_token_id
])
(
0
,
seq_len
-
max_llm_image_tokens
-
2
),
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
)
[
0
])
*
(
seq_len
-
max_llm_image_tokens
-
2
)
dummy_seqdata
=
SequenceData
(
token_ids
)
dummy_image
=
Image
.
new
(
"RGB"
,
(
max_resized_width
,
max_resized_height
),
dummy_image
=
Image
.
new
(
"RGB"
,
(
max_resized_width
,
max_resized_height
),
color
=
0
)
color
=
0
)
...
@@ -859,15 +875,21 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -859,15 +875,21 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
self
.
model
=
Qwen2Model
(
config
,
cache_config
,
quant_config
)
self
.
model
=
Qwen2Model
(
config
,
cache_config
,
quant_config
)
if
config
.
tie_word_embeddings
:
if
get_pp_group
().
is_last_rank
:
self
.
lm_head
=
self
.
model
.
embed_tokens
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
else
:
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
self
.
lm_head
=
PPMissingLayer
()
config
.
hidden_size
,
quant_config
=
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
def
_validate_and_reshape_mm_tensor
(
self
,
def
_validate_and_reshape_mm_tensor
(
self
,
mm_input
:
Union
[
torch
.
Tensor
,
mm_input
:
Union
[
torch
.
Tensor
,
...
@@ -982,7 +1004,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -982,7 +1004,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
video_input
=
self
.
_parse_and_validate_video_input
(
**
kwargs
)
video_input
=
self
.
_parse_and_validate_video_input
(
**
kwargs
)
if
image_input
is
None
and
video_input
is
None
:
if
(
image_input
is
None
and
video_input
is
None
)
or
not
get_pp_group
().
is_first_rank
:
inputs_embeds
=
None
inputs_embeds
=
None
else
:
else
:
if
getattr
(
self
.
config
,
"rope_scaling"
,
{}).
get
(
"type"
,
if
getattr
(
self
.
config
,
"rope_scaling"
,
{}).
get
(
"type"
,
...
@@ -1018,6 +1041,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -1018,6 +1041,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
positions
=
positions
,
positions
=
positions
,
kv_caches
=
kv_caches
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
)
)
return
hidden_states
return
hidden_states
...
@@ -1058,6 +1082,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -1058,6 +1082,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
...
@@ -1084,6 +1110,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -1084,6 +1110,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
except
KeyError
:
except
KeyError
:
print
(
params_dict
.
keys
())
print
(
params_dict
.
keys
())
...
...
vllm/model_executor/models/siglip.py
View file @
539aa992
...
@@ -2,9 +2,9 @@
...
@@ -2,9 +2,9 @@
within a vision language model."""
within a vision language model."""
import
math
import
math
from
array
import
array
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
torch
from
PIL
import
Image
from
PIL
import
Image
from
torch
import
nn
from
torch
import
nn
...
@@ -24,7 +24,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -24,7 +24,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.multimodal.utils
import
(
cached_get_tokenizer
,
from
vllm.multimodal.utils
import
(
cached_get_tokenizer
,
repeat_and_pad_placeholder_tokens
)
repeat_and_pad_placeholder_tokens
)
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
from
vllm.sequence
import
SequenceData
try
:
try
:
from
xformers
import
ops
as
xops
from
xformers
import
ops
as
xops
...
@@ -67,11 +67,10 @@ def dummy_seq_data_for_siglip(
...
@@ -67,11 +67,10 @@ def dummy_seq_data_for_siglip(
else
:
else
:
image_feature_size
=
image_feature_size_override
image_feature_size
=
image_feature_size_override
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
return
SequenceData
.
from_token_counts
(
[
image_token_id
])
*
image_feature_size
(
image_token_id
,
image_feature_size
*
num_images
),
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
(
0
,
seq_len
-
image_feature_size
*
num_images
),
[
0
])
*
(
seq_len
-
image_feature_size
)
)
return
SequenceData
(
token_ids
)
def
dummy_image_for_siglip
(
def
dummy_image_for_siglip
(
...
@@ -91,6 +90,24 @@ def dummy_image_for_siglip(
...
@@ -91,6 +90,24 @@ def dummy_image_for_siglip(
return
{
"image"
:
image
if
num_images
==
1
else
[
image
]
*
num_images
}
return
{
"image"
:
image
if
num_images
==
1
else
[
image
]
*
num_images
}
def
dummy_video_for_siglip
(
hf_config
:
SiglipVisionConfig
,
num_frames
:
int
,
*
,
image_width_override
:
Optional
[
int
]
=
None
,
image_height_override
:
Optional
[
int
]
=
None
,
):
pil_frame
=
dummy_image_for_siglip
(
hf_config
,
num_images
=
1
,
image_width_override
=
image_width_override
,
image_height_override
=
image_height_override
)
np_frame
=
np
.
array
(
pil_frame
[
"image"
])
mm_data_per_video
=
np
.
repeat
([
np_frame
],
num_frames
,
axis
=
0
)
mm_data
=
{
"video"
:
mm_data_per_video
}
return
mm_data
def
input_processor_for_siglip
(
def
input_processor_for_siglip
(
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
hf_config
:
SiglipVisionConfig
,
hf_config
:
SiglipVisionConfig
,
...
@@ -503,6 +520,7 @@ class SiglipVisionModel(nn.Module):
...
@@ -503,6 +520,7 @@ class SiglipVisionModel(nn.Module):
num_hidden_layers_override
:
Optional
[
int
]
=
None
,
num_hidden_layers_override
:
Optional
[
int
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
num_heads
=
config
.
num_attention_heads
num_heads
=
config
.
num_attention_heads
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
shard_weight
=
USE_XFORMERS_OPS
and
num_heads
%
tp_size
==
0
self
.
shard_weight
=
USE_XFORMERS_OPS
and
num_heads
%
tp_size
==
0
...
@@ -513,10 +531,6 @@ class SiglipVisionModel(nn.Module):
...
@@ -513,10 +531,6 @@ class SiglipVisionModel(nn.Module):
num_hidden_layers_override
=
num_hidden_layers_override
,
num_hidden_layers_override
=
num_hidden_layers_override
,
)
)
@
property
def
_require_post_layernorm
(
self
)
->
bool
:
return
self
.
vision_model
.
post_layernorm
is
not
None
def
get_input_embeddings
(
self
)
->
nn
.
Module
:
def
get_input_embeddings
(
self
)
->
nn
.
Module
:
return
self
.
vision_model
.
embeddings
.
patch_embedding
return
self
.
vision_model
.
embeddings
.
patch_embedding
...
@@ -542,12 +556,12 @@ class SiglipVisionModel(nn.Module):
...
@@ -542,12 +556,12 @@ class SiglipVisionModel(nn.Module):
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
# post_layernorm is optional in SiglipVisionModel
# post_layernorm is optional in SiglipVisionModel
if
(
"vision_model.post_layernorm"
in
name
if
(
name
.
startswith
(
"vision_model.post_layernorm"
)
and
not
self
.
_require_
post_layernorm
):
and
self
.
vision_model
.
post_layernorm
is
None
):
continue
continue
# omit layers when num_hidden_layers_override is set
# omit layers when num_hidden_layers_override is set
if
"vision_model.encoder.layers
."
in
name
:
if
name
.
startswith
(
"vision_model.encoder.layers
"
)
:
layer_idx
=
int
(
name
.
split
(
"."
)[
3
])
layer_idx
=
int
(
name
.
split
(
"."
)[
3
])
if
layer_idx
>=
layer_count
:
if
layer_idx
>=
layer_count
:
continue
continue
...
...
vllm/model_executor/models/solar.py
0 → 100644
View file @
539aa992
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Solar model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
get_compressed_tensors_cache_scale
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
kv_cache_scales_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.models.interfaces
import
SupportsLoRA
from
vllm.model_executor.models.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
make_layers
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_hip
class
SolarMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
False
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
input_size
=
hidden_size
,
output_sizes
=
[
intermediate_size
]
*
2
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
)
self
.
down_proj
=
RowParallelLinear
(
input_size
=
intermediate_size
,
output_size
=
hidden_size
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
SolarAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
False
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
self
.
head_dim
=
getattr
(
config
,
"head_dim"
,
self
.
hidden_size
//
self
.
total_num_heads
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
=
hidden_size
,
head_size
=
self
.
head_dim
,
total_num_heads
=
self
.
total_num_heads
,
total_num_kv_heads
=
self
.
total_num_kv_heads
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
input_size
=
self
.
total_num_heads
*
self
.
head_dim
,
output_size
=
hidden_size
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
SolarDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
if
rope_scaling
is
not
None
and
getattr
(
config
,
"original_max_position_embeddings"
,
None
):
rope_scaling
[
"original_max_position_embeddings"
]
\
=
config
.
original_max_position_embeddings
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias
=
getattr
(
config
,
"attention_bias"
,
False
)
or
getattr
(
config
,
"bias"
,
False
)
self
.
self_attn
=
SolarAttention
(
config
=
config
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
getattr
(
config
,
"num_key_value_heads"
,
config
.
num_attention_heads
),
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
quant_config
=
quant_config
,
bias
=
attention_bias
,
cache_config
=
cache_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
self
.
mlp
=
SolarMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
bias
=
getattr
(
config
,
"mlp_bias"
,
False
),
prefix
=
f
"
{
prefix
}
.mlp"
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
class
SolarModel
(
nn
.
Module
):
def
__init__
(
self
,
config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
lora_vocab
=
((
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
)
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
self
.
org_vocab_size
=
config
.
vocab_size
if
get_pp_group
().
is_first_rank
or
(
config
.
tie_word_embeddings
and
get_pp_group
().
is_last_rank
):
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
SolarDecoderLayer
(
config
=
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
prefix
,
),
prefix
=
f
"
{
prefix
}
.layers"
,
)
if
get_pp_group
().
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
()
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
bskcn_h_1
=
None
bskcn_h_2
=
None
bskcn_r_1
=
None
bskcn_r_2
=
None
bskcn_tv
=
(
self
.
config
.
bskcn_tv
[
0
]
if
self
.
training
else
self
.
config
.
bskcn_tv
[
1
])
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
if
i
in
self
.
config
.
bskcn_1
:
bskcn_h_1
=
hidden_states
.
clone
()
bskcn_r_1
=
residual
.
clone
()
if
i
in
self
.
config
.
bskcn_2
:
bskcn_h_2
=
hidden_states
.
clone
()
bskcn_r_2
=
residual
.
clone
()
if
i
in
self
.
config
.
bskcn_3
:
hidden_states
=
bskcn_h_1
*
bskcn_tv
+
hidden_states
*
(
1
-
bskcn_tv
)
residual
=
bskcn_r_1
*
bskcn_tv
+
residual
*
(
1
-
bskcn_tv
)
if
i
in
self
.
config
.
bskcn_4
:
hidden_states
=
bskcn_h_2
*
bskcn_tv
+
hidden_states
*
(
1
-
bskcn_tv
)
residual
=
bskcn_r_2
*
bskcn_tv
+
residual
*
(
1
-
bskcn_tv
)
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"residual"
:
residual
})
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
SolarForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
# LoRA specific attributes
supported_lora_modules
=
[
"qkv_proj"
,
"o_proj"
,
"gate_up_proj"
,
"down_proj"
,
"embed_tokens"
,
"lm_head"
,
]
embedding_modules
=
{
"embed_tokens"
:
"input_embeddings"
,
"lm_head"
:
"output_embeddings"
,
}
embedding_padding_modules
=
[
"lm_head"
]
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"k_proj"
:
(
"qkv_proj"
,
1
),
"v_proj"
:
(
"qkv_proj"
,
2
),
"gate_proj"
:
(
"gate_up_proj"
,
0
),
"up_proj"
:
(
"gate_up_proj"
,
1
),
}
def
__init__
(
self
,
config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
model
=
SolarModel
(
config
,
cache_config
,
quant_config
,
lora_config
=
lora_config
,
prefix
=
"model"
,
)
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
quant_config
=
quant_config
,
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
Sampler
()
else
:
self
.
lm_head
=
PPMissingLayer
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
model_output
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
return
model_output
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
make_empty_intermediate_tensors
(
self
,
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
IntermediateTensors
:
return
IntermediateTensors
({
"hidden_states"
:
torch
.
zeros
(
(
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
,
),
"residual"
:
torch
.
zeros
(
(
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
,
),
})
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
(
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if
scale_name
:
=
get_compressed_tensors_cache_scale
(
name
):
# Loading kv cache scales for compressed-tensors quantization
param
=
params_dict
[
scale_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
loaded_weight
=
loaded_weight
[
0
]
weight_loader
(
param
,
loaded_weight
)
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def
load_kv_cache_scales
(
self
,
quantization_param_path
:
str
)
->
None
:
tp_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
for
layer_idx
,
scaling_factor
in
kv_cache_scales_loader
(
quantization_param_path
,
tp_rank
,
tp_size
,
self
.
config
.
num_hidden_layers
,
self
.
config
.
__class__
.
model_type
,
):
if
not
isinstance
(
self
.
model
.
layers
[
layer_idx
],
nn
.
Identity
):
layer_self_attn
=
self
.
model
.
layers
[
layer_idx
].
self_attn
if
is_hip
():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor
*=
2
if
hasattr
(
layer_self_attn
,
"kv_scale"
):
layer_self_attn
.
attn
.
_kv_scale
=
scaling_factor
else
:
raise
RuntimeError
(
"Self attention has no KV cache scaling "
"factor attribute!"
)
vllm/model_executor/models/ultravox.py
View file @
539aa992
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
"""PyTorch Ultravox model."""
import
itertools
import
math
import
math
from
array
import
array
from
array
import
array
from
functools
import
lru_cache
from
functools
import
lru_cache
...
@@ -21,15 +20,16 @@ from vllm.config import CacheConfig, MultiModalConfig
...
@@ -21,15 +20,16 @@ from vllm.config import CacheConfig, MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
from
vllm.inputs
import
INPUT_REGISTRY
from
vllm.inputs.data
import
LLMInputs
from
vllm.inputs.data
import
LLMInputs
from
vllm.inputs.registry
import
InputContext
from
vllm.inputs.registry
import
InputContext
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
,
get_act_fn
from
vllm.model_executor.layers.activation
import
SiluAndMul
,
get_act_fn
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.loader
import
DefaultModelLoader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
SupportsMultiModal
from
vllm.model_executor.models.interfaces
import
SupportsMultiModal
from
vllm.model_executor.models.utils
import
(
filter_weights
,
flatten_bn
,
from
vllm.model_executor.models.utils
import
(
flatten_bn
,
group_weights_with_prefix
,
init_vllm_registered_model
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
merge_multimodal_embeddings
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
@@ -43,8 +43,6 @@ from vllm.transformers_utils.configs.ultravox import UltravoxConfig
...
@@ -43,8 +43,6 @@ from vllm.transformers_utils.configs.ultravox import UltravoxConfig
_AUDIO_PLACEHOLDER_TOKEN
=
128002
_AUDIO_PLACEHOLDER_TOKEN
=
128002
_AUDIO_TOKENS_PER_SECOND
=
6.25
_AUDIO_TOKENS_PER_SECOND
=
6.25
logger
=
init_logger
(
__name__
)
class
UltravoxAudioFeatureInputs
(
TypedDict
):
class
UltravoxAudioFeatureInputs
(
TypedDict
):
type
:
Literal
[
"audio_features"
]
type
:
Literal
[
"audio_features"
]
...
@@ -77,15 +75,11 @@ def get_ultravox_max_audio_tokens(ctx: InputContext):
...
@@ -77,15 +75,11 @@ def get_ultravox_max_audio_tokens(ctx: InputContext):
return
math
.
ceil
(
feature_extractor
.
chunk_length
*
_AUDIO_TOKENS_PER_SECOND
)
return
math
.
ceil
(
feature_extractor
.
chunk_length
*
_AUDIO_TOKENS_PER_SECOND
)
def
dummy_data_for_ultravox
(
def
dummy_
seq_
data_for_ultravox
(
ctx
:
InputContext
,
ctx
:
InputContext
,
seq_len
:
int
,
seq_len
:
int
,
mm
_count
s
:
Mapping
[
str
,
int
]
,
audio
_count
:
int
,
):
):
feature_extractor
=
whisper_feature_extractor
(
ctx
)
audio_count
=
mm_counts
[
"audio"
]
audio_placeholder
=
array
(
audio_placeholder
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
_AUDIO_PLACEHOLDER_TOKEN
])
*
get_ultravox_max_audio_tokens
(
ctx
)
[
_AUDIO_PLACEHOLDER_TOKEN
])
*
get_ultravox_max_audio_tokens
(
ctx
)
...
@@ -96,10 +90,28 @@ def dummy_data_for_ultravox(
...
@@ -96,10 +90,28 @@ def dummy_data_for_ultravox(
other_token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
other_token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
len
(
audio_token_ids
))
[
0
])
*
(
seq_len
-
len
(
audio_token_ids
))
return
SequenceData
(
audio_token_ids
+
other_token_ids
)
def
dummy_audio_for_ultravox
(
ctx
:
InputContext
,
audio_count
:
int
,
):
feature_extractor
=
whisper_feature_extractor
(
ctx
)
audio_and_sr
=
(
np
.
array
([
0.0
]
*
feature_extractor
.
chunk_length
),
1
)
audio_and_sr
=
(
np
.
array
([
0.0
]
*
feature_extractor
.
chunk_length
),
1
)
mm_dict
=
{
"audio"
:
[
audio_and_sr
]
*
audio_count
}
return
{
"audio"
:
[
audio_and_sr
]
*
audio_count
}
def
dummy_data_for_ultravox
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
):
audio_count
=
mm_counts
[
"audio"
]
seq_data
=
dummy_seq_data_for_ultravox
(
ctx
,
seq_len
,
audio_count
)
mm_dict
=
dummy_audio_for_ultravox
(
ctx
,
audio_count
)
return
(
S
eq
uenceData
(
audio_token_ids
+
other_token_ids
)
,
mm_dict
)
return
(
s
eq
_data
,
mm_dict
)
def
input_mapper_for_ultravox
(
ctx
:
InputContext
,
data
:
object
):
def
input_mapper_for_ultravox
(
ctx
:
InputContext
,
data
:
object
):
...
@@ -323,14 +335,23 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
...
@@ -323,14 +335,23 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
self
.
multi_modal_config
=
multimodal_config
self
.
multi_modal_config
=
multimodal_config
assert
self
.
multi_modal_config
assert
self
.
multi_modal_config
self
.
secondary_weights
=
[]
self
.
audio_tower
=
ModifiedWhisperEncoder
(
config
.
audio_config
)
if
config
.
audio_model_id
is
not
None
:
if
config
.
audio_model_id
is
not
None
:
self
.
audio_tower
=
ModifiedWhisperEncoder
.
from_pretrained
(
self
.
secondary_weights
.
append
(
config
.
audio_model_id
)
DefaultModelLoader
.
Source
(
else
:
model_or_path
=
config
.
audio_model_id
,
self
.
audio_tower
=
ModifiedWhisperEncoder
(
config
.
audio_config
)
revision
=
None
,
prefix
=
"audio_tower."
,
))
self
.
multi_modal_projector
=
UltravoxProjector
(
config
)
self
.
multi_modal_projector
=
UltravoxProjector
(
config
)
self
.
language_model
=
init_vllm_registered_model
(
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
cache_config
,
quant_config
)
config
.
text_config
,
cache_config
,
quant_config
)
if
config
.
text_model_id
is
not
None
:
self
.
secondary_weights
.
append
(
DefaultModelLoader
.
Source
(
model_or_path
=
config
.
text_model_id
,
revision
=
None
,
prefix
=
"language_model."
))
def
_audio_features_to_embeddings
(
def
_audio_features_to_embeddings
(
self
,
input_features
:
torch
.
Tensor
)
->
torch
.
Tensor
:
self
,
input_features
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -453,11 +474,22 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
...
@@ -453,11 +474,22 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# prepare weight iterators for components
# prepare weight iterators for components
projector_weights
,
llm_weights
=
itertools
.
tee
(
weights
,
2
)
weights_group
=
group_weights_with_prefix
(
weights
)
# load audio tower weights
audio_tower_weights
=
weights_group
[
"audio_tower"
]
audio_tower_params_dict
=
dict
(
self
.
audio_tower
.
named_parameters
(
prefix
=
self
.
audio_tower
.
base_model_prefix
))
for
name
,
loaded_weight
in
audio_tower_weights
:
if
name
in
audio_tower_params_dict
:
param
=
audio_tower_params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load projector weights
# load projector weights
projector_weights
=
filter_weights
(
projector_weights
,
projector_weights
=
weights_group
[
"multi_modal_projector"
]
"multi_modal_projector"
)
projector_params_dict
=
dict
(
projector_params_dict
=
dict
(
self
.
multi_modal_projector
.
named_parameters
())
self
.
multi_modal_projector
.
named_parameters
())
for
name
,
loaded_weight
in
projector_weights
:
for
name
,
loaded_weight
in
projector_weights
:
...
@@ -467,5 +499,4 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
...
@@ -467,5 +499,4 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
# load llm backbone
# load llm backbone
llm_weights
=
filter_weights
(
llm_weights
,
"language_model"
)
self
.
language_model
.
load_weights
(
weights_group
[
"language_model"
])
self
.
language_model
.
load_weights
(
llm_weights
)
vllm/model_executor/models/utils.py
View file @
539aa992
import
itertools
from
collections
import
UserDict
from
typing
import
(
Dict
,
Iterable
,
List
,
Literal
,
Optional
,
Protocol
,
Tuple
,
from
typing
import
(
Dict
,
Iterable
,
List
,
Literal
,
Optional
,
Protocol
,
Tuple
,
Union
,
overload
)
Union
,
overload
)
...
@@ -16,7 +18,23 @@ from vllm.sequence import IntermediateTensors
...
@@ -16,7 +18,23 @@ from vllm.sequence import IntermediateTensors
from
vllm.utils
import
is_pin_memory_available
from
vllm.utils
import
is_pin_memory_available
def
filter_weights
(
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
prefix
:
str
):
class
WeightsGroup
(
UserDict
):
"""
Wraps grouped weights dictionary for a more informative error message
when attempting to access a weight component that does not exist.
"""
def
__getitem__
(
self
,
key
:
str
)
->
int
:
try
:
return
super
().
__getitem__
(
key
)
except
KeyError
as
exc
:
msg
=
(
f
"There is no weights named with the prefix:
{
key
}
. "
f
"Available prefix:
{
set
(
self
.
keys
())
}
"
)
raise
KeyError
(
msg
)
from
exc
def
filter_weights
(
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
prefix
:
str
)
->
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]:
"""
"""
Helper function to load weights for inner vLLM models.
Helper function to load weights for inner vLLM models.
...
@@ -30,6 +48,22 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
...
@@ -30,6 +48,22 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
yield
name
,
loaded_weight
yield
name
,
loaded_weight
def
group_weights_with_prefix
(
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]
)
->
Dict
[
str
,
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]]:
"""
Helper function to group weights with prefix
"""
init_weights
,
repeated_weights
=
itertools
.
tee
(
weights
,
2
)
weights_prefix
=
{
name
.
split
(
"."
)[
0
]
for
name
,
_
in
init_weights
}
repeated_weights
=
itertools
.
tee
(
repeated_weights
,
len
(
weights_prefix
))
return
WeightsGroup
({
prefix
:
filter_weights
(
component
,
prefix
)
for
component
,
prefix
in
zip
(
repeated_weights
,
weights_prefix
)
})
def
init_vllm_registered_model
(
def
init_vllm_registered_model
(
hf_config
:
PretrainedConfig
,
hf_config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
],
cache_config
:
Optional
[
CacheConfig
],
...
...
vllm/model_executor/parameter.py
View file @
539aa992
...
@@ -328,6 +328,64 @@ class PackedvLLMParameter(ModelWeightParameter):
...
@@ -328,6 +328,64 @@ class PackedvLLMParameter(ModelWeightParameter):
marlin_tile_size
=
self
.
marlin_tile_size
)
marlin_tile_size
=
self
.
marlin_tile_size
)
def
permute_param_layout_
(
param
:
BasevLLMParameter
,
input_dim
:
int
,
output_dim
:
int
,
**
kwargs
)
->
BasevLLMParameter
:
"""
Permute a parameter's layout to the specified input and output dimensions,
useful for forcing the parameter into a known layout, for example, if I need
a packed (quantized) weight matrix to be in the layout
{input_dim = 0, output_dim = 1, packed_dim = 0}
then I can call:
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
to ensure x is in the correct layout (permuting it to the correct layout if
required, asserting if it cannot get it to the correct layout)
"""
curr_input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
curr_output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
if
curr_input_dim
is
None
or
curr_output_dim
is
None
:
assert
param
.
data
.
dim
()
==
2
,
\
"permute_param_layout_ only supports 2D parameters when either "
\
"input_dim or output_dim is not set"
# if one of the dimensions is not set, set it to the opposite of the other
# we can only do this since we asserted the parameter is 2D above
if
curr_input_dim
is
None
:
assert
curr_output_dim
is
not
None
,
\
"either input or output dim must be set"
curr_input_dim
=
(
curr_output_dim
+
1
)
%
2
if
curr_output_dim
is
None
:
assert
curr_input_dim
is
not
None
,
\
"either input or output dim must be set"
curr_output_dim
=
(
curr_input_dim
+
1
)
%
2
# create permutation from the current layout to the layout with
# self.input_dim at input_dim and self.output_dim at output_dim preserving
# other dimensions
perm
=
[
i
for
i
in
range
(
param
.
data
.
dim
())
if
i
not
in
[
curr_input_dim
,
curr_output_dim
]
]
perm
.
insert
(
input_dim
,
curr_input_dim
)
perm
.
insert
(
output_dim
,
curr_output_dim
)
if
"packed_dim"
in
kwargs
:
assert
hasattr
(
param
,
"packed_dim"
)
and
\
param
.
packed_dim
==
perm
[
kwargs
[
"packed_dim"
]],
\
"permute_param_layout_ currently doesn't support repacking"
param
.
data
=
param
.
data
.
permute
(
*
perm
)
if
hasattr
(
param
,
"_input_dim"
):
param
.
_input_dim
=
input_dim
if
hasattr
(
param
,
"_output_dim"
):
param
.
_output_dim
=
output_dim
if
"packed_dim"
in
kwargs
and
hasattr
(
param
,
"_packed_dim"
):
param
.
_packed_dim
=
kwargs
[
"packed_dim"
]
return
param
def
_adjust_shard_indexes_for_marlin
(
shard_size
,
shard_offset
,
def
_adjust_shard_indexes_for_marlin
(
shard_size
,
shard_offset
,
marlin_tile_size
):
marlin_tile_size
):
return
shard_size
*
marlin_tile_size
,
shard_offset
*
marlin_tile_size
return
shard_size
*
marlin_tile_size
,
shard_offset
*
marlin_tile_size
...
...
vllm/model_executor/sampling_metadata.py
View file @
539aa992
import
random
from
array
import
array
from
array
import
array
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
...
@@ -8,15 +7,10 @@ import torch
...
@@ -8,15 +7,10 @@ import torch
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
,
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
,
SequenceGroupMetadata
)
SequenceGroupMetadata
)
from
vllm.triton_utils.sample
import
get_num_triton_sampler_splits
from
vllm.utils
import
(
PyObjectCache
,
async_tensor_h2d
,
from
vllm.utils
import
(
PyObjectCache
,
async_tensor_h2d
,
is_pin_memory_available
,
make_tensor_with_pad
,
is_pin_memory_available
,
make_tensor_with_pad
)
maybe_expand_dim
)
_SAMPLING_EPS
=
1e-5
_SAMPLING_EPS
=
1e-5
_SEED_0_REPLACEMENT
=
3403598558
# Some triton sampler related code is guarded before it is ready.
_USE_TRITON_SAMPLER
=
False
@
dataclass
@
dataclass
...
@@ -74,12 +68,12 @@ def gen_seq_group_to_sample_builder(num_seqs: int):
...
@@ -74,12 +68,12 @@ def gen_seq_group_to_sample_builder(num_seqs: int):
generator
=
None
,
generator
=
None
,
is_prompt
=
True
,
is_prompt
=
True
,
prompt_logprob_indices
=
[],
prompt_logprob_indices
=
[],
sample_indices
=
[])
sample_indices
=
[],
)
class
SamplingMetadataCache
:
class
SamplingMetadataCache
:
"""Used to cache SamplingMetadata objects between scheduler iterations
"""Used to cache SamplingMetadata objects between scheduler iterations"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
_seq_group_to_sample_cache
:
Dict
[
int
,
PyObjectCache
]
=
{}
self
.
_seq_group_to_sample_cache
:
Dict
[
int
,
PyObjectCache
]
=
{}
...
@@ -124,12 +118,12 @@ class SamplingMetadata:
...
@@ -124,12 +118,12 @@ class SamplingMetadata:
The first tuple is [1, 2] (sampled index within original logit),
The first tuple is [1, 2] (sampled index within original logit),
and the second tuple is [0, 1] (sampled index within pruned logit).
and the second tuple is [0, 1] (sampled index within pruned logit).
num_prompts: Number of prompt sequence groups in seq_groups.
num_prompts: Number of prompt sequence groups in seq_groups.
skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
serialization of token outputs.
serialization of token outputs.
reuse_sampling_tensors: Indicates if we want to reuse sampling
reuse_sampling_tensors: Indicates if we want to reuse sampling
tensors that are part of the sampler forward pass. Currently,
tensors that are part of the sampler forward pass. Currently,
it is mainly used for multi-step decode.
it is mainly used for multi-step decode.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -165,16 +159,19 @@ class SamplingMetadata:
...
@@ -165,16 +159,19 @@ class SamplingMetadata:
num_prompts
,
num_prompts
,
)
=
_prepare_seq_groups
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
)
=
_prepare_seq_groups
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
device
,
generators
,
cache
)
device
,
generators
,
cache
)
selected_token_indices
=
async_tensor_h2d
(
selected_token_indices
,
selected_token_indices
=
async_tensor_h2d
(
dtype
=
torch
.
long
,
selected_token_indices
,
target_device
=
device
,
dtype
=
torch
.
long
,
pin_memory
=
pin_memory
)
target_device
=
device
,
pin_memory
=
pin_memory
,
)
categorized_sample_indices
=
{
categorized_sample_indices
=
{
t
:
maybe_expand_dim
(
t
:
async_tensor_h2d
(
async_tensor_h2d
(
seq_ids
,
seq_ids
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
target_device
=
device
,
target_device
=
device
,
pin_memory
=
pin_memory
),
2
,
2
)
pin_memory
=
pin_memory
,
)
for
t
,
seq_ids
in
categorized_sample_indices
.
items
()
for
t
,
seq_ids
in
categorized_sample_indices
.
items
()
}
}
...
@@ -201,8 +198,8 @@ def _prepare_seq_groups(
...
@@ -201,8 +198,8 @@ def _prepare_seq_groups(
device
:
str
,
device
:
str
,
generators
:
Optional
[
Dict
[
str
,
torch
.
Generator
]]
=
None
,
generators
:
Optional
[
Dict
[
str
,
torch
.
Generator
]]
=
None
,
cache
:
Optional
[
SamplingMetadataCache
]
=
None
,
cache
:
Optional
[
SamplingMetadataCache
]
=
None
,
)
->
Tuple
[
List
[
SequenceGroupToSample
],
List
[
int
],
Dict
[
)
->
Tuple
[
List
[
SequenceGroupToSample
],
List
[
int
],
Dict
[
SamplingType
,
SamplingType
,
List
[
Tuple
[
int
,
int
]]],
int
]:
List
[
int
]]
,
int
,
]:
"""Prepare sequence groups and indices for sampling.
"""Prepare sequence groups and indices for sampling.
Args:
Args:
...
@@ -233,16 +230,13 @@ def _prepare_seq_groups(
...
@@ -233,16 +230,13 @@ def _prepare_seq_groups(
# Sampling type -> (
# Sampling type -> (
# indices to sample/prompt logprob within pruned output logits,
# indices to sample/prompt logprob within pruned output logits,
# indices to sample within pruned logits)
# indices to sample within pruned logits)
categorized_sample_indices
:
Dict
[
SamplingType
,
List
[
Tuple
[
int
,
int
]
]]
=
{
categorized_sample_indices
:
Dict
[
SamplingType
,
List
[
int
]]
=
{
t
:
[]
t
:
[]
for
t
in
SamplingType
for
t
in
SamplingType
}
}
# Index of logits to compute logprob. Logits include both prompt logprob
# Index of logits to compute logprob. Logits include both prompt logprob
# and sample logprob indices.
# and sample logprob indices.
logit_idx
=
0
logit_idx
=
0
# Index to sample from a sample tensor. It is used by triton sample kernel.
# See `_sample_with_triton_kernel` for more details.
sample_idx
=
0
# Total number of prompts from given sequence groups.
# Total number of prompts from given sequence groups.
num_prompts
=
0
num_prompts
=
0
...
@@ -264,10 +258,10 @@ def _prepare_seq_groups(
...
@@ -264,10 +258,10 @@ def _prepare_seq_groups(
# If the current seq group is in decode stage, it is None.
# If the current seq group is in decode stage, it is None.
seq_len
:
Optional
[
int
]
=
None
seq_len
:
Optional
[
int
]
=
None
query_len
:
Optional
[
int
]
=
None
query_len
:
Optional
[
int
]
=
None
prompt_logprob_indices
:
List
[
int
]
=
\
prompt_logprob_indices
:
List
[
int
]
=
(
sample_obj
.
prompt_logprob_indices
sample_obj
.
prompt_logprob_indices
if
cache
is
not
None
else
[]
if
cache
is
not
None
else
[]
)
sample_indices
:
List
[
int
]
=
\
sample_indices
:
List
[
int
]
=
(
sample_obj
.
sample_indices
sample_obj
.
sample_indices
if
cache
is
not
None
else
[]
if
cache
is
not
None
else
[]
)
do_sample
=
seq_group_metadata
.
do_sample
do_sample
=
seq_group_metadata
.
do_sample
if
seq_group_metadata
.
is_prompt
:
if
seq_group_metadata
.
is_prompt
:
...
@@ -333,11 +327,8 @@ def _prepare_seq_groups(
...
@@ -333,11 +327,8 @@ def _prepare_seq_groups(
if
do_sample
:
if
do_sample
:
sample_indices
.
extend
(
range
(
logit_idx
,
logit_idx
+
sample_len
))
sample_indices
.
extend
(
range
(
logit_idx
,
logit_idx
+
sample_len
))
categorized_sample_indices
[
sampling_params
.
sampling_type
].
extend
(
categorized_sample_indices
[
sampling_params
.
sampling_type
].
extend
(
list
(
list
(
range
(
logit_idx
,
logit_idx
+
sample_len
)))
zip
(
range
(
logit_idx
,
logit_idx
+
sample_len
),
range
(
sample_idx
,
sample_idx
+
sample_len
))))
logit_idx
+=
sample_len
logit_idx
+=
sample_len
sample_idx
+=
sample_len
if
cache
is
not
None
:
if
cache
is
not
None
:
sample_obj
.
sampling_params
=
sampling_params
sample_obj
.
sampling_params
=
sampling_params
...
@@ -356,7 +347,8 @@ def _prepare_seq_groups(
...
@@ -356,7 +347,8 @@ def _prepare_seq_groups(
generator
=
generator
,
generator
=
generator
,
is_prompt
=
is_prompt
,
is_prompt
=
is_prompt
,
prompt_logprob_indices
=
list
(
prompt_logprob_indices
),
prompt_logprob_indices
=
list
(
prompt_logprob_indices
),
sample_indices
=
list
(
sample_indices
))
sample_indices
=
list
(
sample_indices
),
)
seq_groups
.
append
(
sample_obj
)
seq_groups
.
append
(
sample_obj
)
...
@@ -378,9 +370,6 @@ class SamplingTensors:
...
@@ -378,9 +370,6 @@ class SamplingTensors:
presence_penalties
:
torch
.
Tensor
presence_penalties
:
torch
.
Tensor
frequency_penalties
:
torch
.
Tensor
frequency_penalties
:
torch
.
Tensor
repetition_penalties
:
torch
.
Tensor
repetition_penalties
:
torch
.
Tensor
sampling_seeds
:
torch
.
Tensor
sample_indices
:
torch
.
Tensor
extra_seeds
:
Optional
[
torch
.
Tensor
]
prompt_tokens
:
torch
.
Tensor
prompt_tokens
:
torch
.
Tensor
output_tokens
:
torch
.
Tensor
output_tokens
:
torch
.
Tensor
...
@@ -391,15 +380,7 @@ class SamplingTensors:
...
@@ -391,15 +380,7 @@ class SamplingTensors:
vocab_size
:
int
,
vocab_size
:
int
,
device
:
torch
.
device
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
*
,
extra_seeds_to_generate
:
int
=
0
,
extra_entropy
:
Optional
[
Tuple
[
int
,
...]]
=
None
)
->
Tuple
[
"SamplingTensors"
,
bool
,
bool
,
bool
]:
)
->
Tuple
[
"SamplingTensors"
,
bool
,
bool
,
bool
]:
"""
extra_seeds_to_generate: extra seeds to generate using the
user-defined seed for each sequence.
extra_entropy: extra entropy to use when generating seeds.
"""
prompt_tokens
:
List
[
array
]
=
[]
prompt_tokens
:
List
[
array
]
=
[]
output_tokens
:
List
[
array
]
=
[]
output_tokens
:
List
[
array
]
=
[]
top_ks
:
List
[
int
]
=
[]
top_ks
:
List
[
int
]
=
[]
...
@@ -409,19 +390,10 @@ class SamplingTensors:
...
@@ -409,19 +390,10 @@ class SamplingTensors:
presence_penalties
:
List
[
float
]
=
[]
presence_penalties
:
List
[
float
]
=
[]
frequency_penalties
:
List
[
float
]
=
[]
frequency_penalties
:
List
[
float
]
=
[]
repetition_penalties
:
List
[
float
]
=
[]
repetition_penalties
:
List
[
float
]
=
[]
sampling_seeds
:
List
[
int
]
=
[]
sample_indices
:
List
[
int
]
=
[]
do_penalties
=
False
do_penalties
=
False
do_top_p_top_k
=
False
do_top_p_top_k
=
False
do_min_p
=
False
do_min_p
=
False
if
_USE_TRITON_SAMPLER
:
prompt_best_of
:
List
[
int
]
=
[]
# We need one base seed per Triton slice.
seeds_to_generate
=
(
extra_seeds_to_generate
+
get_num_triton_sampler_splits
(
vocab_size
))
assert
sampling_metadata
.
seq_groups
is
not
None
assert
sampling_metadata
.
seq_groups
is
not
None
for
seq_group
in
sampling_metadata
.
seq_groups
:
for
seq_group
in
sampling_metadata
.
seq_groups
:
seq_ids
=
seq_group
.
seq_ids
seq_ids
=
seq_group
.
seq_ids
...
@@ -452,7 +424,7 @@ class SamplingTensors:
...
@@ -452,7 +424,7 @@ class SamplingTensors:
do_penalties
=
True
do_penalties
=
True
is_prompt
=
seq_group
.
is_prompt
is_prompt
=
seq_group
.
is_prompt
if
(
is_prompt
and
sampling_params
.
prompt_logprobs
is
not
None
)
:
if
is_prompt
and
sampling_params
.
prompt_logprobs
is
not
None
:
# For tokens in the prompt that we only need to get
# For tokens in the prompt that we only need to get
# their logprobs
# their logprobs
query_len
=
seq_group
.
query_len
query_len
=
seq_group
.
query_len
...
@@ -477,28 +449,6 @@ class SamplingTensors:
...
@@ -477,28 +449,6 @@ class SamplingTensors:
frequency_penalties
+=
[
f
]
*
len
(
seq_ids
)
frequency_penalties
+=
[
f
]
*
len
(
seq_ids
)
repetition_penalties
+=
[
r
]
*
len
(
seq_ids
)
repetition_penalties
+=
[
r
]
*
len
(
seq_ids
)
if
_USE_TRITON_SAMPLER
:
if
is_prompt
:
prompt_best_of
.
append
(
sampling_params
.
best_of
)
query_len
=
seq_group
.
query_len
assert
query_len
is
not
None
seed
=
sampling_params
.
seed
is_greedy
=
sampling_params
.
sampling_type
==
SamplingType
.
GREEDY
for
seq_id
in
seq_ids
:
seq_data
=
seq_group
.
seq_data
[
seq_id
]
extra_entropy
=
extra_entropy
or
()
seq_seeds
=
cls
.
_get_sequence_seeds
(
seed
,
seq_data
.
get_len
(),
*
extra_entropy
,
seq_id
,
seeds_to_generate
=
seeds_to_generate
,
is_greedy
=
is_greedy
)
sampling_seeds
.
append
(
seq_seeds
)
sample_indices
.
extend
(
seq_group
.
sample_indices
)
if
do_penalties
:
if
do_penalties
:
for
seq_group
in
sampling_metadata
.
seq_groups
:
for
seq_group
in
sampling_metadata
.
seq_groups
:
seq_ids
=
seq_group
.
seq_ids
seq_ids
=
seq_group
.
seq_ids
...
@@ -518,23 +468,37 @@ class SamplingTensors:
...
@@ -518,23 +468,37 @@ class SamplingTensors:
output_tokens
.
append
(
seq_data
.
output_token_ids_array
)
output_tokens
.
append
(
seq_data
.
output_token_ids_array
)
sampling_tensors
=
SamplingTensors
.
from_lists
(
sampling_tensors
=
SamplingTensors
.
from_lists
(
temperatures
,
top_ps
,
top_ks
,
min_ps
,
presence_penalties
,
temperatures
,
frequency_penalties
,
repetition_penalties
,
sampling_seeds
,
top_ps
,
sample_indices
,
prompt_tokens
,
output_tokens
,
vocab_size
,
top_ks
,
extra_seeds_to_generate
,
device
,
dtype
)
min_ps
,
presence_penalties
,
frequency_penalties
,
repetition_penalties
,
prompt_tokens
,
output_tokens
,
vocab_size
,
device
,
dtype
,
)
return
(
sampling_tensors
,
do_penalties
,
do_top_p_top_k
,
do_min_p
)
return
(
sampling_tensors
,
do_penalties
,
do_top_p_top_k
,
do_min_p
)
@
classmethod
@
classmethod
def
from_lists
(
cls
,
temperatures
:
List
[
float
],
top_ps
:
List
[
float
],
def
from_lists
(
top_ks
:
List
[
int
],
min_ps
:
List
[
float
],
cls
,
presence_penalties
:
List
[
float
],
temperatures
:
List
[
float
],
frequency_penalties
:
List
[
float
],
top_ps
:
List
[
float
],
repetition_penalties
:
List
[
float
],
top_ks
:
List
[
int
],
sampling_seeds
:
List
[
int
],
sample_indices
:
List
[
int
],
min_ps
:
List
[
float
],
prompt_tokens
:
List
[
array
],
output_tokens
:
List
[
array
],
presence_penalties
:
List
[
float
],
vocab_size
:
int
,
extra_seeds_to_generate
:
int
,
frequency_penalties
:
List
[
float
],
device
:
torch
.
device
,
repetition_penalties
:
List
[
float
],
dtype
:
torch
.
dtype
)
->
"SamplingTensors"
:
prompt_tokens
:
List
[
array
],
output_tokens
:
List
[
array
],
vocab_size
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
)
->
"SamplingTensors"
:
# Note that the performance will be very bad without
# Note that the performance will be very bad without
# pinned memory.
# pinned memory.
pin_memory
=
is_pin_memory_available
()
pin_memory
=
is_pin_memory_available
()
...
@@ -603,34 +567,9 @@ class SamplingTensors:
...
@@ -603,34 +567,9 @@ class SamplingTensors:
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
pin_memory
=
pin_memory
,
pin_memory
=
pin_memory
,
)
)
sample_indices_t
=
torch
.
tensor
(
sample_indices
,
device
=
"cpu"
,
dtype
=
torch
.
long
,
pin_memory
=
pin_memory
,
)
# need to transpose and make contiguous to
# copy the tensor correctly.
# [batch_size, n_seeds] -> [n_seeds, batch_size]
sampling_seeds_t
=
torch
.
tensor
(
sampling_seeds
,
device
=
"cpu"
,
dtype
=
torch
.
long
,
pin_memory
=
pin_memory
,
).
t
().
contiguous
()
# Because the memory is pinned, we can do non-blocking
# Because the memory is pinned, we can do non-blocking
# transfer to device.
# transfer to device.
# How many seeds the sample operation itself will need.
num_base_seeds
=
sampling_seeds_t
.
shape
[
0
]
-
extra_seeds_to_generate
sampling_seeds_gpu
=
sampling_seeds_t
.
to
(
device
=
device
,
non_blocking
=
True
)
extra_seeds_gpu
=
sampling_seeds_gpu
[
num_base_seeds
:]
if
not
extra_seeds_gpu
.
numel
():
extra_seeds_gpu
=
None
sampling_seeds_gpu
=
sampling_seeds_gpu
[:
num_base_seeds
]
return
cls
(
return
cls
(
temperatures
=
temperatures_t
.
to
(
device
=
device
,
non_blocking
=
True
),
temperatures
=
temperatures_t
.
to
(
device
=
device
,
non_blocking
=
True
),
top_ps
=
top_ps_t
.
to
(
device
=
device
,
non_blocking
=
True
),
top_ps
=
top_ps_t
.
to
(
device
=
device
,
non_blocking
=
True
),
...
@@ -644,38 +583,4 @@ class SamplingTensors:
...
@@ -644,38 +583,4 @@ class SamplingTensors:
non_blocking
=
True
),
non_blocking
=
True
),
prompt_tokens
=
prompt_t
.
to
(
device
=
device
,
non_blocking
=
True
),
prompt_tokens
=
prompt_t
.
to
(
device
=
device
,
non_blocking
=
True
),
output_tokens
=
output_t
.
to
(
device
=
device
,
non_blocking
=
True
),
output_tokens
=
output_t
.
to
(
device
=
device
,
non_blocking
=
True
),
sampling_seeds
=
sampling_seeds_gpu
,
sample_indices
=
sample_indices_t
.
to
(
device
=
device
,
non_blocking
=
True
),
extra_seeds
=
extra_seeds_gpu
,
)
)
@
staticmethod
def
_get_sequence_seeds
(
seed
:
int
,
*
extra_entropy
:
int
,
seeds_to_generate
:
int
,
is_greedy
:
bool
,
):
"""Get `seeds_to_generate` child seeds from `seed` and extra entropy."""
if
not
is_greedy
:
if
seed
is
None
:
randint_fn
=
random
.
randint
else
:
generator
=
random
.
Random
(
str
((
seed
,
)
+
extra_entropy
))
randint_fn
=
generator
.
randint
lo
,
hi
=
torch
.
iinfo
(
torch
.
long
).
min
,
torch
.
iinfo
(
torch
.
long
).
max
# If the user/random sets seed = 0 but request should
# have sampling, we need to change it to something
# else. We use a constant in that case.
# This way we don't need to create and load a bool
# matrix in the sampling kernel, which reduces CPU
# overhead and latency.
seq_seeds
=
[
randint_fn
(
lo
,
hi
)
or
_SEED_0_REPLACEMENT
for
_
in
range
(
seeds_to_generate
)
]
else
:
# For the kernel, seed == 0 means greedy decoding.
seq_seeds
=
[
0
]
*
seeds_to_generate
return
seq_seeds
vllm/model_executor/utils.py
View file @
539aa992
"""Utils for model executor."""
"""Utils for model executor."""
import
random
from
typing
import
Any
,
Dict
,
Optional
from
typing
import
Any
,
Dict
,
Optional
import
numpy
as
np
import
torch
import
torch
from
vllm.utils
import
seed_everything
def
set_random_seed
(
seed
:
int
)
->
None
:
def
set_random_seed
(
seed
:
int
)
->
None
:
random
.
seed
(
seed
)
seed_everything
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
seed
)
def
set_weight_attrs
(
def
set_weight_attrs
(
...
...
vllm/multimodal/base.py
View file @
539aa992
...
@@ -14,7 +14,8 @@ from typing_extensions import TypeAlias
...
@@ -14,7 +14,8 @@ from typing_extensions import TypeAlias
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.inputs
import
InputContext
from
vllm.inputs
import
InputContext
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
JSONTree
,
is_list_of
,
json_map_leaves
from
vllm.utils
import
(
JSONTree
,
get_allowed_kwarg_only_overrides
,
is_list_of
,
json_map_leaves
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -53,6 +54,12 @@ class MultiModalInputs(_MultiModalInputsBase):
...
@@ -53,6 +54,12 @@ class MultiModalInputs(_MultiModalInputsBase):
if
isinstance
(
nested_tensors
,
torch
.
Tensor
):
if
isinstance
(
nested_tensors
,
torch
.
Tensor
):
return
nested_tensors
return
nested_tensors
if
isinstance
(
nested_tensors
,
np
.
ndarray
):
return
torch
.
from_numpy
(
nested_tensors
)
if
isinstance
(
nested_tensors
,
(
int
,
float
)):
return
torch
.
tensor
(
nested_tensors
)
stacked
=
[
MultiModalInputs
.
_try_stack
(
t
)
for
t
in
nested_tensors
]
stacked
=
[
MultiModalInputs
.
_try_stack
(
t
)
for
t
in
nested_tensors
]
if
not
is_list_of
(
stacked
,
torch
.
Tensor
,
check
=
"all"
):
if
not
is_list_of
(
stacked
,
torch
.
Tensor
,
check
=
"all"
):
# Only tensors (not lists) can be stacked.
# Only tensors (not lists) can be stacked.
...
@@ -256,11 +263,20 @@ class MultiModalPlugin(ABC):
...
@@ -256,11 +263,20 @@ class MultiModalPlugin(ABC):
model_cls
,
_
=
get_model_architecture
(
model_config
)
model_cls
,
_
=
get_model_architecture
(
model_config
)
mapper
=
self
.
_input_mappers
.
get
(
model_cls
)
mapper
=
self
.
_input_mappers
.
get
(
model_cls
)
# Only get processor kwargs at mapping time if we are not using the
# input mapper; no overrides are used on the default here because they
# should be passed to the huggingface resource at initialization time.
if
mapper
is
not
None
and
mapper
!=
self
.
_default_input_mapper
:
mm_processor_kwargs
=
get_allowed_kwarg_only_overrides
(
mapper
,
overrides
=
model_config
.
mm_processor_kwargs
)
else
:
mm_processor_kwargs
=
{}
if
mapper
is
None
:
if
mapper
is
None
:
raise
KeyError
(
f
"No input mapper in
{
self
}
is registered for "
raise
KeyError
(
f
"No input mapper in
{
self
}
is registered for "
f
"model class
{
model_cls
.
__name__
}
."
)
f
"model class
{
model_cls
.
__name__
}
."
)
return
mapper
(
InputContext
(
model_config
),
data
)
return
mapper
(
InputContext
(
model_config
),
data
,
**
mm_processor_kwargs
)
@
abstractmethod
@
abstractmethod
def
_default_max_multimodal_tokens
(
self
,
ctx
:
InputContext
)
->
int
:
def
_default_max_multimodal_tokens
(
self
,
ctx
:
InputContext
)
->
int
:
...
@@ -333,7 +349,10 @@ class MultiModalPlugin(ABC):
...
@@ -333,7 +349,10 @@ class MultiModalPlugin(ABC):
f
"for model class
{
model_cls
.
__name__
}
in
{
self
}
."
)
f
"for model class
{
model_cls
.
__name__
}
in
{
self
}
."
)
if
callable
(
max_mm_tokens
):
if
callable
(
max_mm_tokens
):
max_mm_tokens
=
max_mm_tokens
(
InputContext
(
model_config
))
mm_processor_kwargs
=
get_allowed_kwarg_only_overrides
(
max_mm_tokens
,
overrides
=
model_config
.
mm_processor_kwargs
)
max_mm_tokens
=
max_mm_tokens
(
InputContext
(
model_config
),
**
mm_processor_kwargs
)
self
.
_validate_max_multimodal_tokens
(
max_mm_tokens
)
self
.
_validate_max_multimodal_tokens
(
max_mm_tokens
)
...
...
Prev
1
…
13
14
15
16
17
18
19
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment