Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
af7f4372
Commit
af7f4372
authored
Sep 03, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.5' into v0.5.5-dtk24.04.1
parents
5e19cdef
09c77926
Changes
448
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
617 additions
and
428 deletions
+617
-428
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+8
-3
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+21
-20
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+20
-24
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+134
-71
vllm/model_executor/models/jais.py
vllm/model_executor/models/jais.py
+61
-22
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+188
-160
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+12
-2
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+173
-126
No files found.
Too many changes to show.
To preserve performance only
448 of 448+
files are displayed.
Plain diff
Email patch
vllm/model_executor/models/gpt_neox.py
View file @
af7f4372
...
...
@@ -230,7 +230,7 @@ class GPTNeoXForCausalLM(nn.Module):
def
__init__
(
self
,
config
,
config
:
GPTNeoXConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
...
...
@@ -243,6 +243,8 @@ class GPTNeoXForCausalLM(nn.Module):
config
.
hidden_size
,
quant_config
=
quant_config
,
)
if
self
.
config
.
tie_word_embeddings
:
self
.
embed_out
.
weight
=
self
.
gpt_neox
.
embed_in
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
@@ -258,8 +260,11 @@ class GPTNeoXForCausalLM(nn.Module):
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
embed_out
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/interfaces.py
View file @
af7f4372
from
typing
import
(
ClassVar
,
Dict
,
List
,
Literal
,
Optional
,
Protocol
,
Type
,
Union
,
overload
,
runtime_checkable
)
from
typing_extensions
import
Type
Guard
from
typing_extensions
import
Type
Is
from
vllm.config
import
LoRAConfig
,
MultiModalConfig
,
SchedulerConfig
from
vllm.logger
import
init_logger
...
...
@@ -10,12 +10,12 @@ logger = init_logger(__name__)
@
runtime_checkable
class
Supports
Vision
(
Protocol
):
"""The interface required for all
vision language models (VLMs)
."""
class
Supports
MultiModal
(
Protocol
):
"""The interface required for all
multi-modal models
."""
supports_
vision
:
ClassVar
[
Literal
[
True
]]
=
True
supports_
multimodal
:
ClassVar
[
Literal
[
True
]]
=
True
"""
A flag that indicates this model supports
vision
inputs.
A flag that indicates this model supports
multi-modal
inputs.
Note:
There is no need to redefine this flag if this class is in the
...
...
@@ -29,30 +29,31 @@ class SupportsVision(Protocol):
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@
runtime_checkable
class
_Supports
Vision
Type
(
Protocol
):
supports_
vision
:
Literal
[
True
]
class
_Supports
MultiModal
Type
(
Protocol
):
supports_
multimodal
:
Literal
[
True
]
def
__call__
(
self
,
*
,
multimodal_config
:
MultiModalConfig
)
->
None
:
...
@
overload
def
supports_vision
(
model
:
Type
[
object
])
->
TypeGuard
[
Type
[
SupportsVision
]]:
def
supports_multimodal
(
model
:
Type
[
object
])
->
TypeIs
[
Type
[
SupportsMultiModal
]]:
...
@
overload
def
supports_
vision
(
model
:
object
)
->
Type
Guard
[
Supports
Vision
]:
def
supports_
multimodal
(
model
:
object
)
->
Type
Is
[
Supports
MultiModal
]:
...
def
supports_
vision
(
def
supports_
multimodal
(
model
:
Union
[
Type
[
object
],
object
],
)
->
Union
[
Type
Guard
[
Type
[
Supports
Vision
]],
Type
Guard
[
Supports
Vision
]]:
)
->
Union
[
Type
Is
[
Type
[
Supports
MultiModal
]],
Type
Is
[
Supports
MultiModal
]]:
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
_Supports
Vision
Type
)
return
isinstance
(
model
,
_Supports
MultiModal
Type
)
return
isinstance
(
model
,
Supports
Vision
)
return
isinstance
(
model
,
Supports
MultiModal
)
@
runtime_checkable
...
...
@@ -94,18 +95,18 @@ class _SupportsLoRAType(Protocol):
@
overload
def
supports_lora
(
model
:
Type
[
object
])
->
Type
Guard
[
Type
[
SupportsLoRA
]]:
def
supports_lora
(
model
:
Type
[
object
])
->
Type
Is
[
Type
[
SupportsLoRA
]]:
...
@
overload
def
supports_lora
(
model
:
object
)
->
Type
Guard
[
SupportsLoRA
]:
def
supports_lora
(
model
:
object
)
->
Type
Is
[
SupportsLoRA
]:
...
def
supports_lora
(
model
:
Union
[
Type
[
object
],
object
],
)
->
Union
[
Type
Guard
[
Type
[
SupportsLoRA
]],
Type
Guard
[
SupportsLoRA
]]:
)
->
Union
[
Type
Is
[
Type
[
SupportsLoRA
]],
Type
Is
[
SupportsLoRA
]]:
result
=
_supports_lora
(
model
)
if
not
result
:
...
...
@@ -137,7 +138,7 @@ def supports_lora(
def
_supports_lora
(
model
:
Union
[
Type
[
object
],
object
],
)
->
Union
[
Type
Guard
[
Type
[
SupportsLoRA
]],
Type
Guard
[
SupportsLoRA
]]:
)
->
Union
[
Type
Is
[
Type
[
SupportsLoRA
]],
Type
Is
[
SupportsLoRA
]]:
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
_SupportsLoRAType
)
...
...
@@ -172,18 +173,18 @@ class _HasInnerStateType(Protocol):
@
overload
def
has_inner_state
(
model
:
object
)
->
Type
Guard
[
HasInnerState
]:
def
has_inner_state
(
model
:
object
)
->
Type
Is
[
HasInnerState
]:
...
@
overload
def
has_inner_state
(
model
:
Type
[
object
])
->
Type
Guard
[
Type
[
HasInnerState
]]:
def
has_inner_state
(
model
:
Type
[
object
])
->
Type
Is
[
Type
[
HasInnerState
]]:
...
def
has_inner_state
(
model
:
Union
[
Type
[
object
],
object
]
)
->
Union
[
Type
Guard
[
Type
[
HasInnerState
]],
Type
Guard
[
HasInnerState
]]:
)
->
Union
[
Type
Is
[
Type
[
HasInnerState
]],
Type
Is
[
HasInnerState
]]:
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
_HasInnerStateType
)
...
...
vllm/model_executor/models/internlm2.py
View file @
af7f4372
...
...
@@ -87,6 +87,7 @@ class InternLM2Attention(nn.Module):
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
key_value_groups
=
int
(
self
.
num_heads
/
self
.
num_kv_heads
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
...
...
@@ -120,6 +121,14 @@ class InternLM2Attention(nn.Module):
cache_config
=
cache_config
,
quant_config
=
quant_config
)
def
split_qkv
(
self
,
qkv
:
torch
.
Tensor
):
qkv
=
qkv
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
key_value_groups
+
2
,
128
)
q
,
k
,
v
=
torch
.
split
(
qkv
,
[
self
.
key_value_groups
,
1
,
1
],
dim
=
2
)
q
=
q
.
reshape
(
-
1
,
self
.
q_size
)
k
=
k
.
reshape
(
-
1
,
self
.
kv_size
)
v
=
v
.
reshape
(
-
1
,
self
.
kv_size
)
return
q
,
k
,
v
def
forward
(
self
,
positions
:
torch
.
Tensor
,
...
...
@@ -128,7 +137,7 @@ class InternLM2Attention(nn.Module):
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
wqkv
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
self
.
split_qkv
(
qkv
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
wo
(
attn_output
)
...
...
@@ -264,6 +273,8 @@ class InternLM2ForCausalLM(nn.Module):
self
.
output
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
output
.
weight
=
self
.
model
.
tok_embeddings
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
@@ -279,8 +290,11 @@ class InternLM2ForCausalLM(nn.Module):
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
output
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
@@ -319,24 +333,6 @@ class InternLM2ForCausalLM(nn.Module):
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
if
"wqkv"
in
name
:
config
=
self
.
config
kv_groups
=
(
config
.
num_attention_heads
//
config
.
num_key_value_heads
)
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
loaded_weight
=
loaded_weight
.
view
(
-
1
,
2
+
kv_groups
,
head_dim
,
loaded_weight
.
shape
[
-
1
])
wq
,
wk
,
wv
=
torch
.
split
(
loaded_weight
,
[
kv_groups
,
1
,
1
],
dim
=
1
)
wq
=
wq
.
reshape
(
-
1
,
wq
.
shape
[
-
1
])
wk
=
wk
.
reshape
(
-
1
,
wk
.
shape
[
-
1
])
wv
=
wv
.
reshape
(
-
1
,
wv
.
shape
[
-
1
])
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
wq
,
'q'
)
weight_loader
(
param
,
wk
,
'k'
)
weight_loader
(
param
,
wv
,
'v'
)
else
:
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/internvl.py
View file @
af7f4372
...
...
@@ -5,7 +5,8 @@
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import
itertools
from
typing
import
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
,
Union
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
import
torch
import
torch.nn
as
nn
...
...
@@ -18,18 +19,18 @@ from vllm.config import CacheConfig, MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models.intern_vit
import
InternVisionModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.
image
import
cached_get_tokenizer
from
vllm.multimodal.
utils
import
cached_get_tokenizer
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
.clip
import
(
dummy_image_for_clip
,
dummy_seq_data_for_clip
,
get_clip_num_patches
)
from
.interfaces
import
SupportsVision
from
.utils
import
merge_vision_embeddings
from
.interfaces
import
SupportsMultiModal
from
.utils
import
(
filter_weights
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
IMG_START
=
'<img>'
IMG_END
=
'</img>'
...
...
@@ -38,9 +39,6 @@ IMG_CONTEXT = '<IMG_CONTEXT>'
IMAGENET_MEAN
=
(
0.485
,
0.456
,
0.406
)
IMAGENET_STD
=
(
0.229
,
0.224
,
0.225
)
MAX_IMAGE_FEATURE_SIZE_WIDTH
=
3000
MAX_IMAGE_FEATURE_SIZE_HEIGHT
=
500
class
InternVLImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
...
...
@@ -53,6 +51,19 @@ class InternVLImagePixelInputs(TypedDict):
"""
class
InternVLImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
data
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
InternVLImageInputs
=
Union
[
InternVLImagePixelInputs
,
InternVLImageEmbeddingInputs
]
# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
def
build_transform
(
input_size
):
MEAN
,
STD
=
IMAGENET_MEAN
,
IMAGENET_STD
...
...
@@ -84,11 +95,9 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
return
best_ratio
def
calculate_num_blocks
(
orig_width
:
int
,
orig_height
:
int
,
min_num
=
1
,
max_num
=
6
,
image_size
=
448
):
def
calculate_num_blocks
(
orig_width
:
int
,
orig_height
:
int
,
min_num
:
int
,
max_num
:
int
,
image_size
:
int
)
->
Tuple
[
int
,
int
,
int
]:
aspect_ratio
=
orig_width
/
orig_height
# calculate the existing image aspect ratio
...
...
@@ -110,11 +119,9 @@ def calculate_num_blocks(orig_width: int,
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def
dynamic_preprocess
(
image
,
min_num
=
1
,
max_num
=
6
,
image_size
=
448
,
use_thumbnail
=
False
):
def
dynamic_preprocess
(
image
:
Image
.
Image
,
min_num
:
int
,
max_num
:
int
,
image_size
:
int
,
use_thumbnail
:
int
)
->
List
[
Image
.
Image
]:
orig_width
,
orig_height
=
image
.
size
blocks
,
target_width
,
target_height
=
calculate_num_blocks
(
...
...
@@ -138,12 +145,14 @@ def dynamic_preprocess(image,
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def
image_to_pixel_values
(
image
:
Image
.
Image
,
input_size
=
448
,
max_num
=
6
):
def
image_to_pixel_values
(
image
:
Image
.
Image
,
input_size
:
int
,
min_num
:
int
,
max_num
:
int
,
use_thumbnail
:
bool
)
->
torch
.
Tensor
:
transform
=
build_transform
(
input_size
=
input_size
)
images
=
dynamic_preprocess
(
image
,
min_num
=
min_num
,
max_num
=
max_num
,
image_size
=
input_size
,
use_thumbnail
=
True
,
max_num
=
max_num
)
use_thumbnail
=
use_thumbnail
)
pixel_values
=
[
transform
(
image
)
for
image
in
images
]
pixel_values
=
torch
.
stack
(
pixel_values
)
return
pixel_values
...
...
@@ -157,14 +166,20 @@ def get_internvl_num_patches(image_size: int, patch_size: int,
def
get_max_internvl_image_tokens
(
ctx
:
InputContext
):
hf_config
=
ctx
.
get_hf_config
(
PretrainedConfig
)
hf_config
=
ctx
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
use_thumbnail
=
hf_config
.
use_thumbnail
max_dynamic_patch
=
hf_config
.
max_dynamic_patch
if
use_thumbnail
:
max_dynamic_patch
+=
1
downsample_ratio
=
hf_config
.
downsample_ratio
image_size
=
vision_config
.
image_size
patch_size
=
vision_config
.
patch_size
downsample_ratio
=
hf_config
.
downsample_ratio
num_patches
=
get_internvl_num_patches
(
image_size
,
patch_size
,
downsample_ratio
)
return
num_patches
*
7
return
num_patches
*
max_dynamic_patch
def
input_processor_for_internvl
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
):
...
...
@@ -173,24 +188,32 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
return
llm_inputs
model_config
=
ctx
.
model_config
hf_config
=
ctx
.
get_hf_config
(
PretrainedConfig
)
hf_config
=
ctx
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
image_size
=
vision_config
.
image_size
patch_size
=
vision_config
.
patch_size
downsample_ratio
=
hf_config
.
downsample_ratio
num_patches
=
get_internvl_num_patches
(
image_size
,
patch_size
,
downsample_ratio
)
image_data
=
multi_modal_data
[
"image"
]
if
isinstance
(
image_data
,
Image
.
Image
):
width
,
height
=
image_data
.
size
num_blocks
,
_
,
_
=
calculate_num_blocks
(
width
,
height
)
min_num
=
hf_config
.
min_dynamic_patch
max_num
=
hf_config
.
max_dynamic_patch
num_blocks
,
_
,
_
=
calculate_num_blocks
(
width
,
height
,
min_num
,
max_num
,
image_size
)
# add thumbnail image if num_blocks > 1
if
hf_config
.
use_thumbnail
and
num_blocks
>
1
:
num_blocks
+=
1
image_feature_size
=
num_blocks
*
num_patches
elif
isinstance
(
image_data
,
torch
.
Tensor
):
raise
NotImplementedError
(
"Embeddings input is not supported yet"
)
image_feature_size
=
image_data
.
shape
[
0
]
else
:
raise
TypeError
(
f
"Invalid image type:
{
type
(
image_data
)
}
"
)
image_size
=
vision_config
.
image_size
patch_size
=
vision_config
.
patch_size
downsample_ratio
=
hf_config
.
downsample_ratio
num_patches
=
get_internvl_num_patches
(
image_size
,
patch_size
,
downsample_ratio
)
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
...
...
@@ -198,8 +221,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
prompt_token_ids
=
llm_inputs
[
"prompt_token_ids"
]
if
prompt
is
None
:
prompt
=
tokenizer
.
decode
(
prompt_token_ids
)
image_prompt
=
IMG_START
+
IMG_CONTEXT
*
(
num_blocks
+
1
)
*
num_patches
+
IMG_END
image_prompt
=
IMG_START
+
IMG_CONTEXT
*
image_feature_size
+
IMG_END
new_prompt
=
prompt
.
replace
(
'<image>'
,
image_prompt
,
1
)
new_prompt_token_ids
=
tokenizer
.
encode
(
new_prompt
)
...
...
@@ -209,8 +231,19 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
def
input_mapper_for_internvl
(
ctx
:
InputContext
,
data
:
object
):
hf_config
=
ctx
.
get_hf_config
()
use_thumbnail
=
hf_config
.
use_thumbnail
min_num
=
hf_config
.
min_dynamic_patch
max_num
=
hf_config
.
max_dynamic_patch
image_size
=
hf_config
.
vision_config
.
image_size
if
isinstance
(
data
,
Image
.
Image
):
data
=
image_to_pixel_values
(
data
)
data
=
image_to_pixel_values
(
data
,
image_size
,
min_num
,
max_num
,
use_thumbnail
=
use_thumbnail
)
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
...
...
@@ -224,11 +257,13 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
})
def
dummy_data_for_internvl
(
ctx
:
InputContext
,
seq_len
:
int
):
def
dummy_data_for_internvl
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
num_images
=
mm_counts
[
"image"
]
image_feature_size
=
get_max_internvl_image_tokens
(
ctx
)
model_config
=
ctx
.
model_config
hf_config
=
ctx
.
get_hf_config
(
PretrainedConfig
)
hf_config
=
ctx
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
...
...
@@ -236,14 +271,23 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
seq_data
=
dummy_seq_data_for_clip
(
vision_config
,
seq_len
,
num_images
,
image_token_id
=
tokenizer
.
encode
(
IMG_CONTEXT
,
add_special_tokens
=
False
)[
0
],
image_feature_size_override
=
image_feature_size
,
)
image_size
=
vision_config
.
image_size
min_num
=
hf_config
.
min_dynamic_patch
max_num
=
hf_config
.
max_dynamic_patch
max_image_width
=
max_num
*
image_size
max_image_height
=
min_num
*
image_size
mm_data
=
dummy_image_for_clip
(
vision_config
,
image_width_override
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
,
image_height_override
=
MAX_IMAGE_FEATURE_SIZE_HEIGHT
,
num_images
,
image_width_override
=
max_image_width
,
image_height_override
=
max_image_height
,
)
return
seq_data
,
mm_data
...
...
@@ -253,7 +297,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_internvl_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_internvl
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_internvl
)
class
InternVLChatModel
(
nn
.
Module
,
Supports
Vision
):
class
InternVLChatModel
(
nn
.
Module
,
Supports
MultiModal
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
...
...
@@ -283,10 +327,8 @@ class InternVLChatModel(nn.Module, SupportsVision):
self
.
vision_model
=
InternVisionModel
(
config
.
vision_config
,
num_hidden_layers_override
=
num_hidden_layers
)
llm_class
=
ModelRegistry
.
load_model_cls
(
config
.
text_config
.
architectures
[
0
])
self
.
language_model
=
llm_class
(
config
.
text_config
,
cache_config
,
quant_config
)
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
cache_config
,
quant_config
)
vit_hidden_size
=
config
.
vision_config
.
hidden_size
llm_hidden_size
=
config
.
text_config
.
hidden_size
...
...
@@ -356,23 +398,49 @@ class InternVLChatModel(nn.Module, SupportsVision):
return
data
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
InternVLImage
Pixel
Inputs
]:
self
,
**
kwargs
:
object
)
->
Optional
[
InternVLImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
image_token_id
=
kwargs
.
pop
(
"image_token_id"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
if
pixel_values
is
None
:
if
pixel_values
is
None
and
image_embeds
is
None
:
return
None
if
image_embeds
is
not
None
:
if
not
isinstance
(
image_embeds
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of image embeddings. "
f
"Got type:
{
type
(
image_embeds
)
}
"
)
return
InternVLImageEmbeddingInputs
(
type
=
"image_embeds"
,
data
=
image_embeds
,
)
self
.
img_context_token_id
=
image_token_id
[
0
]
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
if
pixel_values
is
not
None
:
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
return
InternVLImagePixelInputs
(
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
pixel_values
),
)
raise
AssertionError
(
"This line should be unreachable."
)
def
_process_image_input
(
self
,
image_input
:
InternVLImageInputs
,
)
->
torch
.
Tensor
:
if
image_input
[
"type"
]
==
"image_embeds"
:
return
image_input
[
"data"
]
assert
self
.
vision_model
is
not
None
image_embeds
=
self
.
extract_feature
(
image_input
[
"data"
])
return
InternVLImagePixelInputs
(
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
pixel_values
),
)
return
image_embeds
def
forward
(
self
,
...
...
@@ -387,10 +455,10 @@ class InternVLChatModel(nn.Module, SupportsVision):
if
image_input
is
not
None
:
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
input_ids
)
vi
t
_embeds
=
self
.
extract_feature
(
image_input
[
"data"
]
)
inputs_embeds
=
merge_
vision
_embeddings
(
input_ids
,
inputs_embeds
,
vit
_embeds
,
self
.
img_context_token_id
)
vi
sion
_embed
ding
s
=
self
.
_process_image_input
(
image_input
)
inputs_embeds
=
merge_
multimodal
_embeddings
(
input_ids
,
inputs_embeds
,
vision
_embed
ding
s
,
self
.
img_context_token_id
)
input_ids
=
None
else
:
inputs_embeds
=
None
...
...
@@ -403,8 +471,11 @@ class InternVLChatModel(nn.Module, SupportsVision):
inputs_embeds
=
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
...
...
@@ -415,24 +486,16 @@ class InternVLChatModel(nn.Module, SupportsVision):
)
->
Optional
[
SamplerOutput
]:
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
_filter_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
prefix
:
str
):
for
name
,
loaded_weight
in
weights
:
name
=
name
.
split
(
"."
)
if
prefix
==
name
.
pop
(
0
):
name
=
"."
.
join
(
name
)
yield
name
,
loaded_weight
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# prepare weight iterators for components
vit_weights
,
mlp_weights
,
llm_weights
=
itertools
.
tee
(
weights
,
3
)
# load vision encoder
vit_weights
=
self
.
_
filter_weights
(
vit_weights
,
"vision_model"
)
vit_weights
=
filter_weights
(
vit_weights
,
"vision_model"
)
self
.
vision_model
.
load_weights
(
vit_weights
)
# load mlp projector
mlp_weights
=
self
.
_
filter_weights
(
mlp_weights
,
"mlp1"
)
mlp_weights
=
filter_weights
(
mlp_weights
,
"mlp1"
)
mlp_params_dict
=
dict
(
self
.
mlp1
.
named_parameters
())
for
name
,
loaded_weight
in
mlp_weights
:
param
=
mlp_params_dict
[
name
]
...
...
@@ -441,5 +504,5 @@ class InternVLChatModel(nn.Module, SupportsVision):
weight_loader
(
param
,
loaded_weight
)
# load llm backbone
llm_weights
=
self
.
_
filter_weights
(
llm_weights
,
"language_model"
)
llm_weights
=
filter_weights
(
llm_weights
,
"language_model"
)
self
.
language_model
.
load_weights
(
llm_weights
)
vllm/model_executor/models/jais.py
View file @
af7f4372
...
...
@@ -20,14 +20,14 @@
"""Inference-only Jais model compatible with HuggingFace weights."""
import
math
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
...
...
@@ -37,12 +37,14 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.transformers_utils.configs
import
JAISConfig
from
.utils
import
is_pp_missing_parameter
,
make_layers
class
SwiGLUActivation
(
nn
.
Module
):
...
...
@@ -216,6 +218,7 @@ class JAISModel(nn.Module):
config
:
JAISConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -231,10 +234,15 @@ class JAISModel(nn.Module):
self
.
embeddings_scale
=
config
.
embeddings_scale
else
:
self
.
embeddings_scale
=
config
.
mup_embeddings_scale
self
.
h
=
nn
.
ModuleList
([
JAISBlock
(
config
,
cache_config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
start_layer
,
self
.
end_layer
,
self
.
h
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
JAISBlock
(
config
=
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
),
prefix
=
f
"
{
prefix
}
.h"
,
)
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
...
...
@@ -243,19 +251,29 @@ class JAISModel(nn.Module):
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
wte
(
input_ids
)
if
self
.
wpe
is
not
None
:
position_embeds
=
self
.
wpe
(
position_ids
)
hidden_states
=
inputs_embeds
+
position_embeds
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
Union
[
IntermediateTensors
,
torch
.
Tensor
]:
if
get_pp_group
().
is_first_rank
:
inputs_embeds
=
self
.
wte
(
input_ids
)
if
self
.
wpe
is
not
None
:
position_embeds
=
self
.
wpe
(
position_ids
)
hidden_states
=
inputs_embeds
+
position_embeds
else
:
hidden_states
=
inputs_embeds
hidden_states
*=
torch
.
tensor
(
float
(
self
.
embeddings_scale
),
dtype
=
hidden_states
.
dtype
)
else
:
hidden_states
=
inputs_embeds
hidden_states
*=
torch
.
tensor
(
float
(
self
.
embeddings_scale
),
dtype
=
hidden_states
.
dtype
)
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
len
(
self
.
h
)
):
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
attn_metadata
)
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
...
...
@@ -273,7 +291,11 @@ class JAISLMHeadModel(nn.Module):
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
transformer
=
JAISModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head
=
self
.
transformer
.
wte
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
transformer
.
wte
else
:
self
.
lm_head
=
ParallelLMHead
(
self
.
config
.
vocab_size
,
self
.
config
.
hidden_size
)
if
hasattr
(
config
,
"width_scale"
):
self
.
output_logits_scale
=
config
.
width_scale
else
:
...
...
@@ -290,17 +312,30 @@ class JAISLMHeadModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Union
[
IntermediateTensors
,
torch
.
Tensor
]
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
attn_metadata
,
intermediate_tensors
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
make_empty_intermediate_tensors
(
self
,
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
IntermediateTensors
:
return
IntermediateTensors
({
"hidden_states"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
})
def
sample
(
self
,
logits
:
torch
.
Tensor
,
...
...
@@ -324,6 +359,10 @@ class JAISLMHeadModel(nn.Module):
continue
if
not
name
.
startswith
(
"transformer."
):
name
=
"transformer."
+
name
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
...
...
vllm/model_executor/models/jamba.py
View file @
af7f4372
...
...
@@ -16,7 +16,6 @@ from vllm.attention.layer import Attention
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -249,37 +248,6 @@ class JambaMambaMixer(nn.Module):
return
hidden_states
class
JambaMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
:
JambaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
hidden_size
=
config
.
hidden_size
intermediate_size
=
config
.
intermediate_size
hidden_act
=
config
.
hidden_act
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
JambaMoE
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -327,6 +295,21 @@ class JambaMoE(nn.Module):
return
hidden_states
.
view
(
orig_shape
)
class
JambaMLP
(
JambaMoE
):
def
__init__
(
self
,
config
:
JambaConfig
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
(
config
,
num_experts
=
1
,
top_k
=
1
,
params_dtype
=
params_dtype
,
tp_size
=
tp_size
,
quant_config
=
quant_config
)
class
JambaMambaDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -609,12 +592,8 @@ class JambaForCausalLM(nn.Module, HasInnerState):
# compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
)
# Current step used indices
self
.
current_indices
:
List
[
int
]
=
[]
# Used to track and store by the Mamba cache between steps.
self
.
mamba_cache
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
=
tuple
()
# Used as an input_buffer for the CUDA graph runs.
self
.
mamba_gc_cache_buffer
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
=
tuple
()
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the self.mamba_cache
self
.
mamba_cache_indices_mapping
:
Dict
[
str
,
Dict
[
int
,
int
]]
=
{}
...
...
@@ -644,95 +623,148 @@ class JambaForCausalLM(nn.Module, HasInnerState):
batch_size
=
input_ids
.
shape
[
0
]
if
attn_metadata
.
prefill_metadata
:
batch_size
=
len
(
request_ids_to_seq_ids
)
(
current_seqlen_agnostic_cache
,
indices
,
)
=
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
batch_size
,
finished_requests_ids
)
mamba_cache
=
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
batch_size
,
finished_requests_ids
)
else
:
# CUDA graph capturing runs
current_seqlen_agnostic_cache
,
indices
=
(
kwargs
[
"seqlen_agnostic_capture_inputs"
],
[],
)
self
.
current_indices
=
indices
mamba_cache
=
kwargs
[
"seqlen_agnostic_capture_inputs"
]
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
current_seqlen_agnostic_cache
[
0
],
current_seqlen_agnostic_cache
[
1
])
if
"seqlen_agnostic_capture_inputs"
not
in
kwargs
:
self
.
_copy_mamba_cache_by_indices
(
self
.
current_indices
,
current_seqlen_agnostic_cache
)
attn_metadata
,
mamba_cache
[
0
],
mamba_cache
[
1
])
return
hidden_states
def
_
copy
_mamba_cache
_by_indices
(
self
,
indices
:
List
[
int
],
current_seqlen_agnostic_cache
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
])
:
for
i
,
offset
in
enumerate
(
indices
):
self
.
_copy_mamba_cache
(
offset
,
i
,
current_seqlen_agnostic_cache
)
def
_
swap
_mamba_cache
(
self
,
from_index
:
int
,
to_index
:
int
):
assert
len
(
self
.
mamba_cache
)
>
0
for
cache_t
in
self
.
mamba_cache
:
cache_t
[:,
[
to_index
,
from_index
]]
=
\
cache_t
[:,
[
from_index
,
to_index
]]
def
_copy_mamba_cache
(
self
,
index_to
:
int
,
index_from
:
int
,
from_buffer
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]):
def
_copy_mamba_cache
(
self
,
from_index
:
int
,
to_index
:
int
):
assert
len
(
self
.
mamba_cache
)
>
0
for
(
cache_t
,
from_buffer_t
)
in
zip
(
self
.
mamba_cache
,
from_buffer
)
:
cache_t
[:,
index
_to
].
copy_
(
from_buffer_t
[:,
index_from
],
for
cache_t
in
self
.
mamba_cache
:
cache_t
[:,
to_
index
].
copy_
(
cache_t
[:,
from_index
],
non_blocking
=
True
)
def
_assign_seq_id_to_mamba_cache
(
self
,
cur_rid
:
str
,
seqs_id
:
List
[
int
])
->
List
[
int
]:
indices_for_current_run
=
[]
for
seq_id
in
seqs_id
:
if
cur_rid
not
in
self
.
mamba_cache_indices_mapping
:
self
.
mamba_cache_indices_mapping
[
cur_rid
]
=
{}
first_free_index
=
self
.
_first_free_index_in_mamba_cache
()
self
.
mamba_cache_indices_mapping
[
cur_rid
][
seq_id
]
=
first_free_index
index_for_current_run
=
first_free_index
## case of decoding n>1, copy prefill cache to decoding indices
elif
seq_id
not
in
(
seq_ids2indices
:
=
self
.
mamba_cache_indices_mapping
[
cur_rid
]):
first_free_index
=
self
.
_first_free_index_in_mamba_cache
()
index_exist
=
list
(
seq_ids2indices
.
values
())[
0
]
self
.
_copy_mamba_cache
(
index_from
=
index_exist
,
index_to
=
first_free_index
,
from_buffer
=
self
.
mamba_cache
)
self
.
mamba_cache_indices_mapping
[
cur_rid
][
seq_id
]
=
first_free_index
index_for_current_run
=
first_free_index
else
:
index_for_current_run
=
self
.
mamba_cache_indices_mapping
[
cur_rid
][
seq_id
]
indices_for_current_run
.
append
(
index_for_current_run
)
return
indices_for_current_run
def
_move_out_if_already_occupied
(
self
,
index
:
int
,
all_occupied_indices
:
List
[
int
]):
if
index
in
all_occupied_indices
:
first_free_index
=
self
.
_first_free_index_in_mamba_cache
()
# In case occupied, move the occupied to a new empty block
self
.
_move_cache_index_and_mappings
(
from_index
=
index
,
to_index
=
first_free_index
)
def
_assign_seq_id_to_mamba_cache_in_specific_dest
(
self
,
cur_rid
:
str
,
seq_id
:
int
,
destination_index
:
int
):
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
all_occupied_indices
=
self
.
_get_all_occupied_indices
()
if
cur_rid
not
in
self
.
mamba_cache_indices_mapping
:
self
.
_move_out_if_already_occupied
(
index
=
destination_index
,
all_occupied_indices
=
all_occupied_indices
)
self
.
mamba_cache_indices_mapping
[
cur_rid
]
=
{
seq_id
:
destination_index
}
elif
seq_id
not
in
(
seq_ids2indices
:
=
self
.
mamba_cache_indices_mapping
[
cur_rid
]):
# parallel sampling , where n > 1, assume prefill have
# already happened now we only need to copy the already
# existing cache into the siblings seq_ids caches
self
.
_move_out_if_already_occupied
(
index
=
destination_index
,
all_occupied_indices
=
all_occupied_indices
)
index_exists
=
list
(
seq_ids2indices
.
values
())[
0
]
# case of decoding n>1, copy prefill cache to decoding indices
self
.
_copy_mamba_cache
(
from_index
=
index_exists
,
to_index
=
destination_index
)
self
.
mamba_cache_indices_mapping
[
cur_rid
][
seq_id
]
=
destination_index
else
:
# already exists
cache_index_already_exists
=
self
.
mamba_cache_indices_mapping
[
cur_rid
][
seq_id
]
if
cache_index_already_exists
!=
destination_index
:
# In case the seq id already exists but not in
# the right destination, swap it with what's occupying it
self
.
_swap_pair_indices_and_mappings
(
from_index
=
cache_index_already_exists
,
to_index
=
destination_index
)
def
_prepare_current_run_mamba_cache
(
self
,
request_ids_to_seq_ids
:
Dict
[
str
,
list
[
int
]],
batch_size
:
int
,
finished_requests_ids
:
List
[
str
]
)
->
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
List
[
int
]]:
indices_for_current_run
=
[]
for
request_id
,
seqs_id
in
request_ids_to_seq_ids
.
items
():
self
,
request_ids_to_seq_ids
:
Dict
[
str
,
list
[
int
]],
batch_size
:
int
,
finished_requests_ids
:
List
[
str
]):
running_indices
=
[]
request_ids_to_seq_ids_flatten
=
[
(
req_id
,
seq_id
)
for
req_id
,
seq_ids
in
request_ids_to_seq_ids
.
items
()
for
seq_id
in
seq_ids
]
for
dest_index
,
(
request_id
,
seq_id
)
in
enumerate
(
request_ids_to_seq_ids_flatten
):
if
request_id
in
finished_requests_ids
:
# Do not allocate cache for requests that run
# Do not allocate cache
index
for requests that run
# and finish right after
continue
indices_for_current_run
+=
self
.
_assign_seq_id_to_mamba_cache
(
request_id
,
seqs_id
)
## Pad the batch in case of running batch that was not captured via CG
padded_indices
=
indices_for_current_run
.
copy
()
pad_index
=
self
.
_first_free_index_in_mamba_cache
()
self
.
_assign_seq_id_to_mamba_cache_in_specific_dest
(
request_id
,
seq_id
,
dest_index
)
running_indices
.
append
(
dest_index
)
for
_
in
range
(
batch_size
-
len
(
indices_for_current_run
)):
padded_indices
.
append
(
pad_index
)
self
.
_clean_up_first_bs_blocks
(
batch_size
,
running_indices
)
conv_state
=
self
.
mamba_cache
[
0
][:,
:
batch_size
]
temporal_state
=
self
.
mamba_cache
[
1
][:,
:
batch_size
]
conv_state
=
self
.
mamba_cache
[
0
][:,
padded_indices
]
temporal_state
=
self
.
mamba_cache
[
1
][:,
padded_indices
]
return
(
conv_state
,
temporal_state
)
def
_get_all_occupied_indices
(
self
):
return
[
cache_idx
for
seq_ids2indices
in
self
.
mamba_cache_indices_mapping
.
values
()
for
cache_idx
in
seq_ids2indices
.
values
()
]
return
(
conv_state
,
temporal_state
),
indices_for_current_run
def
_clean_up_first_bs_blocks
(
self
,
batch_size
:
int
,
indices_for_current_run
:
List
[
int
]):
# move out all of the occupied but currently not running blocks
# outside of the first n blocks
destination_indices
=
set
([
range
(
batch_size
)])
max_possible_batch_size
=
self
.
mamba_cache
[
0
].
shape
[
1
]
for
destination_index
in
destination_indices
:
if
destination_index
in
self
.
_get_all_occupied_indices
()
and
\
destination_index
not
in
indices_for_current_run
:
# move not running indices outside of the batch
all_other_indices
=
list
(
range
(
batch_size
,
max_possible_batch_size
))
first_avail_index
=
self
.
_first_free_index_in_mamba_cache
(
all_other_indices
)
self
.
_swap_indices
(
from_index
=
destination_index
,
to_index
=
first_avail_index
)
def
_move_cache_index_and_mappings
(
self
,
from_index
:
int
,
to_index
:
int
):
self
.
_copy_mamba_cache
(
from_index
=
from_index
,
to_index
=
to_index
)
self
.
_update_mapping_index
(
from_index
=
from_index
,
to_index
=
to_index
)
def
_swap_pair_indices_and_mappings
(
self
,
from_index
:
int
,
to_index
:
int
):
self
.
_swap_mamba_cache
(
from_index
=
from_index
,
to_index
=
to_index
)
self
.
_swap_mapping_index
(
from_index
=
from_index
,
to_index
=
to_index
)
def
_swap_mapping_index
(
self
,
from_index
:
int
,
to_index
:
int
):
for
seq_ids2index
in
self
.
mamba_cache_indices_mapping
.
values
():
for
seq_id
,
index
in
seq_ids2index
.
items
():
if
from_index
==
index
:
seq_ids2index
.
update
({
seq_id
:
to_index
})
elif
to_index
==
index
:
seq_ids2index
.
update
({
seq_id
:
from_index
})
def
_update_mapping_index
(
self
,
from_index
:
int
,
to_index
:
int
):
for
seq_ids2index
in
self
.
mamba_cache_indices_mapping
.
values
():
for
seq_id
,
index
in
seq_ids2index
.
items
():
if
from_index
==
index
:
seq_ids2index
.
update
({
seq_id
:
to_index
})
return
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
"""
...
...
@@ -747,28 +779,9 @@ class JambaForCausalLM(nn.Module, HasInnerState):
self
.
_release_mamba_cache
(
finished_requests_ids
)
request_ids_to_seq_ids
=
kwargs
[
"request_ids_to_seq_ids"
]
cg_batch_size
=
input_buffers
[
'input_ids'
].
shape
[
0
]
(
current_mamba_cache
,
indices
,
)
=
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
cg_batch_size
,
finished_requests_ids
)
self
.
current_indices
=
indices
for
input_buffer
,
current_cache_buffer
in
zip
(
input_buffers
[
"seqlen_agnostic_capture_inputs"
],
current_mamba_cache
):
input_buffer
.
copy_
(
current_cache_buffer
,
non_blocking
=
True
)
def
copy_outputs_after_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
"""
Copy the relevant Mamba cache from the CUDA graph input_buffers
back to the JambaForCausalLM.mamba_cache after CUDA
graph replay run is done.
"""
self
.
_copy_mamba_cache_by_indices
(
self
.
current_indices
,
input_buffers
[
"seqlen_agnostic_capture_inputs"
])
self
.
_prepare_current_run_mamba_cache
(
request_ids_to_seq_ids
,
cg_batch_size
,
finished_requests_ids
)
def
get_seqlen_agnostic_capture_inputs
(
self
,
batch_size
:
int
):
"""
...
...
@@ -776,26 +789,25 @@ class JambaForCausalLM(nn.Module, HasInnerState):
The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs.
"""
return
tuple
(
buffer
[:,
:
batch_size
]
for
buffer
in
self
.
mamba_gc_cache_buffer
)
return
tuple
(
buffer
[:,
:
batch_size
]
for
buffer
in
self
.
mamba_cache
)
def
_release_mamba_cache
(
self
,
finished_seq_groups_req_ids
:
List
[
str
]):
for
req_id
in
finished_seq_groups_req_ids
:
if
req_id
in
self
.
mamba_cache_indices_mapping
:
self
.
mamba_cache_indices_mapping
.
pop
(
req_id
)
def
_first_free_index_in_mamba_cache
(
self
)
->
int
:
if
self
.
mamba_cache
:
def
_first_free_index_in_mamba_cache
(
self
,
indices_range
:
Optional
[
List
[
int
]]
=
None
)
->
int
:
assert
self
.
mamba_cache
is
not
None
if
indices_range
is
None
:
max_possible_batch_size
=
self
.
mamba_cache
[
0
].
shape
[
1
]
occupied
=
[
id
for
seq_ids
in
self
.
mamba_cache_indices_mapping
.
values
()
for
id
in
seq_ids
.
values
()
]
first_free_index
=
[
i
not
in
occupied
for
i
in
range
(
max_possible_batch_size
)
].
index
(
True
)
return
first_free_index
return
0
indices_range
=
list
(
range
(
max_possible_batch_size
))
all_occupied_indices
=
self
.
_get_all_occupied_indices
()
for
i
in
indices_range
:
if
i
not
in
all_occupied_indices
:
return
i
raise
Exception
(
"Couldn't find a free spot in the mamba cache! This"
"should never happen"
)
def
_get_mamba_cache_shape
(
self
...
...
@@ -819,23 +831,24 @@ class JambaForCausalLM(nn.Module, HasInnerState):
[
layer_type
==
"mamba"
for
layer_type
in
layers_type
])
max_batch_size
=
(
_get_graph_batch_size
(
self
.
scheduler_config
.
max_num_seqs
)
if
self
.
scheduler_config
else
max
(
_BATCH_SIZES_TO_CAPTURE
)
)
+
10
max
(
_BATCH_SIZES_TO_CAPTURE
)
+
2
)
conv_state_shape
,
temporal_state_shape
=
self
.
_get_mamba_cache_shape
()
assert
conv_state_shape
is
not
None
and
temporal_state_shape
is
not
None
for
buffername
in
[
"mamba_cache"
,
"mamba_gc_cache_buffer"
]:
buffer
=
(
torch
.
empty
(
size
=
(
mamba_layers
,
max_batch_size
)
+
conv_state_shape
,
dtype
=
dtype
,
device
=
"cuda"
),
torch
.
empty
(
size
=
(
mamba_layers
,
max_batch_size
)
+
temporal_state_shape
,
dtype
=
dtype
,
device
=
"cuda"
))
setattr
(
self
,
buffername
,
buffer
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
.
mamba_cache
=
(
torch
.
empty
(
size
=
(
mamba_layers
,
max_batch_size
)
+
conv_state_shape
,
dtype
=
dtype
,
device
=
"cuda"
),
torch
.
empty
(
size
=
(
mamba_layers
,
max_batch_size
)
+
temporal_state_shape
,
dtype
=
dtype
,
device
=
"cuda"
))
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
@@ -854,8 +867,6 @@ class JambaForCausalLM(nn.Module, HasInnerState):
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
# Params for weights, fp8 weight scales, fp8 activation scales
...
...
@@ -877,6 +888,10 @@ class JambaForCausalLM(nn.Module, HasInnerState):
if
".self_attn."
in
name
:
name
=
name
.
replace
(
".self_attn"
,
""
)
if
"feed_forward"
in
name
and
not
_is_moe_layer
(
name
):
## map MLP layers to expert with ID=0
name
=
name
.
replace
(
"feed_forward"
,
"feed_forward.experts.0"
)
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
...
...
@@ -891,10 +906,15 @@ class JambaForCausalLM(nn.Module, HasInnerState):
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
for
(
param_name
,
weight_name
,
expert_id
,
shard_id
,
)
in
expert_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
...
...
@@ -913,3 +933,11 @@ class JambaForCausalLM(nn.Module, HasInnerState):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
def
_is_moe_layer
(
name
:
str
):
return
any
(
[
experts_name
in
name
for
experts_name
in
[
"experts"
,
"router"
,
]])
vllm/model_executor/models/llama.py
View file @
af7f4372
...
...
@@ -145,6 +145,7 @@ class LlamaAttention(nn.Module):
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
input_size
=
self
.
total_num_heads
*
self
.
head_dim
,
output_size
=
hidden_size
,
...
...
@@ -153,12 +154,17 @@ class LlamaAttention(nn.Module):
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
is_neox_style
=
True
if
quant_config
is
not
None
and
quant_config
.
get_name
()
==
"gguf"
:
is_neox_style
=
False
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
is_neox_style
=
is_neox_style
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
...
...
@@ -291,6 +297,7 @@ class LlamaModel(nn.Module):
self
.
vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
quant_config
=
quant_config
,
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
...
...
@@ -444,8 +451,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
attn_metadata
,
intermediate_tensors
)
return
model_output
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/llava.py
View file @
af7f4372
from
typing
import
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
import
itertools
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
import
torch
import
torch.nn
as
nn
from
transformers
import
CLIPVisionConfig
,
LlavaConfig
from
transformers
import
CLIPVisionConfig
,
LlavaConfig
,
SiglipVisionConfig
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.clip
import
CLIPVisionModel
from
vllm.model_executor.models.llama
import
LlamaModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
.clip
import
(
dummy_image_for_clip
,
dummy_seq_data_for_clip
,
get_max_clip_image_tokens
,
input_processor_for_clip
)
from
.interfaces
import
SupportsVision
from
.utils
import
merge_vision_embeddings
from
.clip
import
(
CLIPVisionModel
,
dummy_image_for_clip
,
dummy_seq_data_for_clip
,
get_max_clip_image_tokens
,
input_processor_for_clip
)
from
.interfaces
import
SupportsMultiModal
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
dummy_seq_data_for_siglip
,
get_max_siglip_image_tokens
,
input_processor_for_siglip
)
from
.utils
import
(
filter_weights
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
_KEYS_TO_MODIFY_MAPPING
=
{
"language_model.lm_head"
:
"lm_head"
,
"language_model.model"
:
"language_model"
,
}
class
LlavaImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
data
:
torch
.
Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""
class
LlavaImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
data
:
torch
.
Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
LlavaImageInputs
=
Union
[
LlavaImagePixelInputs
,
LlavaImageEmbeddingInputs
]
# TODO(xwjiang): Run benchmark and decide if TP.
...
...
@@ -53,38 +67,56 @@ class LlavaMultiModalProjector(nn.Module):
return
hidden_states
class
LlavaImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
data
:
torch
.
Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""
LlavaImageInputs
=
LlavaImagePixelInputs
def
get_max_llava_image_tokens
(
ctx
:
InputContext
):
hf_config
=
ctx
.
get_hf_config
(
LlavaConfig
)
vision_config
=
hf_config
.
vision_config
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
return
get_max_clip_image_tokens
(
vision_config
)
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
num_image_tokens
=
get_max_clip_image_tokens
(
vision_config
)
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
num_image_tokens
=
get_max_siglip_image_tokens
(
vision_config
)
else
:
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
strategy
=
hf_config
.
vision_feature_select_strategy
if
strategy
==
"default"
:
return
num_image_tokens
-
1
elif
strategy
==
"full"
:
return
num_image_tokens
else
:
raise
ValueError
(
f
"Unexpected select feature strategy:
{
strategy
}
"
)
def
dummy_data_for_llava
(
ctx
:
InputContext
,
seq_len
:
int
):
def
dummy_data_for_llava
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
hf_config
=
ctx
.
get_hf_config
(
LlavaConfig
)
vision_config
=
hf_config
.
vision_config
num_images
=
mm_counts
[
"image"
]
image_feature_size
=
get_max_llava_image_tokens
(
ctx
)
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
seq_data
=
dummy_seq_data_for_clip
(
vision_config
,
seq_len
,
num_images
,
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
)
mm_data
=
dummy_image_for_clip
(
vision_config
,
num_images
)
return
seq_data
,
mm_data
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
seq_data
=
dummy_seq_data_for_siglip
(
vision_config
,
seq_len
,
num_images
,
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
)
mm_data
=
dummy_image_for_
c
lip
(
vision_config
)
mm_data
=
dummy_image_for_
sig
lip
(
vision_config
,
num_images
)
return
seq_data
,
mm_data
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
...
...
@@ -100,12 +132,49 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
hf_config
=
ctx
.
get_hf_config
(
LlavaConfig
)
vision_config
=
hf_config
.
vision_config
image_feature_size
=
get_max_llava_image_tokens
(
ctx
)
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
return
input_processor_for_clip
(
model_config
,
vision_config
,
llm_inputs
,
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
)
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
return
input_processor_for_siglip
(
model_config
,
vision_config
,
llm_inputs
,
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
)
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
def
_init_vision_tower
(
hf_config
:
LlavaConfig
):
vision_config
=
hf_config
.
vision_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer
=
hf_config
.
vision_feature_layer
if
vision_feature_layer
<
0
:
num_hidden_layers
=
hf_config
.
vision_config
.
num_hidden_layers
\
+
vision_feature_layer
+
1
else
:
num_hidden_layers
=
vision_feature_layer
+
1
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
return
CLIPVisionModel
(
vision_config
,
num_hidden_layers_override
=
num_hidden_layers
,
)
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
return
SiglipVisionModel
(
vision_config
,
num_hidden_layers_override
=
num_hidden_layers
,
)
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
...
...
@@ -116,7 +185,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_llava_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_llava
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_llava
)
class
LlavaForConditionalGeneration
(
nn
.
Module
,
Supports
Vision
):
class
LlavaForConditionalGeneration
(
nn
.
Module
,
Supports
MultiModal
):
def
__init__
(
self
,
config
:
LlavaConfig
,
...
...
@@ -128,36 +197,15 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer
=
config
.
vision_feature_layer
if
vision_feature_layer
<
0
:
num_hidden_layers
=
config
.
vision_config
.
num_hidden_layers
\
+
vision_feature_layer
+
1
else
:
num_hidden_layers
=
vision_feature_layer
+
1
# TODO: Optionally initializes this for supporting embeddings.
self
.
vision_tower
=
CLIPVisionModel
(
config
.
vision_config
,
num_hidden_layers_override
=
num_hidden_layers
)
self
.
vision_tower
=
_init_vision_tower
(
config
)
self
.
multi_modal_projector
=
LlavaMultiModalProjector
(
vision_hidden_size
=
config
.
vision_config
.
hidden_size
,
text_hidden_size
=
config
.
text_config
.
hidden_size
,
projector_hidden_act
=
config
.
projector_hidden_act
)
self
.
quant_config
=
quant_config
self
.
language_model
=
LlamaModel
(
config
.
text_config
,
cache_config
,
quant_config
)
self
.
unpadded_vocab_size
=
config
.
text_config
.
vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
text_config
.
hidden_size
,
org_num_embeddings
=
self
.
language_model
.
org_vocab_size
,
quant_config
=
quant_config
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
text_config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
Sampler
()
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
cache_config
,
quant_config
)
def
_validate_pixel_values
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
h
=
w
=
self
.
config
.
vision_config
.
image_size
...
...
@@ -175,18 +223,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
LlavaImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
if
pixel_values
is
None
:
if
pixel_values
is
None
and
image_embeds
is
None
:
return
None
if
not
isinstance
(
pixel_values
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
return
LlavaImagePixelInputs
(
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
pixel_values
),
)
if
pixel_values
is
not
None
:
if
not
isinstance
(
pixel_values
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
return
LlavaImagePixelInputs
(
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
pixel_values
),
)
if
image_embeds
is
not
None
:
if
not
isinstance
(
image_embeds
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of image embeddings. "
f
"Got type:
{
type
(
image_embeds
)
}
"
)
return
LlavaImageEmbeddingInputs
(
type
=
"image_embeds"
,
data
=
image_embeds
,
)
raise
AssertionError
(
"This line should be unreachable."
)
def
_select_image_features
(
self
,
image_features
:
torch
.
Tensor
,
*
,
strategy
:
str
)
->
torch
.
Tensor
:
...
...
@@ -198,8 +258,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
raise
ValueError
(
f
"Unexpected select feature strategy:
{
strategy
}
"
)
def
_image_pixels_to_features
(
self
,
vision_tower
:
CLIPVisionModel
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_image_pixels_to_features
(
self
,
vision_tower
:
Union
[
CLIPVisionModel
,
SiglipVisionModel
],
pixel_values
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
...
...
@@ -220,6 +283,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def
_process_image_input
(
self
,
image_input
:
LlavaImageInputs
)
->
torch
.
Tensor
:
if
image_input
[
"type"
]
==
"image_embeds"
:
return
image_input
[
"data"
]
assert
self
.
vision_tower
is
not
None
image_features
=
self
.
_process_image_pixels
(
image_input
)
return
self
.
multi_modal_projector
(
image_features
)
...
...
@@ -246,7 +313,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.
To reserve space in KV cache, we have to insert placeholder tokens
before they are inputted to the model, so the input processor prepends
before they are inputted to the model, so the input processor prepends
additional image tokens (denoted as `32000`), resulting in:
`[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
...
...
@@ -264,7 +331,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: The pixels in each input image.
See also:
:class:`LlavaImageInputs`
"""
...
...
@@ -272,9 +339,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
if
image_input
is
not
None
:
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
merge_
vision
_embeddings
(
inputs_embeds
=
merge_
multimodal
_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
config
.
image_token_index
)
...
...
@@ -282,68 +350,47 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
else
:
inputs_embeds
=
None
hidden_states
=
self
.
language_model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
None
,
inputs_embeds
=
inputs_embeds
)
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
None
,
inputs_embeds
=
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# only doing this for language model part for now.
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
# post_layernorm is not needed in CLIPVisionModel
if
"vision_model.post_layernorm"
in
name
:
continue
for
key_to_modify
,
new_key
in
_KEYS_TO_MODIFY_MAPPING
.
items
():
if
key_to_modify
in
name
:
name
=
name
.
replace
(
key_to_modify
,
new_key
)
use_default_weight_loading
=
False
if
"vision"
in
name
:
if
self
.
vision_tower
is
not
None
:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading
=
True
else
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
use_default_weight_loading
=
True
if
use_default_weight_loading
and
name
in
params_dict
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# prepare weight iterators for components
vit_weights
,
mlp_weights
,
llm_weights
=
itertools
.
tee
(
weights
,
3
)
# load vision encoder
vit_weights
=
filter_weights
(
vit_weights
,
"vision_tower"
)
self
.
vision_tower
.
load_weights
(
vit_weights
)
# load mlp projector
mlp_weights
=
filter_weights
(
mlp_weights
,
"multi_modal_projector"
)
mlp_params_dict
=
dict
(
self
.
multi_modal_projector
.
named_parameters
())
for
name
,
loaded_weight
in
mlp_weights
:
param
=
mlp_params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load llm backbone
llm_weights
=
filter_weights
(
llm_weights
,
"language_model"
)
self
.
language_model
.
load_weights
(
llm_weights
)
Prev
1
…
19
20
21
22
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment