Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
af7f4372
Commit
af7f4372
authored
Sep 03, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.5' into v0.5.5-dtk24.04.1
parents
5e19cdef
09c77926
Changes
465
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
959 additions
and
622 deletions
+959
-622
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+8
-3
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+21
-20
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+20
-24
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+134
-71
vllm/model_executor/models/jais.py
vllm/model_executor/models/jais.py
+61
-22
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+188
-160
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+12
-2
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+173
-126
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+198
-130
vllm/model_executor/models/medusa.py
vllm/model_executor/models/medusa.py
+30
-5
vllm/model_executor/models/minicpm.py
vllm/model_executor/models/minicpm.py
+5
-2
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+46
-35
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+8
-3
vllm/model_executor/models/mixtral_quant.py
vllm/model_executor/models/mixtral_quant.py
+7
-2
vllm/model_executor/models/mlp_speculator.py
vllm/model_executor/models/mlp_speculator.py
+12
-2
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+5
-2
vllm/model_executor/models/nemotron.py
vllm/model_executor/models/nemotron.py
+8
-5
vllm/model_executor/models/olmo.py
vllm/model_executor/models/olmo.py
+5
-2
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+11
-4
vllm/model_executor/models/orion.py
vllm/model_executor/models/orion.py
+7
-2
No files found.
Too many changes to show.
To preserve performance only
465 of 465+
files are displayed.
Plain diff
Email patch
vllm/model_executor/models/gpt_neox.py
View file @
af7f4372
...
...
@@ -230,7 +230,7 @@ class GPTNeoXForCausalLM(nn.Module):
def
__init__
(
self
,
config
,
config
:
GPTNeoXConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
...
...
@@ -243,6 +243,8 @@ class GPTNeoXForCausalLM(nn.Module):
config
.
hidden_size
,
quant_config
=
quant_config
,
)
if
self
.
config
.
tie_word_embeddings
:
self
.
embed_out
.
weight
=
self
.
gpt_neox
.
embed_in
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
@@ -258,8 +260,11 @@ class GPTNeoXForCausalLM(nn.Module):
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
embed_out
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/interfaces.py
View file @
af7f4372
from
typing
import
(
ClassVar
,
Dict
,
List
,
Literal
,
Optional
,
Protocol
,
Type
,
Union
,
overload
,
runtime_checkable
)
from
typing_extensions
import
Type
Guard
from
typing_extensions
import
Type
Is
from
vllm.config
import
LoRAConfig
,
MultiModalConfig
,
SchedulerConfig
from
vllm.logger
import
init_logger
...
...
@@ -10,12 +10,12 @@ logger = init_logger(__name__)
@
runtime_checkable
class
Supports
Vision
(
Protocol
):
"""The interface required for all
vision language models (VLMs)
."""
class
Supports
MultiModal
(
Protocol
):
"""The interface required for all
multi-modal models
."""
supports_
vision
:
ClassVar
[
Literal
[
True
]]
=
True
supports_
multimodal
:
ClassVar
[
Literal
[
True
]]
=
True
"""
A flag that indicates this model supports
vision
inputs.
A flag that indicates this model supports
multi-modal
inputs.
Note:
There is no need to redefine this flag if this class is in the
...
...
@@ -29,30 +29,31 @@ class SupportsVision(Protocol):
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@
runtime_checkable
class
_Supports
Vision
Type
(
Protocol
):
supports_
vision
:
Literal
[
True
]
class
_Supports
MultiModal
Type
(
Protocol
):
supports_
multimodal
:
Literal
[
True
]
def
__call__
(
self
,
*
,
multimodal_config
:
MultiModalConfig
)
->
None
:
...
@
overload
def
supports_vision
(
model
:
Type
[
object
])
->
TypeGuard
[
Type
[
SupportsVision
]]:
def
supports_multimodal
(
model
:
Type
[
object
])
->
TypeIs
[
Type
[
SupportsMultiModal
]]:
...
@
overload
def
supports_
vision
(
model
:
object
)
->
Type
Guard
[
Supports
Vision
]:
def
supports_
multimodal
(
model
:
object
)
->
Type
Is
[
Supports
MultiModal
]:
...
def
supports_
vision
(
def
supports_
multimodal
(
model
:
Union
[
Type
[
object
],
object
],
)
->
Union
[
Type
Guard
[
Type
[
Supports
Vision
]],
Type
Guard
[
Supports
Vision
]]:
)
->
Union
[
Type
Is
[
Type
[
Supports
MultiModal
]],
Type
Is
[
Supports
MultiModal
]]:
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
_Supports
Vision
Type
)
return
isinstance
(
model
,
_Supports
MultiModal
Type
)
return
isinstance
(
model
,
Supports
Vision
)
return
isinstance
(
model
,
Supports
MultiModal
)
@
runtime_checkable
...
...
@@ -94,18 +95,18 @@ class _SupportsLoRAType(Protocol):
@
overload
def
supports_lora
(
model
:
Type
[
object
])
->
Type
Guard
[
Type
[
SupportsLoRA
]]:
def
supports_lora
(
model
:
Type
[
object
])
->
Type
Is
[
Type
[
SupportsLoRA
]]:
...
@
overload
def
supports_lora
(
model
:
object
)
->
Type
Guard
[
SupportsLoRA
]:
def
supports_lora
(
model
:
object
)
->
Type
Is
[
SupportsLoRA
]:
...
def
supports_lora
(
model
:
Union
[
Type
[
object
],
object
],
)
->
Union
[
Type
Guard
[
Type
[
SupportsLoRA
]],
Type
Guard
[
SupportsLoRA
]]:
)
->
Union
[
Type
Is
[
Type
[
SupportsLoRA
]],
Type
Is
[
SupportsLoRA
]]:
result
=
_supports_lora
(
model
)
if
not
result
:
...
...
@@ -137,7 +138,7 @@ def supports_lora(
def
_supports_lora
(
model
:
Union
[
Type
[
object
],
object
],
)
->
Union
[
Type
Guard
[
Type
[
SupportsLoRA
]],
Type
Guard
[
SupportsLoRA
]]:
)
->
Union
[
Type
Is
[
Type
[
SupportsLoRA
]],
Type
Is
[
SupportsLoRA
]]:
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
_SupportsLoRAType
)
...
...
@@ -172,18 +173,18 @@ class _HasInnerStateType(Protocol):
@
overload
def
has_inner_state
(
model
:
object
)
->
Type
Guard
[
HasInnerState
]:
def
has_inner_state
(
model
:
object
)
->
Type
Is
[
HasInnerState
]:
...
@
overload
def
has_inner_state
(
model
:
Type
[
object
])
->
Type
Guard
[
Type
[
HasInnerState
]]:
def
has_inner_state
(
model
:
Type
[
object
])
->
Type
Is
[
Type
[
HasInnerState
]]:
...
def
has_inner_state
(
model
:
Union
[
Type
[
object
],
object
]
)
->
Union
[
Type
Guard
[
Type
[
HasInnerState
]],
Type
Guard
[
HasInnerState
]]:
)
->
Union
[
Type
Is
[
Type
[
HasInnerState
]],
Type
Is
[
HasInnerState
]]:
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
_HasInnerStateType
)
...
...
vllm/model_executor/models/internlm2.py
View file @
af7f4372
...
...
@@ -87,6 +87,7 @@ class InternLM2Attention(nn.Module):
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
key_value_groups
=
int
(
self
.
num_heads
/
self
.
num_kv_heads
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
...
...
@@ -120,6 +121,14 @@ class InternLM2Attention(nn.Module):
cache_config
=
cache_config
,
quant_config
=
quant_config
)
def
split_qkv
(
self
,
qkv
:
torch
.
Tensor
):
qkv
=
qkv
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
key_value_groups
+
2
,
128
)
q
,
k
,
v
=
torch
.
split
(
qkv
,
[
self
.
key_value_groups
,
1
,
1
],
dim
=
2
)
q
=
q
.
reshape
(
-
1
,
self
.
q_size
)
k
=
k
.
reshape
(
-
1
,
self
.
kv_size
)
v
=
v
.
reshape
(
-
1
,
self
.
kv_size
)
return
q
,
k
,
v
def
forward
(
self
,
positions
:
torch
.
Tensor
,
...
...
@@ -128,7 +137,7 @@ class InternLM2Attention(nn.Module):
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
wqkv
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
self
.
split_qkv
(
qkv
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
wo
(
attn_output
)
...
...
@@ -264,6 +273,8 @@ class InternLM2ForCausalLM(nn.Module):
self
.
output
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
output
.
weight
=
self
.
model
.
tok_embeddings
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
@@ -279,8 +290,11 @@ class InternLM2ForCausalLM(nn.Module):
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
output
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
@@ -319,24 +333,6 @@ class InternLM2ForCausalLM(nn.Module):
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
if
"wqkv"
in
name
:
config
=
self
.
config
kv_groups
=
(
config
.
num_attention_heads
//
config
.
num_key_value_heads
)
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
loaded_weight
=
loaded_weight
.
view
(
-
1
,
2
+
kv_groups
,
head_dim
,
loaded_weight
.
shape
[
-
1
])
wq
,
wk
,
wv
=
torch
.
split
(
loaded_weight
,
[
kv_groups
,
1
,
1
],
dim
=
1
)
wq
=
wq
.
reshape
(
-
1
,
wq
.
shape
[
-
1
])
wk
=
wk
.
reshape
(
-
1
,
wk
.
shape
[
-
1
])
wv
=
wv
.
reshape
(
-
1
,
wv
.
shape
[
-
1
])
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
wq
,
'q'
)
weight_loader
(
param
,
wk
,
'k'
)
weight_loader
(
param
,
wv
,
'v'
)
else
:
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/internvl.py
View file @
af7f4372
...
...
@@ -5,7 +5,8 @@
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import
itertools
from
typing
import
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
,
Union
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
import
torch
import
torch.nn
as
nn
...
...
@@ -18,18 +19,18 @@ from vllm.config import CacheConfig, MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models.intern_vit
import
InternVisionModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.
image
import
cached_get_tokenizer
from
vllm.multimodal.
utils
import
cached_get_tokenizer
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
.clip
import
(
dummy_image_for_clip
,
dummy_seq_data_for_clip
,
get_clip_num_patches
)
from
.interfaces
import
SupportsVision
from
.utils
import
merge_vision_embeddings
from
.interfaces
import
SupportsMultiModal
from
.utils
import
(
filter_weights
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
IMG_START
=
'<img>'
IMG_END
=
'</img>'
...
...
@@ -38,9 +39,6 @@ IMG_CONTEXT = '<IMG_CONTEXT>'
IMAGENET_MEAN
=
(
0.485
,
0.456
,
0.406
)
IMAGENET_STD
=
(
0.229
,
0.224
,
0.225
)
MAX_IMAGE_FEATURE_SIZE_WIDTH
=
3000
MAX_IMAGE_FEATURE_SIZE_HEIGHT
=
500
class
InternVLImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
...
...
@@ -53,6 +51,19 @@ class InternVLImagePixelInputs(TypedDict):
"""
class
InternVLImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
data
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
InternVLImageInputs
=
Union
[
InternVLImagePixelInputs
,
InternVLImageEmbeddingInputs
]
# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
def
build_transform
(
input_size
):
MEAN
,
STD
=
IMAGENET_MEAN
,
IMAGENET_STD
...
...
@@ -84,11 +95,9 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
return
best_ratio
def
calculate_num_blocks
(
orig_width
:
int
,
orig_height
:
int
,
min_num
=
1
,
max_num
=
6
,
image_size
=
448
):
def
calculate_num_blocks
(
orig_width
:
int
,
orig_height
:
int
,
min_num
:
int
,
max_num
:
int
,
image_size
:
int
)
->
Tuple
[
int
,
int
,
int
]:
aspect_ratio
=
orig_width
/
orig_height
# calculate the existing image aspect ratio
...
...
@@ -110,11 +119,9 @@ def calculate_num_blocks(orig_width: int,
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def
dynamic_preprocess
(
image
,
min_num
=
1
,
max_num
=
6
,
image_size
=
448
,
use_thumbnail
=
False
):
def
dynamic_preprocess
(
image
:
Image
.
Image
,
min_num
:
int
,
max_num
:
int
,
image_size
:
int
,
use_thumbnail
:
int
)
->
List
[
Image
.
Image
]:
orig_width
,
orig_height
=
image
.
size
blocks
,
target_width
,
target_height
=
calculate_num_blocks
(
...
...
@@ -138,12 +145,14 @@ def dynamic_preprocess(image,
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def
image_to_pixel_values
(
image
:
Image
.
Image
,
input_size
=
448
,
max_num
=
6
):
def
image_to_pixel_values
(
image
:
Image
.
Image
,
input_size
:
int
,
min_num
:
int
,
max_num
:
int
,
use_thumbnail
:
bool
)
->
torch
.
Tensor
:
transform
=
build_transform
(
input_size
=
input_size
)
images
=
dynamic_preprocess
(
image
,
min_num
=
min_num
,
max_num
=
max_num
,
image_size
=
input_size
,
use_thumbnail
=
True
,
max_num
=
max_num
)
use_thumbnail
=
use_thumbnail
)
pixel_values
=
[
transform
(
image
)
for
image
in
images
]
pixel_values
=
torch
.
stack
(
pixel_values
)
return
pixel_values
...
...
@@ -157,14 +166,20 @@ def get_internvl_num_patches(image_size: int, patch_size: int,
def
get_max_internvl_image_tokens
(
ctx
:
InputContext
):
hf_config
=
ctx
.
get_hf_config
(
PretrainedConfig
)
hf_config
=
ctx
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
use_thumbnail
=
hf_config
.
use_thumbnail
max_dynamic_patch
=
hf_config
.
max_dynamic_patch
if
use_thumbnail
:
max_dynamic_patch
+=
1
downsample_ratio
=
hf_config
.
downsample_ratio
image_size
=
vision_config
.
image_size
patch_size
=
vision_config
.
patch_size
downsample_ratio
=
hf_config
.
downsample_ratio
num_patches
=
get_internvl_num_patches
(
image_size
,
patch_size
,
downsample_ratio
)
return
num_patches
*
7
return
num_patches
*
max_dynamic_patch
def
input_processor_for_internvl
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
):
...
...
@@ -173,24 +188,32 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
return
llm_inputs
model_config
=
ctx
.
model_config
hf_config
=
ctx
.
get_hf_config
(
PretrainedConfig
)
hf_config
=
ctx
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
image_size
=
vision_config
.
image_size
patch_size
=
vision_config
.
patch_size
downsample_ratio
=
hf_config
.
downsample_ratio
num_patches
=
get_internvl_num_patches
(
image_size
,
patch_size
,
downsample_ratio
)
image_data
=
multi_modal_data
[
"image"
]
if
isinstance
(
image_data
,
Image
.
Image
):
width
,
height
=
image_data
.
size
num_blocks
,
_
,
_
=
calculate_num_blocks
(
width
,
height
)
min_num
=
hf_config
.
min_dynamic_patch
max_num
=
hf_config
.
max_dynamic_patch
num_blocks
,
_
,
_
=
calculate_num_blocks
(
width
,
height
,
min_num
,
max_num
,
image_size
)
# add thumbnail image if num_blocks > 1
if
hf_config
.
use_thumbnail
and
num_blocks
>
1
:
num_blocks
+=
1
image_feature_size
=
num_blocks
*
num_patches
elif
isinstance
(
image_data
,
torch
.
Tensor
):
raise
NotImplementedError
(
"Embeddings input is not supported yet"
)
image_feature_size
=
image_data
.
shape
[
0
]
else
:
raise
TypeError
(
f
"Invalid image type:
{
type
(
image_data
)
}
"
)
image_size
=
vision_config
.
image_size
patch_size
=
vision_config
.
patch_size
downsample_ratio
=
hf_config
.
downsample_ratio
num_patches
=
get_internvl_num_patches
(
image_size
,
patch_size
,
downsample_ratio
)
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
...
...
@@ -198,8 +221,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
prompt_token_ids
=
llm_inputs
[
"prompt_token_ids"
]
if
prompt
is
None
:
prompt
=
tokenizer
.
decode
(
prompt_token_ids
)
image_prompt
=
IMG_START
+
IMG_CONTEXT
*
(
num_blocks
+
1
)
*
num_patches
+
IMG_END
image_prompt
=
IMG_START
+
IMG_CONTEXT
*
image_feature_size
+
IMG_END
new_prompt
=
prompt
.
replace
(
'<image>'
,
image_prompt
,
1
)
new_prompt_token_ids
=
tokenizer
.
encode
(
new_prompt
)
...
...
@@ -209,8 +231,19 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
def
input_mapper_for_internvl
(
ctx
:
InputContext
,
data
:
object
):
hf_config
=
ctx
.
get_hf_config
()
use_thumbnail
=
hf_config
.
use_thumbnail
min_num
=
hf_config
.
min_dynamic_patch
max_num
=
hf_config
.
max_dynamic_patch
image_size
=
hf_config
.
vision_config
.
image_size
if
isinstance
(
data
,
Image
.
Image
):
data
=
image_to_pixel_values
(
data
)
data
=
image_to_pixel_values
(
data
,
image_size
,
min_num
,
max_num
,
use_thumbnail
=
use_thumbnail
)
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
...
...
@@ -224,11 +257,13 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
})
def
dummy_data_for_internvl
(
ctx
:
InputContext
,
seq_len
:
int
):
def
dummy_data_for_internvl
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
num_images
=
mm_counts
[
"image"
]
image_feature_size
=
get_max_internvl_image_tokens
(
ctx
)
model_config
=
ctx
.
model_config
hf_config
=
ctx
.
get_hf_config
(
PretrainedConfig
)
hf_config
=
ctx
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
...
...
@@ -236,14 +271,23 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
seq_data
=
dummy_seq_data_for_clip
(
vision_config
,
seq_len
,
num_images
,
image_token_id
=
tokenizer
.
encode
(
IMG_CONTEXT
,
add_special_tokens
=
False
)[
0
],
image_feature_size_override
=
image_feature_size
,
)
image_size
=
vision_config
.
image_size
min_num
=
hf_config
.
min_dynamic_patch
max_num
=
hf_config
.
max_dynamic_patch
max_image_width
=
max_num
*
image_size
max_image_height
=
min_num
*
image_size
mm_data
=
dummy_image_for_clip
(
vision_config
,
image_width_override
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
,
image_height_override
=
MAX_IMAGE_FEATURE_SIZE_HEIGHT
,
num_images
,
image_width_override
=
max_image_width
,
image_height_override
=
max_image_height
,
)
return
seq_data
,
mm_data
...
...
@@ -253,7 +297,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_internvl_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_internvl
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_internvl
)
class
InternVLChatModel
(
nn
.
Module
,
Supports
Vision
):
class
InternVLChatModel
(
nn
.
Module
,
Supports
MultiModal
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
...
...
@@ -283,10 +327,8 @@ class InternVLChatModel(nn.Module, SupportsVision):
self
.
vision_model
=
InternVisionModel
(
config
.
vision_config
,
num_hidden_layers_override
=
num_hidden_layers
)
llm_class
=
ModelRegistry
.
load_model_cls
(
config
.
text_config
.
architectures
[
0
])
self
.
language_model
=
llm_class
(
config
.
text_config
,
cache_config
,
quant_config
)
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
cache_config
,
quant_config
)
vit_hidden_size
=
config
.
vision_config
.
hidden_size
llm_hidden_size
=
config
.
text_config
.
hidden_size
...
...
@@ -356,15 +398,26 @@ class InternVLChatModel(nn.Module, SupportsVision):
return
data
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
InternVLImage
Pixel
Inputs
]:
self
,
**
kwargs
:
object
)
->
Optional
[
InternVLImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
image_token_id
=
kwargs
.
pop
(
"image_token_id"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
if
pixel_values
is
None
:
if
pixel_values
is
None
and
image_embeds
is
None
:
return
None
if
image_embeds
is
not
None
:
if
not
isinstance
(
image_embeds
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of image embeddings. "
f
"Got type:
{
type
(
image_embeds
)
}
"
)
return
InternVLImageEmbeddingInputs
(
type
=
"image_embeds"
,
data
=
image_embeds
,
)
self
.
img_context_token_id
=
image_token_id
[
0
]
if
pixel_values
is
not
None
:
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
...
...
@@ -374,6 +427,21 @@ class InternVLChatModel(nn.Module, SupportsVision):
data
=
self
.
_validate_pixel_values
(
pixel_values
),
)
raise
AssertionError
(
"This line should be unreachable."
)
def
_process_image_input
(
self
,
image_input
:
InternVLImageInputs
,
)
->
torch
.
Tensor
:
if
image_input
[
"type"
]
==
"image_embeds"
:
return
image_input
[
"data"
]
assert
self
.
vision_model
is
not
None
image_embeds
=
self
.
extract_feature
(
image_input
[
"data"
])
return
image_embeds
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -387,9 +455,9 @@ class InternVLChatModel(nn.Module, SupportsVision):
if
image_input
is
not
None
:
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
input_ids
)
vi
t
_embeds
=
self
.
extract_feature
(
image_input
[
"data"
]
)
inputs_embeds
=
merge_
vision
_embeddings
(
input_ids
,
inputs_embeds
,
vit
_embeds
,
vi
sion
_embed
ding
s
=
self
.
_process_image_input
(
image_input
)
inputs_embeds
=
merge_
multimodal
_embeddings
(
input_ids
,
inputs_embeds
,
vision
_embed
ding
s
,
self
.
img_context_token_id
)
input_ids
=
None
else
:
...
...
@@ -403,8 +471,11 @@ class InternVLChatModel(nn.Module, SupportsVision):
inputs_embeds
=
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
...
...
@@ -415,24 +486,16 @@ class InternVLChatModel(nn.Module, SupportsVision):
)
->
Optional
[
SamplerOutput
]:
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
_filter_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
prefix
:
str
):
for
name
,
loaded_weight
in
weights
:
name
=
name
.
split
(
"."
)
if
prefix
==
name
.
pop
(
0
):
name
=
"."
.
join
(
name
)
yield
name
,
loaded_weight
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# prepare weight iterators for components
vit_weights
,
mlp_weights
,
llm_weights
=
itertools
.
tee
(
weights
,
3
)
# load vision encoder
vit_weights
=
self
.
_
filter_weights
(
vit_weights
,
"vision_model"
)
vit_weights
=
filter_weights
(
vit_weights
,
"vision_model"
)
self
.
vision_model
.
load_weights
(
vit_weights
)
# load mlp projector
mlp_weights
=
self
.
_
filter_weights
(
mlp_weights
,
"mlp1"
)
mlp_weights
=
filter_weights
(
mlp_weights
,
"mlp1"
)
mlp_params_dict
=
dict
(
self
.
mlp1
.
named_parameters
())
for
name
,
loaded_weight
in
mlp_weights
:
param
=
mlp_params_dict
[
name
]
...
...
@@ -441,5 +504,5 @@ class InternVLChatModel(nn.Module, SupportsVision):
weight_loader
(
param
,
loaded_weight
)
# load llm backbone
llm_weights
=
self
.
_
filter_weights
(
llm_weights
,
"language_model"
)
llm_weights
=
filter_weights
(
llm_weights
,
"language_model"
)
self
.
language_model
.
load_weights
(
llm_weights
)
vllm/model_executor/models/jais.py
View file @
af7f4372
...
...
@@ -20,14 +20,14 @@
"""Inference-only Jais model compatible with HuggingFace weights."""
import
math
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
...
...
@@ -37,12 +37,14 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.transformers_utils.configs
import
JAISConfig
from
.utils
import
is_pp_missing_parameter
,
make_layers
class
SwiGLUActivation
(
nn
.
Module
):
...
...
@@ -216,6 +218,7 @@ class JAISModel(nn.Module):
config
:
JAISConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -231,10 +234,15 @@ class JAISModel(nn.Module):
self
.
embeddings_scale
=
config
.
embeddings_scale
else
:
self
.
embeddings_scale
=
config
.
mup_embeddings_scale
self
.
h
=
nn
.
ModuleList
([
JAISBlock
(
config
,
cache_config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
start_layer
,
self
.
end_layer
,
self
.
h
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
JAISBlock
(
config
=
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
),
prefix
=
f
"
{
prefix
}
.h"
,
)
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
...
...
@@ -243,7 +251,9 @@ class JAISModel(nn.Module):
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
Union
[
IntermediateTensors
,
torch
.
Tensor
]:
if
get_pp_group
().
is_first_rank
:
inputs_embeds
=
self
.
wte
(
input_ids
)
if
self
.
wpe
is
not
None
:
position_embeds
=
self
.
wpe
(
position_ids
)
...
...
@@ -252,10 +262,18 @@ class JAISModel(nn.Module):
hidden_states
=
inputs_embeds
hidden_states
*=
torch
.
tensor
(
float
(
self
.
embeddings_scale
),
dtype
=
hidden_states
.
dtype
)
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
len
(
self
.
h
)
):
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
attn_metadata
)
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
...
...
@@ -273,7 +291,11 @@ class JAISLMHeadModel(nn.Module):
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
transformer
=
JAISModel
(
config
,
cache_config
,
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
transformer
.
wte
else
:
self
.
lm_head
=
ParallelLMHead
(
self
.
config
.
vocab_size
,
self
.
config
.
hidden_size
)
if
hasattr
(
config
,
"width_scale"
):
self
.
output_logits_scale
=
config
.
width_scale
else
:
...
...
@@ -290,17 +312,30 @@ class JAISLMHeadModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Union
[
IntermediateTensors
,
torch
.
Tensor
]
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
attn_metadata
,
intermediate_tensors
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
make_empty_intermediate_tensors
(
self
,
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
IntermediateTensors
:
return
IntermediateTensors
({
"hidden_states"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
})
def
sample
(
self
,
logits
:
torch
.
Tensor
,
...
...
@@ -324,6 +359,10 @@ class JAISLMHeadModel(nn.Module):
continue
if
not
name
.
startswith
(
"transformer."
):
name
=
"transformer."
+
name
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
...
...
vllm/model_executor/models/jamba.py
View file @
af7f4372
...
...
@@ -16,7 +16,6 @@ from vllm.attention.layer import Attention
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -249,37 +248,6 @@ class JambaMambaMixer(nn.Module):
return
hidden_states
class
JambaMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
:
JambaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
hidden_size
=
config
.
hidden_size
intermediate_size
=
config
.
intermediate_size
hidden_act
=
config
.
hidden_act
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
JambaMoE
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -327,6 +295,21 @@ class JambaMoE(nn.Module):
return
hidden_states
.
view
(
orig_shape
)
class
JambaMLP
(
JambaMoE
):
def
__init__
(
self
,
config
:
JambaConfig
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
(
config
,
num_experts
=
1
,
top_k
=
1
,
params_dtype
=
params_dtype
,
tp_size
=
tp_size
,
quant_config
=
quant_config
)
class
JambaMambaDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -609,12 +592,8 @@ class JambaForCausalLM(nn.Module, HasInnerState):
# compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
)
# Current step used indices
self
.
current_indices
:
List
[
int
]
=
[]
# Used to track and store by the Mamba cache between steps.
self
.
mamba_cache
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
=
tuple
()
# Used as an input_buffer for the CUDA graph runs.
self
.
mamba_gc_cache_buffer
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
=
tuple
()
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the self.mamba_cache
self
.
mamba_cache_indices_mapping
:
Dict
[
str
,
Dict
[
int
,
int
]]
=
{}
...
...
@@ -644,95 +623,148 @@ class JambaForCausalLM(nn.Module, HasInnerState):
batch_size
=
input_ids
.
shape
[
0
]
if
attn_metadata
.
prefill_metadata
:
batch_size
=
len
(
request_ids_to_seq_ids
)
(
current_seqlen_agnostic_cache
,
indices
,
)
=
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
batch_size
,
finished_requests_ids
)
mamba_cache
=
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
batch_size
,
finished_requests_ids
)
else
:
# CUDA graph capturing runs
current_seqlen_agnostic_cache
,
indices
=
(
kwargs
[
"seqlen_agnostic_capture_inputs"
],
[],
)
self
.
current_indices
=
indices
mamba_cache
=
kwargs
[
"seqlen_agnostic_capture_inputs"
]
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
current_seqlen_agnostic_cache
[
0
],
current_seqlen_agnostic_cache
[
1
])
if
"seqlen_agnostic_capture_inputs"
not
in
kwargs
:
self
.
_copy_mamba_cache_by_indices
(
self
.
current_indices
,
current_seqlen_agnostic_cache
)
attn_metadata
,
mamba_cache
[
0
],
mamba_cache
[
1
])
return
hidden_states
def
_
copy
_mamba_cache
_by_indices
(
self
,
indices
:
List
[
int
],
current_seqlen_agnostic_cache
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
])
:
for
i
,
offset
in
enumerate
(
indices
):
self
.
_copy_mamba_cache
(
offset
,
i
,
current_seqlen_agnostic_cache
)
def
_
swap
_mamba_cache
(
self
,
from_index
:
int
,
to_index
:
int
):
assert
len
(
self
.
mamba_cache
)
>
0
for
cache_t
in
self
.
mamba_cache
:
cache_t
[:,
[
to_index
,
from_index
]]
=
\
cache_t
[:,
[
from_index
,
to_index
]]
def
_copy_mamba_cache
(
self
,
index_to
:
int
,
index_from
:
int
,
from_buffer
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]):
def
_copy_mamba_cache
(
self
,
from_index
:
int
,
to_index
:
int
):
assert
len
(
self
.
mamba_cache
)
>
0
for
(
cache_t
,
from_buffer_t
)
in
zip
(
self
.
mamba_cache
,
from_buffer
)
:
cache_t
[:,
index
_to
].
copy_
(
from_buffer_t
[:,
index_from
],
for
cache_t
in
self
.
mamba_cache
:
cache_t
[:,
to_
index
].
copy_
(
cache_t
[:,
from_index
],
non_blocking
=
True
)
def
_assign_seq_id_to_mamba_cache
(
self
,
cur_rid
:
str
,
seqs_id
:
List
[
int
])
->
List
[
int
]:
indices_for_current_run
=
[]
for
seq_id
in
seqs_id
:
if
cur_rid
not
in
self
.
mamba_cache_indices_mapping
:
self
.
mamba_cache_indices_mapping
[
cur_rid
]
=
{}
def
_move_out_if_already_occupied
(
self
,
index
:
int
,
all_occupied_indices
:
List
[
int
]):
if
index
in
all_occupied_indices
:
first_free_index
=
self
.
_first_free_index_in_mamba_cache
()
self
.
mamba_cache_indices_mapping
[
cur_rid
][
seq_id
]
=
first_free_index
index_for_current_run
=
first_free_index
## case of decoding n>1, copy prefill cache to decoding indices
# In case occupied, move the occupied to a new empty block
self
.
_move_cache_index_and_mappings
(
from_index
=
index
,
to_index
=
first_free_index
)
def
_assign_seq_id_to_mamba_cache_in_specific_dest
(
self
,
cur_rid
:
str
,
seq_id
:
int
,
destination_index
:
int
):
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
all_occupied_indices
=
self
.
_get_all_occupied_indices
()
if
cur_rid
not
in
self
.
mamba_cache_indices_mapping
:
self
.
_move_out_if_already_occupied
(
index
=
destination_index
,
all_occupied_indices
=
all_occupied_indices
)
self
.
mamba_cache_indices_mapping
[
cur_rid
]
=
{
seq_id
:
destination_index
}
elif
seq_id
not
in
(
seq_ids2indices
:
=
self
.
mamba_cache_indices_mapping
[
cur_rid
]):
first_free_index
=
self
.
_first_free_index_in_mamba_cache
()
index_exist
=
list
(
seq_ids2indices
.
values
())[
0
]
self
.
_copy_mamba_cache
(
index_from
=
index_exist
,
index_to
=
first_free_index
,
from_buffer
=
self
.
mamba_cache
)
# parallel sampling , where n > 1, assume prefill have
# already happened now we only need to copy the already
# existing cache into the siblings seq_ids caches
self
.
_move_out_if_already_occupied
(
index
=
destination_index
,
all_occupied_indices
=
all_occupied_indices
)
index_exists
=
list
(
seq_ids2indices
.
values
())[
0
]
# case of decoding n>1, copy prefill cache to decoding indices
self
.
_copy_mamba_cache
(
from_index
=
index_exists
,
to_index
=
destination_index
)
self
.
mamba_cache_indices_mapping
[
cur_rid
][
seq_id
]
=
first_free_index
index_for_current_run
=
first_free_index
seq_id
]
=
destination_index
else
:
index_for_current_run
=
self
.
mamba_cache_indices_mapping
[
# already exists
cache_index_already_exists
=
self
.
mamba_cache_indices_mapping
[
cur_rid
][
seq_id
]
indices_for_current_run
.
append
(
index_for_current_run
)
return
indices_for_current_run
if
cache_index_already_exists
!=
destination_index
:
# In case the seq id already exists but not in
# the right destination, swap it with what's occupying it
self
.
_swap_pair_indices_and_mappings
(
from_index
=
cache_index_already_exists
,
to_index
=
destination_index
)
def
_prepare_current_run_mamba_cache
(
self
,
request_ids_to_seq_ids
:
Dict
[
str
,
list
[
int
]],
batch_size
:
int
,
finished_requests_ids
:
List
[
str
]
)
->
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
List
[
int
]]:
indices_for_current_run
=
[]
for
request_id
,
seqs_id
in
request_ids_to_seq_ids
.
items
():
self
,
request_ids_to_seq_ids
:
Dict
[
str
,
list
[
int
]],
batch_size
:
int
,
finished_requests_ids
:
List
[
str
]):
running_indices
=
[]
request_ids_to_seq_ids_flatten
=
[
(
req_id
,
seq_id
)
for
req_id
,
seq_ids
in
request_ids_to_seq_ids
.
items
()
for
seq_id
in
seq_ids
]
for
dest_index
,
(
request_id
,
seq_id
)
in
enumerate
(
request_ids_to_seq_ids_flatten
):
if
request_id
in
finished_requests_ids
:
# Do not allocate cache for requests that run
# Do not allocate cache
index
for requests that run
# and finish right after
continue
indices_for_current_run
+=
self
.
_assign_seq_id_to_mamba_cache
(
request_id
,
seqs_id
)
## Pad the batch in case of running batch that was not captured via CG
padded_indices
=
indices_for_current_run
.
copy
()
pad_index
=
self
.
_first_free_index_in_mamba_cache
()
self
.
_assign_seq_id_to_mamba_cache_in_specific_dest
(
request_id
,
seq_id
,
dest_index
)
running_indices
.
append
(
dest_index
)
self
.
_clean_up_first_bs_blocks
(
batch_size
,
running_indices
)
conv_state
=
self
.
mamba_cache
[
0
][:,
:
batch_size
]
temporal_state
=
self
.
mamba_cache
[
1
][:,
:
batch_size
]
for
_
in
range
(
batch_size
-
len
(
indices_for_current_run
)):
padded_indices
.
append
(
pad_index
)
return
(
conv_state
,
temporal_state
)
conv_state
=
self
.
mamba_cache
[
0
][:,
padded_indices
]
temporal_state
=
self
.
mamba_cache
[
1
][:,
padded_indices
]
def
_get_all_occupied_indices
(
self
):
return
[
cache_idx
for
seq_ids2indices
in
self
.
mamba_cache_indices_mapping
.
values
()
for
cache_idx
in
seq_ids2indices
.
values
()
]
return
(
conv_state
,
temporal_state
),
indices_for_current_run
def
_clean_up_first_bs_blocks
(
self
,
batch_size
:
int
,
indices_for_current_run
:
List
[
int
]):
# move out all of the occupied but currently not running blocks
# outside of the first n blocks
destination_indices
=
set
([
range
(
batch_size
)])
max_possible_batch_size
=
self
.
mamba_cache
[
0
].
shape
[
1
]
for
destination_index
in
destination_indices
:
if
destination_index
in
self
.
_get_all_occupied_indices
()
and
\
destination_index
not
in
indices_for_current_run
:
# move not running indices outside of the batch
all_other_indices
=
list
(
range
(
batch_size
,
max_possible_batch_size
))
first_avail_index
=
self
.
_first_free_index_in_mamba_cache
(
all_other_indices
)
self
.
_swap_indices
(
from_index
=
destination_index
,
to_index
=
first_avail_index
)
def
_move_cache_index_and_mappings
(
self
,
from_index
:
int
,
to_index
:
int
):
self
.
_copy_mamba_cache
(
from_index
=
from_index
,
to_index
=
to_index
)
self
.
_update_mapping_index
(
from_index
=
from_index
,
to_index
=
to_index
)
def
_swap_pair_indices_and_mappings
(
self
,
from_index
:
int
,
to_index
:
int
):
self
.
_swap_mamba_cache
(
from_index
=
from_index
,
to_index
=
to_index
)
self
.
_swap_mapping_index
(
from_index
=
from_index
,
to_index
=
to_index
)
def
_swap_mapping_index
(
self
,
from_index
:
int
,
to_index
:
int
):
for
seq_ids2index
in
self
.
mamba_cache_indices_mapping
.
values
():
for
seq_id
,
index
in
seq_ids2index
.
items
():
if
from_index
==
index
:
seq_ids2index
.
update
({
seq_id
:
to_index
})
elif
to_index
==
index
:
seq_ids2index
.
update
({
seq_id
:
from_index
})
def
_update_mapping_index
(
self
,
from_index
:
int
,
to_index
:
int
):
for
seq_ids2index
in
self
.
mamba_cache_indices_mapping
.
values
():
for
seq_id
,
index
in
seq_ids2index
.
items
():
if
from_index
==
index
:
seq_ids2index
.
update
({
seq_id
:
to_index
})
return
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
"""
...
...
@@ -747,28 +779,9 @@ class JambaForCausalLM(nn.Module, HasInnerState):
self
.
_release_mamba_cache
(
finished_requests_ids
)
request_ids_to_seq_ids
=
kwargs
[
"request_ids_to_seq_ids"
]
cg_batch_size
=
input_buffers
[
'input_ids'
].
shape
[
0
]
(
current_mamba_cache
,
indices
,
)
=
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
cg_batch_size
,
finished_requests_ids
)
self
.
current_indices
=
indices
for
input_buffer
,
current_cache_buffer
in
zip
(
input_buffers
[
"seqlen_agnostic_capture_inputs"
],
current_mamba_cache
):
input_buffer
.
copy_
(
current_cache_buffer
,
non_blocking
=
True
)
def
copy_outputs_after_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
"""
Copy the relevant Mamba cache from the CUDA graph input_buffers
back to the JambaForCausalLM.mamba_cache after CUDA
graph replay run is done.
"""
self
.
_copy_mamba_cache_by_indices
(
self
.
current_indices
,
input_buffers
[
"seqlen_agnostic_capture_inputs"
])
def
get_seqlen_agnostic_capture_inputs
(
self
,
batch_size
:
int
):
"""
...
...
@@ -776,26 +789,25 @@ class JambaForCausalLM(nn.Module, HasInnerState):
The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs.
"""
return
tuple
(
buffer
[:,
:
batch_size
]
for
buffer
in
self
.
mamba_gc_cache_buffer
)
return
tuple
(
buffer
[:,
:
batch_size
]
for
buffer
in
self
.
mamba_cache
)
def
_release_mamba_cache
(
self
,
finished_seq_groups_req_ids
:
List
[
str
]):
for
req_id
in
finished_seq_groups_req_ids
:
if
req_id
in
self
.
mamba_cache_indices_mapping
:
self
.
mamba_cache_indices_mapping
.
pop
(
req_id
)
def
_first_free_index_in_mamba_cache
(
self
)
->
int
:
if
self
.
mamba_cache
:
def
_first_free_index_in_mamba_cache
(
self
,
indices_range
:
Optional
[
List
[
int
]]
=
None
)
->
int
:
assert
self
.
mamba_cache
is
not
None
if
indices_range
is
None
:
max_possible_batch_size
=
self
.
mamba_cache
[
0
].
shape
[
1
]
occupied
=
[
id
for
seq_ids
in
self
.
mamba_cache_indices_mapping
.
values
()
for
id
in
seq_ids
.
values
()
]
first_free_index
=
[
i
not
in
occupied
for
i
in
range
(
max_possible_batch_size
)
].
index
(
True
)
return
first_free_index
return
0
indices_range
=
list
(
range
(
max_possible_batch_size
))
all_occupied_indices
=
self
.
_get_all_occupied_indices
()
for
i
in
indices_range
:
if
i
not
in
all_occupied_indices
:
return
i
raise
Exception
(
"Couldn't find a free spot in the mamba cache! This"
"should never happen"
)
def
_get_mamba_cache_shape
(
self
...
...
@@ -819,12 +831,11 @@ class JambaForCausalLM(nn.Module, HasInnerState):
[
layer_type
==
"mamba"
for
layer_type
in
layers_type
])
max_batch_size
=
(
_get_graph_batch_size
(
self
.
scheduler_config
.
max_num_seqs
)
if
self
.
scheduler_config
else
max
(
_BATCH_SIZES_TO_CAPTURE
)
)
+
10
max
(
_BATCH_SIZES_TO_CAPTURE
)
+
2
)
conv_state_shape
,
temporal_state_shape
=
self
.
_get_mamba_cache_shape
()
assert
conv_state_shape
is
not
None
and
temporal_state_shape
is
not
None
for
buffername
in
[
"mamba_cache"
,
"mamba_gc_cache_buffer"
]:
buffer
=
(
torch
.
empty
(
size
=
(
mamba_layers
,
max_batch_size
)
+
self
.
mamba_cache
=
(
torch
.
empty
(
size
=
(
mamba_layers
,
max_batch_size
)
+
conv_state_shape
,
dtype
=
dtype
,
device
=
"cuda"
),
...
...
@@ -832,10 +843,12 @@ class JambaForCausalLM(nn.Module, HasInnerState):
temporal_state_shape
,
dtype
=
dtype
,
device
=
"cuda"
))
setattr
(
self
,
buffername
,
buffer
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
@@ -854,8 +867,6 @@ class JambaForCausalLM(nn.Module, HasInnerState):
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
# Params for weights, fp8 weight scales, fp8 activation scales
...
...
@@ -877,6 +888,10 @@ class JambaForCausalLM(nn.Module, HasInnerState):
if
".self_attn."
in
name
:
name
=
name
.
replace
(
".self_attn"
,
""
)
if
"feed_forward"
in
name
and
not
_is_moe_layer
(
name
):
## map MLP layers to expert with ID=0
name
=
name
.
replace
(
"feed_forward"
,
"feed_forward.experts.0"
)
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
...
...
@@ -891,10 +906,15 @@ class JambaForCausalLM(nn.Module, HasInnerState):
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
for
(
param_name
,
weight_name
,
expert_id
,
shard_id
,
)
in
expert_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
...
...
@@ -913,3 +933,11 @@ class JambaForCausalLM(nn.Module, HasInnerState):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
def
_is_moe_layer
(
name
:
str
):
return
any
(
[
experts_name
in
name
for
experts_name
in
[
"experts"
,
"router"
,
]])
vllm/model_executor/models/llama.py
View file @
af7f4372
...
...
@@ -145,6 +145,7 @@ class LlamaAttention(nn.Module):
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
input_size
=
self
.
total_num_heads
*
self
.
head_dim
,
output_size
=
hidden_size
,
...
...
@@ -153,12 +154,17 @@ class LlamaAttention(nn.Module):
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
is_neox_style
=
True
if
quant_config
is
not
None
and
quant_config
.
get_name
()
==
"gguf"
:
is_neox_style
=
False
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
is_neox_style
=
is_neox_style
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
...
...
@@ -291,6 +297,7 @@ class LlamaModel(nn.Module):
self
.
vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
quant_config
=
quant_config
,
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
...
...
@@ -444,8 +451,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
attn_metadata
,
intermediate_tensors
)
return
model_output
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/llava.py
View file @
af7f4372
from
typing
import
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
import
itertools
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
import
torch
import
torch.nn
as
nn
from
transformers
import
CLIPVisionConfig
,
LlavaConfig
from
transformers
import
CLIPVisionConfig
,
LlavaConfig
,
SiglipVisionConfig
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.clip
import
CLIPVisionModel
from
vllm.model_executor.models.llama
import
LlamaModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
.clip
import
(
dummy_image_for_clip
,
dummy_seq_data_for_clip
,
get_max_clip_image_tokens
,
input_processor_for_clip
)
from
.interfaces
import
SupportsVision
from
.utils
import
merge_vision_embeddings
from
.clip
import
(
CLIPVisionModel
,
dummy_image_for_clip
,
dummy_seq_data_for_clip
,
get_max_clip_image_tokens
,
input_processor_for_clip
)
from
.interfaces
import
SupportsMultiModal
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
dummy_seq_data_for_siglip
,
get_max_siglip_image_tokens
,
input_processor_for_siglip
)
from
.utils
import
(
filter_weights
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
_KEYS_TO_MODIFY_MAPPING
=
{
"language_model.lm_head"
:
"lm_head"
,
"language_model.model"
:
"language_model"
,
}
class
LlavaImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
data
:
torch
.
Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""
class
LlavaImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
data
:
torch
.
Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
LlavaImageInputs
=
Union
[
LlavaImagePixelInputs
,
LlavaImageEmbeddingInputs
]
# TODO(xwjiang): Run benchmark and decide if TP.
...
...
@@ -53,38 +67,56 @@ class LlavaMultiModalProjector(nn.Module):
return
hidden_states
class
LlavaImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
data
:
torch
.
Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""
LlavaImageInputs
=
LlavaImagePixelInputs
def
get_max_llava_image_tokens
(
ctx
:
InputContext
):
hf_config
=
ctx
.
get_hf_config
(
LlavaConfig
)
vision_config
=
hf_config
.
vision_config
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
return
get_max_clip_image_tokens
(
vision_config
)
num_image_tokens
=
get_max_clip_image_tokens
(
vision_config
)
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
num_image_tokens
=
get_max_siglip_image_tokens
(
vision_config
)
else
:
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
strategy
=
hf_config
.
vision_feature_select_strategy
if
strategy
==
"default"
:
return
num_image_tokens
-
1
elif
strategy
==
"full"
:
return
num_image_tokens
else
:
raise
ValueError
(
f
"Unexpected select feature strategy:
{
strategy
}
"
)
def
dummy_data_for_llava
(
ctx
:
InputContext
,
seq_len
:
int
):
def
dummy_data_for_llava
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
hf_config
=
ctx
.
get_hf_config
(
LlavaConfig
)
vision_config
=
hf_config
.
vision_config
num_images
=
mm_counts
[
"image"
]
image_feature_size
=
get_max_llava_image_tokens
(
ctx
)
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
seq_data
=
dummy_seq_data_for_clip
(
vision_config
,
seq_len
,
num_images
,
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
)
mm_data
=
dummy_image_for_clip
(
vision_config
)
mm_data
=
dummy_image_for_clip
(
vision_config
,
num_images
)
return
seq_data
,
mm_data
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
seq_data
=
dummy_seq_data_for_siglip
(
vision_config
,
seq_len
,
num_images
,
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
)
mm_data
=
dummy_image_for_siglip
(
vision_config
,
num_images
)
return
seq_data
,
mm_data
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
...
...
@@ -100,12 +132,49 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
hf_config
=
ctx
.
get_hf_config
(
LlavaConfig
)
vision_config
=
hf_config
.
vision_config
image_feature_size
=
get_max_llava_image_tokens
(
ctx
)
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
return
input_processor_for_clip
(
model_config
,
vision_config
,
llm_inputs
,
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
)
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
return
input_processor_for_siglip
(
model_config
,
vision_config
,
llm_inputs
,
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
)
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
def
_init_vision_tower
(
hf_config
:
LlavaConfig
):
vision_config
=
hf_config
.
vision_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer
=
hf_config
.
vision_feature_layer
if
vision_feature_layer
<
0
:
num_hidden_layers
=
hf_config
.
vision_config
.
num_hidden_layers
\
+
vision_feature_layer
+
1
else
:
num_hidden_layers
=
vision_feature_layer
+
1
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
return
CLIPVisionModel
(
vision_config
,
num_hidden_layers_override
=
num_hidden_layers
,
)
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
return
SiglipVisionModel
(
vision_config
,
num_hidden_layers_override
=
num_hidden_layers
,
)
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
...
...
@@ -116,7 +185,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_llava_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_llava
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_llava
)
class
LlavaForConditionalGeneration
(
nn
.
Module
,
Supports
Vision
):
class
LlavaForConditionalGeneration
(
nn
.
Module
,
Supports
MultiModal
):
def
__init__
(
self
,
config
:
LlavaConfig
,
...
...
@@ -128,36 +197,15 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer
=
config
.
vision_feature_layer
if
vision_feature_layer
<
0
:
num_hidden_layers
=
config
.
vision_config
.
num_hidden_layers
\
+
vision_feature_layer
+
1
else
:
num_hidden_layers
=
vision_feature_layer
+
1
# TODO: Optionally initializes this for supporting embeddings.
self
.
vision_tower
=
CLIPVisionModel
(
config
.
vision_config
,
num_hidden_layers_override
=
num_hidden_layers
)
self
.
vision_tower
=
_init_vision_tower
(
config
)
self
.
multi_modal_projector
=
LlavaMultiModalProjector
(
vision_hidden_size
=
config
.
vision_config
.
hidden_size
,
text_hidden_size
=
config
.
text_config
.
hidden_size
,
projector_hidden_act
=
config
.
projector_hidden_act
)
self
.
quant_config
=
quant_config
self
.
language_model
=
LlamaModel
(
config
.
text_config
,
cache_config
,
quant_config
)
self
.
unpadded_vocab_size
=
config
.
text_config
.
vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
text_config
.
hidden_size
,
org_num_embeddings
=
self
.
language_model
.
org_vocab_size
,
quant_config
=
quant_config
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
text_config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
Sampler
()
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
cache_config
,
quant_config
)
def
_validate_pixel_values
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
h
=
w
=
self
.
config
.
vision_config
.
image_size
...
...
@@ -175,19 +223,31 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
LlavaImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
if
pixel_values
is
None
:
if
pixel_values
is
None
and
image_embeds
is
None
:
return
None
if
pixel_values
is
not
None
:
if
not
isinstance
(
pixel_values
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
return
LlavaImagePixelInputs
(
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
pixel_values
),
)
if
image_embeds
is
not
None
:
if
not
isinstance
(
image_embeds
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of image embeddings. "
f
"Got type:
{
type
(
image_embeds
)
}
"
)
return
LlavaImageEmbeddingInputs
(
type
=
"image_embeds"
,
data
=
image_embeds
,
)
raise
AssertionError
(
"This line should be unreachable."
)
def
_select_image_features
(
self
,
image_features
:
torch
.
Tensor
,
*
,
strategy
:
str
)
->
torch
.
Tensor
:
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
...
...
@@ -198,8 +258,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
raise
ValueError
(
f
"Unexpected select feature strategy:
{
strategy
}
"
)
def
_image_pixels_to_features
(
self
,
vision_tower
:
CLIPVisionModel
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_image_pixels_to_features
(
self
,
vision_tower
:
Union
[
CLIPVisionModel
,
SiglipVisionModel
],
pixel_values
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
...
...
@@ -220,6 +283,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def
_process_image_input
(
self
,
image_input
:
LlavaImageInputs
)
->
torch
.
Tensor
:
if
image_input
[
"type"
]
==
"image_embeds"
:
return
image_input
[
"data"
]
assert
self
.
vision_tower
is
not
None
image_features
=
self
.
_process_image_pixels
(
image_input
)
return
self
.
multi_modal_projector
(
image_features
)
...
...
@@ -272,9 +339,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
if
image_input
is
not
None
:
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
merge_
vision
_embeddings
(
inputs_embeds
=
merge_
multimodal
_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
config
.
image_token_index
)
...
...
@@ -282,7 +350,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
else
:
inputs_embeds
=
None
hidden_states
=
self
.
language_model
(
input_ids
,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
...
...
@@ -291,59 +359,38 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# only doing this for language model part for now.
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
# post_layernorm is not needed in CLIPVisionModel
if
"vision_model.post_layernorm"
in
name
:
continue
for
key_to_modify
,
new_key
in
_KEYS_TO_MODIFY_MAPPING
.
items
():
if
key_to_modify
in
name
:
name
=
name
.
replace
(
key_to_modify
,
new_key
)
use_default_weight_loading
=
False
if
"vision"
in
name
:
if
self
.
vision_tower
is
not
None
:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading
=
True
else
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
use_default_weight_loading
=
True
if
use_default_weight_loading
and
name
in
params_dict
:
param
=
params_dict
[
name
]
# prepare weight iterators for components
vit_weights
,
mlp_weights
,
llm_weights
=
itertools
.
tee
(
weights
,
3
)
# load vision encoder
vit_weights
=
filter_weights
(
vit_weights
,
"vision_tower"
)
self
.
vision_tower
.
load_weights
(
vit_weights
)
# load mlp projector
mlp_weights
=
filter_weights
(
mlp_weights
,
"multi_modal_projector"
)
mlp_params_dict
=
dict
(
self
.
multi_modal_projector
.
named_parameters
())
for
name
,
loaded_weight
in
mlp_weights
:
param
=
mlp_params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load llm backbone
llm_weights
=
filter_weights
(
llm_weights
,
"language_model"
)
self
.
language_model
.
load_weights
(
llm_weights
)
vllm/model_executor/models/llava_next.py
View file @
af7f4372
from
typing
import
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
,
Union
import
itertools
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
import
torch
import
torch.nn
as
nn
from
PIL
import
Image
from
transformers
import
CLIPVisionConfig
,
LlavaNextConfig
from
transformers
import
CLIPVisionConfig
,
LlavaNextConfig
,
SiglipVisionConfig
from
transformers.models.llava_next.modeling_llava_next
import
(
get_anyres_image_grid_shape
,
unpad_image
)
from
typing_extensions
import
NotRequired
...
...
@@ -12,23 +14,22 @@ from vllm.attention import AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.clip
import
CLIPVisionModel
from
vllm.model_executor.models.llama
import
LlamaModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
.clip
import
(
dummy_image_for_clip
,
dummy_seq_data_for_clip
,
from
.clip
import
(
CLIPVisionModel
,
dummy_image_for_clip
,
dummy_seq_data_for_clip
,
get_clip_image_feature_size
,
get_clip_patch_grid_length
,
input_processor_for_clip
)
from
.interfaces
import
Supports
Vision
from
.interfaces
import
Supports
MultiModal
from
.llava
import
LlavaMultiModalProjector
from
.utils
import
merge_vision_embeddings
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
dummy_seq_data_for_siglip
,
get_siglip_image_feature_size
,
get_siglip_patch_grid_length
,
input_processor_for_siglip
)
from
.utils
import
(
filter_weights
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
logger
=
init_logger
(
__name__
)
...
...
@@ -59,7 +60,17 @@ class LlavaNextImagePixelInputs(TypedDict):
"""
LlavaNextImageInputs
=
LlavaNextImagePixelInputs
class
LlavaNextImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
data
:
torch
.
Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
LlavaNextImageInputs
=
Union
[
LlavaNextImagePixelInputs
,
LlavaNextImageEmbeddingInputs
]
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
...
...
@@ -104,7 +115,24 @@ def get_llava_next_image_feature_size(
image_size
=
vision_config
.
image_size
,
patch_size
=
vision_config
.
patch_size
,
)
base_feature_size
=
num_patches
*
num_patches
base_feature_size
=
get_clip_image_feature_size
(
vision_config
)
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
num_patches
=
get_siglip_patch_grid_length
(
image_size
=
vision_config
.
image_size
,
patch_size
=
vision_config
.
patch_size
,
)
base_feature_size
=
get_siglip_image_feature_size
(
vision_config
)
else
:
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
strategy
=
hf_config
.
vision_feature_select_strategy
if
strategy
==
"default"
:
base_feature_size
-=
1
elif
strategy
==
"full"
:
pass
else
:
raise
ValueError
(
f
"Unexpected select feature strategy:
{
strategy
}
"
)
num_patch_height
,
num_patch_width
=
get_anyres_image_grid_shape
(
image_size
=
(
input_height
,
input_width
),
...
...
@@ -116,18 +144,13 @@ def get_llava_next_image_feature_size(
unpadded_feature_size
,
newline_feature_size
,
)
=
_get_llava_next_num_unpadded_features
(
input_height
,
input_width
,
num_patches
,
num_patch_height
,
num_patches
,
num_patch_height
,
num_patch_width
)
return
unpadded_feature_size
+
newline_feature_size
+
base_feature_size
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
def
get_max_llava_next_image_tokens
(
ctx
:
InputContext
):
return
get_llava_next_image_feature_size
(
ctx
.
get_hf_config
(
LlavaNextConfig
),
input_height
=
MAX_IMAGE_FEATURE_SIZE_HEIGHT
,
...
...
@@ -135,9 +158,11 @@ def get_max_llava_next_image_tokens(ctx: InputContext):
)
def
dummy_data_for_llava_next
(
ctx
:
InputContext
,
seq_len
:
int
):
def
dummy_data_for_llava_next
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
hf_config
=
ctx
.
get_hf_config
(
LlavaNextConfig
)
vision_config
=
hf_config
.
vision_config
num_images
=
mm_counts
[
"image"
]
image_feature_size
=
get_max_llava_next_image_tokens
(
ctx
)
...
...
@@ -145,12 +170,31 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
seq_data
=
dummy_seq_data_for_clip
(
vision_config
,
seq_len
,
num_images
,
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
)
mm_data
=
dummy_image_for_clip
(
vision_config
,
num_images
,
image_width_override
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
,
image_height_override
=
MAX_IMAGE_FEATURE_SIZE_HEIGHT
,
)
return
seq_data
,
mm_data
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
seq_data
=
dummy_seq_data_for_siglip
(
vision_config
,
seq_len
,
num_images
,
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
)
mm_data
=
dummy_image_for_siglip
(
vision_config
,
num_images
,
image_width_override
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
,
image_height_override
=
MAX_IMAGE_FEATURE_SIZE_HEIGHT
,
)
...
...
@@ -180,7 +224,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
input_width
=
width
,
)
elif
isinstance
(
image_data
,
torch
.
Tensor
):
raise
NotImplementedError
(
"Embeddings input is not supported yet"
)
image_feature_size
=
image_data
.
shape
[
0
]
else
:
raise
TypeError
(
f
"Invalid image type:
{
type
(
image_data
)
}
"
)
...
...
@@ -194,6 +238,40 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
)
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
return
input_processor_for_siglip
(
model_config
,
vision_config
,
llm_inputs
,
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
)
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
def
_init_vision_tower
(
hf_config
:
LlavaNextConfig
):
vision_config
=
hf_config
.
vision_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer
=
hf_config
.
vision_feature_layer
if
vision_feature_layer
<
0
:
num_hidden_layers
=
hf_config
.
vision_config
.
num_hidden_layers
\
+
vision_feature_layer
+
1
else
:
num_hidden_layers
=
vision_feature_layer
+
1
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
return
CLIPVisionModel
(
vision_config
,
num_hidden_layers_override
=
num_hidden_layers
,
)
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
return
SiglipVisionModel
(
vision_config
,
num_hidden_layers_override
=
num_hidden_layers
,
)
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
...
...
@@ -203,7 +281,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_llava_next_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_llava_next
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_llava_next
)
class
LlavaNextForConditionalGeneration
(
nn
.
Module
,
Supports
Vision
):
class
LlavaNextForConditionalGeneration
(
nn
.
Module
,
Supports
MultiModal
):
def
__init__
(
self
,
config
:
LlavaNextConfig
,
...
...
@@ -215,36 +293,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer
=
config
.
vision_feature_layer
if
vision_feature_layer
<
0
:
num_hidden_layers
=
config
.
vision_config
.
num_hidden_layers
\
+
vision_feature_layer
+
1
else
:
num_hidden_layers
=
vision_feature_layer
+
1
# TODO: Optionally initializes this for supporting embeddings.
self
.
vision_tower
=
CLIPVisionModel
(
config
.
vision_config
,
num_hidden_layers_override
=
num_hidden_layers
)
self
.
vision_tower
=
_init_vision_tower
(
config
)
self
.
multi_modal_projector
=
LlavaMultiModalProjector
(
vision_hidden_size
=
config
.
vision_config
.
hidden_size
,
text_hidden_size
=
config
.
text_config
.
hidden_size
,
projector_hidden_act
=
config
.
projector_hidden_act
)
self
.
quant_config
=
quant_config
self
.
language_model
=
LlamaModel
(
config
.
text_config
,
cache_config
,
quant_config
)
self
.
unpadded_vocab_size
=
config
.
text_config
.
vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
text_config
.
hidden_size
,
org_num_embeddings
=
self
.
language_model
.
org_vocab_size
,
quant_config
=
quant_config
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
text_config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
Sampler
()
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
cache_config
,
quant_config
)
self
.
image_newline
=
nn
.
Parameter
(
torch
.
empty
(
config
.
text_config
.
hidden_size
))
...
...
@@ -279,13 +336,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
return
data
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
LlavaNextImage
Pixel
Inputs
]:
self
,
**
kwargs
:
object
)
->
Optional
[
LlavaNextImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
image_sizes
=
kwargs
.
pop
(
"image_sizes"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
if
pixel_values
is
None
:
if
pixel_values
is
None
and
image_embeds
is
None
:
return
None
if
pixel_values
is
not
None
:
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
...
...
@@ -300,6 +359,18 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
image_sizes
=
self
.
_validate_image_sizes
(
image_sizes
),
)
if
image_embeds
is
not
None
:
if
not
isinstance
(
image_embeds
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of image embeds. "
f
"Got type:
{
type
(
image_embeds
)
}
"
)
return
LlavaNextImageEmbeddingInputs
(
type
=
"image_embeds"
,
data
=
image_embeds
,
)
raise
AssertionError
(
"This line should be unreachable."
)
def
_select_image_features
(
self
,
image_features
:
torch
.
Tensor
,
*
,
strategy
:
str
)
->
torch
.
Tensor
:
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
...
...
@@ -310,8 +381,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
raise
ValueError
(
f
"Unexpected select feature strategy:
{
strategy
}
"
)
def
_image_pixels_to_features
(
self
,
vision_tower
:
CLIPVisionModel
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_image_pixels_to_features
(
self
,
vision_tower
:
Union
[
CLIPVisionModel
,
SiglipVisionModel
],
pixel_values
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
...
...
@@ -422,6 +496,10 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
self
,
image_input
:
LlavaNextImageInputs
,
)
->
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]:
if
image_input
[
"type"
]
==
"image_embeds"
:
return
[
image_input
[
"data"
]]
patch_embeddings
=
self
.
_process_image_pixels
(
image_input
)
image_sizes
=
image_input
.
get
(
"image_sizes"
)
...
...
@@ -496,9 +574,10 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
if
image_input
is
not
None
:
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
merge_
vision
_embeddings
(
inputs_embeds
=
merge_
multimodal
_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
config
.
image_token_index
)
...
...
@@ -506,7 +585,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
else
:
inputs_embeds
=
None
hidden_states
=
self
.
language_model
(
input_ids
,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
...
...
@@ -515,59 +594,48 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# only doing this for language model part for now.
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
# post_layernorm is not needed in CLIPVisionModel
if
"vision_model.post_layernorm"
in
name
:
continue
for
key_to_modify
,
new_key
in
_KEYS_TO_MODIFY_MAPPING
.
items
():
if
key_to_modify
in
name
:
name
=
name
.
replace
(
key_to_modify
,
new_key
)
use_default_weight_loading
=
False
if
"vision"
in
name
:
if
self
.
vision_tower
is
not
None
:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading
=
True
else
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
use_default_weight_loading
=
True
if
use_default_weight_loading
and
name
in
params_dict
:
param
=
params_dict
[
name
]
# prepare weight iterators for components
vit_weights
,
mlp_weights
,
newline_weights
,
llm_weights
=
itertools
.
tee
(
weights
,
4
)
# load vision encoder
vit_weights
=
filter_weights
(
vit_weights
,
"vision_tower"
)
self
.
vision_tower
.
load_weights
(
vit_weights
)
# load mlp projector
mlp_weights
=
filter_weights
(
mlp_weights
,
"multi_modal_projector"
)
mlp_params_dict
=
dict
(
self
.
multi_modal_projector
.
named_parameters
())
for
name
,
loaded_weight
in
mlp_weights
:
param
=
mlp_params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load newline
newline_weights
=
filter_weights
(
newline_weights
,
"image_newline"
)
for
name
,
loaded_weight
in
newline_weights
:
assert
name
==
""
param
=
self
.
image_newline
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load llm backbone
llm_weights
=
filter_weights
(
llm_weights
,
"language_model"
)
self
.
language_model
.
load_weights
(
llm_weights
)
vllm/model_executor/models/medusa.py
View file @
af7f4372
...
...
@@ -30,6 +30,19 @@ class ResidualBlock(nn.Module):
class
Medusa
(
nn
.
Module
):
"""This class implements the Medusa draft model from the paper: https://arxiv.org/abs/2401.10774
Reference implementation: https://github.com/FasterDecoding/Medusa
Differences from reference implementation:
1. Currently this only supports generating proposals from top-1 tokens.
2. We have an optional token_map which reduces draft vocab to most
frequently used tokens to give some additional speed-up by reducing
sampling overhead. This is disabled unless the checkpoint file has
explicit token_map tensor and config has an optional attribute
truncated_vocab_size < vocab_size. To use this technique, one has to find
the top-k most frequent tokens in target dataset and add that as a tensor
in the draft checkpoint (using key token_map). Also, the draft config
needs to have truncated_vocab_size (=k) as an attribute."""
def
__init__
(
self
,
config
:
MedusaConfig
,
**
_
)
->
None
:
super
().
__init__
()
...
...
@@ -57,6 +70,12 @@ class Medusa(nn.Module):
self
.
truncated_vocab_size
,
logit_scale
)
# Token map is a idx to token mapping to reduce the vocab size for
# the draft model. Using smaller vocab size for draft, containing
# only most frequent tokens reduces the speculation overhead. This
# doesn't affect the acceptance rate much and thus gives more speed
# -up. By default, this is disabled and is only used if the EAGLE
# checkpoint file has token_map tensor.
self
.
token_map
=
None
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
List
[
torch
.
Tensor
]:
...
...
@@ -65,22 +84,28 @@ class Medusa(nn.Module):
def
compute_logits
(
self
,
hidden_states
:
List
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
)
->
List
[
torch
.
Tensor
]:
logits
=
[]
logits
_lst
:
List
[
torch
.
Tensor
]
=
[]
for
hs
,
lm_head
in
zip
(
hidden_states
,
self
.
lm_heads
):
_logits
=
self
.
logits_processor
(
lm_head
,
hs
,
sampling_metadata
)
if
_logits
is
None
:
# _logits should only be None on rank > 0, in which case
# it should remain true for every lm_head
assert
len
(
logits_lst
)
==
0
continue
if
self
.
token_map
is
None
:
logits
.
append
(
_logits
)
logits
_lst
.
append
(
_logits
)
else
:
logits
.
append
(
-
torch
.
inf
*
torch
.
ones
(
logits
_lst
.
append
(
-
torch
.
inf
*
torch
.
ones
(
size
=
(
*
_logits
.
shape
[:
-
1
],
self
.
orig_vocab_size
),
device
=
_logits
.
device
,
dtype
=
_logits
.
dtype
))
logits
[
-
1
][...,
self
.
token_map
]
=
_logits
logits
_lst
[
-
1
][...,
self
.
token_map
]
=
_logits
return
logits
return
logits
_lst
def
sample
(
self
,
...
...
vllm/model_executor/models/minicpm.py
View file @
af7f4372
...
...
@@ -470,8 +470,11 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
attn_metadata
,
intermediate_tensors
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
hidden_states
=
hidden_states
/
self
.
scale_width
if
self
.
config
.
tie_word_embeddings
:
lm_head
=
self
.
model
.
embed_tokens
...
...
vllm/model_executor/models/minicpmv.py
View file @
af7f4372
...
...
@@ -23,9 +23,10 @@
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
import
math
import
re
from
array
import
array
from
functools
import
partial
from
typing
import
(
Any
,
Callable
,
Iterable
,
List
,
Optional
,
Tuple
,
TypedDict
,
Union
)
from
typing
import
(
Any
,
Callable
,
Iterable
,
List
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
import
numpy
as
np
import
torch
...
...
@@ -34,7 +35,7 @@ import torch.types
from
PIL
import
Image
from
torch
import
nn
from
torch.nn.init
import
trunc_normal_
from
transformers
.configuration_utils
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
...
...
@@ -42,21 +43,21 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
Supports
Vision
from
vllm.model_executor.models.interfaces
import
Supports
MultiModal
from
vllm.model_executor.models.llama
import
LlamaModel
from
vllm.model_executor.models.minicpm
import
MiniCPMModel
from
vllm.model_executor.models.qwen2
import
Qwen2Model
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.image
import
(
cached_get_image_processor
,
cached_get_tokenizer
)
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
,
SequenceData
from
vllm.multimodal.image
import
cached_get_image_processor
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SamplerOutput
,
SequenceData
)
from
.idefics2_vision_model
import
Idefics2VisionTransformer
...
...
@@ -216,7 +217,6 @@ class BaseResampler(nn.Module):
self
.
query
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_queries
,
embed_dim
))
trunc_normal_
(
self
.
query
,
std
=
0.02
)
if
kv_dim
is
not
None
and
kv_dim
!=
embed_dim
:
self
.
kv_proj
=
ReplicatedLinear
(
kv_dim
,
embed_dim
,
bias
=
False
)
else
:
...
...
@@ -225,7 +225,6 @@ class BaseResampler(nn.Module):
nn
.
Identity
()(
*
args
,
**
kwargs
),
None
,
)
self
.
attn
=
nn
.
MultiheadAttention
(
embed_dim
,
num_heads
)
self
.
ln_q
=
norm_layer
(
embed_dim
)
self
.
ln_kv
=
norm_layer
(
embed_dim
)
...
...
@@ -261,7 +260,6 @@ class Resampler2(BaseResampler):
norm_layer
)
self
.
adaptive
=
adaptive
pos_embed_arr
=
get_2d_sincos_pos_embed
(
embed_dim
,
grid_size
,
version
=
(
2
,
0
))
...
...
@@ -407,26 +405,28 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
def
get_max_minicpmv_image_tokens
(
ctx
:
InputContext
):
hf_config
=
ctx
.
get_hf_config
(
PretrainedConfig
)
hf_config
=
ctx
.
get_hf_config
()
return
getattr
(
hf_config
,
"query_num"
,
64
)
def
dummy_seq_data_for_minicpmv
(
seq_len
:
int
):
token_ids
=
[
0
]
*
seq_len
def
dummy_seq_data_for_minicpmv
(
seq_len
:
int
,
num_images
:
int
):
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
]
)
*
seq_len
return
SequenceData
(
token_ids
)
def
dummy_image_for_minicpmv
(
hf_config
:
PretrainedConfig
):
def
dummy_image_for_minicpmv
(
hf_config
:
PretrainedConfig
,
num_images
:
int
):
width
=
height
=
hf_config
.
image_size
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
)
return
{
"image"
:
image
}
return
{
"image"
:
image
if
num_images
==
1
else
[
image
]
*
num_images
}
def
dummy_data_for_minicpmv
(
ctx
:
InputContext
,
seq_len
:
int
):
hf_config
=
ctx
.
get_hf_config
(
PretrainedConfig
)
def
dummy_data_for_minicpmv
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
hf_config
=
ctx
.
get_hf_config
()
num_images
=
mm_counts
[
"image"
]
seq_data
=
dummy_seq_data_for_minicpmv
(
seq_len
)
mm_data
=
dummy_image_for_minicpmv
(
hf_config
)
seq_data
=
dummy_seq_data_for_minicpmv
(
seq_len
,
num_images
)
mm_data
=
dummy_image_for_minicpmv
(
hf_config
,
num_images
)
return
seq_data
,
mm_data
...
...
@@ -482,7 +482,7 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
return
llm_inputs
class
MiniCPMVBaseModel
(
nn
.
Module
,
Supports
Vision
):
class
MiniCPMVBaseModel
(
nn
.
Module
,
Supports
MultiModal
):
"""
The abstract class of MiniCPMV can only be inherited, but cannot be
instantiated.
...
...
@@ -496,6 +496,10 @@ class MiniCPMVBaseModel(nn.Module, SupportsVision):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
# All MiniCPM-V models disable `tie_word_embeddings` but
# `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
# check `tie_word_embeddings` until vLLM integrate MiniCPM-V model
# and config class
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
...
...
@@ -633,8 +637,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsVision):
)
return
output
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
@@ -717,7 +724,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsVision):
raise
NotImplementedError
class
MiniCPMV2
(
MiniCPMVBaseModel
):
class
MiniCPMV2
_0
(
MiniCPMVBaseModel
):
def
__init__
(
self
,
...
...
@@ -890,10 +897,7 @@ class MiniCPMV2_5(MiniCPMVBaseModel):
return
"resampler"
in
name
# NOTE: Currently, information about this model is unavailable. We are
# temporarily using `MiniCPMVQwen2` as it's name. The name may need
# to be modified in the future.
class
MiniCPMVQwen2
(
MiniCPMVBaseModel
):
class
MiniCPMV2_6
(
MiniCPMVBaseModel
):
def
__init__
(
self
,
...
...
@@ -903,6 +907,7 @@ class MiniCPMVQwen2(MiniCPMVBaseModel):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
(
config
,
multimodal_config
,
cache_config
,
quant_config
)
assert
self
.
version
==
(
2
,
6
)
def
init_llm
(
self
,
...
...
@@ -930,6 +935,7 @@ class MiniCPMVQwen2(MiniCPMVBaseModel):
def
init_resampler
(
self
,
embed_dim
:
int
,
vision_dim
:
int
)
->
nn
.
Module
:
with
set_default_torch_dtype
(
torch
.
float16
):
# The resampler in 2.6 remains consistent with the one in 2.5.
resampler
=
Resampler2_5
(
num_queries
=
self
.
config
.
query_num
,
embed_dim
=
embed_dim
,
...
...
@@ -989,6 +995,13 @@ class MiniCPMVQwen2(MiniCPMVBaseModel):
return
"resampler"
in
name
or
"vpm"
in
name
_SUPPORT_VERSION
=
{
(
2
,
0
):
MiniCPMV2_0
,
(
2
,
5
):
MiniCPMV2_5
,
(
2
,
6
):
MiniCPMV2_6
}
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
()
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_minicpmv_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_minicpmv
)
...
...
@@ -1016,11 +1029,9 @@ class MiniCPMV(MiniCPMVBaseModel):
version
=
str
(
config
.
version
).
split
(
"."
)
version
=
tuple
([
int
(
x
)
for
x
in
version
])
# Dispatch class based on version
if
version
==
(
2
,
0
):
instance_class
=
MiniCPMV2
elif
version
==
(
2
,
5
):
instance_class
=
MiniCPMV2_5
else
:
instance_class
=
MiniCPMVQwen2
instance_class
=
_SUPPORT_VERSION
.
get
(
version
,
None
)
if
instance_class
is
None
:
raise
ValueError
(
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6"
)
return
instance_class
(
config
,
multimodal_config
,
cache_config
,
quant_config
)
vllm/model_executor/models/mixtral.py
View file @
af7f4372
...
...
@@ -359,6 +359,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
quant_config
=
quant_config
,
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
@@ -375,8 +377,11 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
attn_metadata
,
intermediate_tensors
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
@@ -452,7 +457,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
weight_
name
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
break
...
...
vllm/model_executor/models/mixtral_quant.py
View file @
af7f4372
...
...
@@ -347,6 +347,8 @@ class MixtralForCausalLM(nn.Module):
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
@@ -362,8 +364,11 @@ class MixtralForCausalLM(nn.Module):
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/mlp_speculator.py
View file @
af7f4372
...
...
@@ -56,6 +56,15 @@ class MLPSpeculatorLayerNorm(nn.Module):
class
MLPSpeculator
(
nn
.
Module
):
"""
An implementation of the speculative models introduced in
"Accelerating Production LLMs with Combined Token/Embedding
Speculators"
https://arxiv.org/pdf/2404.19124
Trained speculators of this type are available on HF hub at:
https://huggingface.co/ibm-fms and https://huggingface.co/ibm-granite
"""
def
__init__
(
self
,
config
:
MLPSpeculatorConfig
,
**
kwargs
)
->
None
:
super
().
__init__
()
...
...
@@ -166,13 +175,14 @@ class MLPSpeculator(nn.Module):
states
.
add_
(
z
,
alpha
=
self
.
emb_weight
/
self
.
state_weight
)
states
=
self
.
activation
(
self
.
ln
[
head_index
](
states
))
# b k d
# TODO: not yet supporting top_k_tokens_per_head
previous_hidden_states
=
states
# TODO: not yet supporting top_k_tokens_per_head
states
=
states
.
flatten
(
0
,
1
)
logits
=
self
.
logits_processor
(
self
.
head
[
head_index
],
states
,
sampling_metadata
)
output
=
self
.
sampler
(
logits
.
flatten
(
0
,
1
)
,
sampling_metadata
)
output
=
self
.
sampler
(
logits
,
sampling_metadata
)
last_tokens
=
output
.
sampled_token_ids
next_tokens
.
append
(
output
)
...
...
vllm/model_executor/models/mpt.py
View file @
af7f4372
...
...
@@ -279,8 +279,11 @@ class MPTForCausalLM(nn.Module):
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/nemotron.py
View file @
af7f4372
...
...
@@ -53,7 +53,7 @@ from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
# - There is no gate_proj, just up_proj
# - Normal LayerNorm (with a +1 to the weights) instead of RMSNorm
# - Squared ReLU instead of SwiGLU
# - Adds a
rotary_percent
to RoPE
# - Adds a
partial_rotary_factor
to RoPE
def
_cast_if_autocast_enabled
(
*
args
):
...
...
@@ -161,7 +161,7 @@ class NemotronAttention(nn.Module):
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
rotary_percent
=
config
.
rope_percent
self
.
partial_rotary_factor
=
config
.
partial_rotary_factor
self
.
max_position_embeddings
=
max_position_embeddings
self
.
qkv_proj
=
QKVParallelLinear
(
...
...
@@ -187,7 +187,7 @@ class NemotronAttention(nn.Module):
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rotary_percent
=
self
.
rotary_percent
,
partial_rotary_factor
=
self
.
partial_rotary_factor
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
...
...
@@ -453,8 +453,11 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA):
attn_metadata
,
intermediate_tensors
)
return
model_output
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/olmo.py
View file @
af7f4372
...
...
@@ -311,8 +311,11 @@ class OlmoForCausalLM(nn.Module):
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/opt.py
View file @
af7f4372
...
...
@@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
...
...
@@ -307,7 +307,11 @@ class OPTForCausalLM(nn.Module):
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
OPTModel
(
config
,
cache_config
,
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
decoder
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
word_embed_proj_dim
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
@@ -323,8 +327,11 @@ class OPTForCausalLM(nn.Module):
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/orion.py
View file @
af7f4372
...
...
@@ -262,6 +262,8 @@ class OrionForCausalLM(nn.Module):
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
@@ -277,8 +279,11 @@ class OrionForCausalLM(nn.Module):
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
Prev
1
…
19
20
21
22
23
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment