Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6d2051cc
Commit
6d2051cc
authored
Oct 21, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.3.post1' into v0.6.3.post1-dev
parents
2c7f740a
a2c71c54
Changes
457
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2060 additions
and
578 deletions
+2060
-578
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+43
-20
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+14
-8
vllm/model_executor/models/exaone.py
vllm/model_executor/models/exaone.py
+12
-23
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+44
-19
vllm/model_executor/models/fuyu.py
vllm/model_executor/models/fuyu.py
+72
-69
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+64
-21
vllm/model_executor/models/gemma2.py
vllm/model_executor/models/gemma2.py
+161
-61
vllm/model_executor/models/glm4_vision_encoder.py
vllm/model_executor/models/glm4_vision_encoder.py
+299
-0
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+11
-15
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+36
-18
vllm/model_executor/models/gpt_j.py
vllm/model_executor/models/gpt_j.py
+35
-15
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+33
-15
vllm/model_executor/models/granite.py
vllm/model_executor/models/granite.py
+8
-7
vllm/model_executor/models/granitemoe.py
vllm/model_executor/models/granitemoe.py
+448
-0
vllm/model_executor/models/idefics2_vision_model.py
vllm/model_executor/models/idefics2_vision_model.py
+15
-9
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+174
-13
vllm/model_executor/models/interfaces_base.py
vllm/model_executor/models/interfaces_base.py
+191
-0
vllm/model_executor/models/intern_vit.py
vllm/model_executor/models/intern_vit.py
+134
-72
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+5
-5
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+261
-188
No files found.
Too many changes to show.
To preserve performance only
457 of 457+
files are displayed.
Plain diff
Email patch
vllm/model_executor/models/deepseek.py
View file @
6d2051cc
...
...
@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Deepseek model."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
nn
...
...
@@ -29,7 +29,7 @@ from transformers import PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
...
...
@@ -40,8 +40,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -50,6 +49,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
class
DeepseekMLP
(
nn
.
Module
):
...
...
@@ -329,6 +332,7 @@ class DeepseekModel(nn.Module):
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
...
...
@@ -338,14 +342,17 @@ class DeepseekModel(nn.Module):
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
DeepseekDecoderLayer
(
config
,
layer_idx
,
cache_config
,
quant_config
=
quant
_config
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
DeepseekDecoderLayer
(
config
,
int
(
prefix
.
split
(
"."
)[
-
1
])
,
cache
_config
,
quant_config
=
quant_config
),
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
def
forward
(
self
,
...
...
@@ -353,19 +360,29 @@ class DeepseekModel(nn.Module):
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
intermediate_tensors
:
Optional
[
IntermediateTensors
],
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
else
:
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
attn_metadata
,
residual
)
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"residual"
:
residual
})
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
DeepseekForCausalLM
(
nn
.
Module
):
class
DeepseekForCausalLM
(
nn
.
Module
,
SupportsPP
):
def
__init__
(
self
,
...
...
@@ -384,6 +401,8 @@ class DeepseekForCausalLM(nn.Module):
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
...
...
@@ -392,9 +411,9 @@ class DeepseekForCausalLM(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
attn_metadata
,
intermediate_tensors
)
return
hidden_states
def
compute_logits
(
...
...
@@ -439,6 +458,8 @@ class DeepseekForCausalLM(nn.Module):
if
((
"mlp.experts."
in
name
or
"mlp.shared_experts."
in
name
)
and
name
not
in
params_dict
):
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
...
...
@@ -451,6 +472,8 @@ class DeepseekForCausalLM(nn.Module):
if
((
"mlp.experts."
in
name
or
"mlp.shared_experts."
in
name
)
and
name
not
in
params_dict
):
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
6d2051cc
...
...
@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only DeepseekV2 model."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
nn
...
...
@@ -40,8 +40,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -50,7 +49,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.utils
import
PPMissingLayer
,
is_pp_missing_parameter
,
make_layers
from
.interfaces
import
SupportsPP
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
class
DeepseekV2MLP
(
nn
.
Module
):
...
...
@@ -241,7 +242,7 @@ class DeepseekV2Attention(nn.Module):
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
)
rope_scaling
[
'
type
'
]
=
'deepseek_yarn'
rope_scaling
[
"rope_
type
"
]
=
'deepseek_yarn'
self
.
rotary_emb
=
get_rope
(
qk_rope_head_dim
,
rotary_dim
=
qk_rope_head_dim
,
max_position
=
max_position_embeddings
,
...
...
@@ -439,6 +440,9 @@ class DeepseekV2Model(nn.Module):
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
()
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
def
forward
(
self
,
...
...
@@ -447,7 +451,7 @@ class DeepseekV2Model(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
)
->
torch
.
Tensor
:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]
:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
...
...
@@ -472,7 +476,7 @@ class DeepseekV2Model(nn.Module):
return
hidden_states
class
DeepseekV2ForCausalLM
(
nn
.
Module
):
class
DeepseekV2ForCausalLM
(
nn
.
Module
,
SupportsPP
):
def
__init__
(
self
,
...
...
@@ -492,6 +496,8 @@ class DeepseekV2ForCausalLM(nn.Module):
quant_config
=
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
...
...
@@ -500,7 +506,7 @@ class DeepseekV2ForCausalLM(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
return
hidden_states
...
...
vllm/model_executor/models/exaone.py
View file @
6d2051cc
...
...
@@ -38,8 +38,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
get_compressed_tensors_cache_scale
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
...
...
@@ -53,8 +52,9 @@ from vllm.sequence import IntermediateTensors
from
vllm.transformers_utils.configs.exaone
import
ExaoneConfig
from
vllm.utils
import
is_hip
from
.interfaces
import
SupportsLoRA
from
.utils
import
PPMissingLayer
,
is_pp_missing_parameter
,
make_layers
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
class
ExaoneGatedMLP
(
nn
.
Module
):
...
...
@@ -354,6 +354,10 @@ class ExaoneModel(nn.Module):
else
:
self
.
ln_f
=
PPMissingLayer
()
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
wte
(
input_ids
)
...
...
@@ -397,7 +401,7 @@ class ExaoneModel(nn.Module):
return
hidden_states
class
ExaoneForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
class
ExaoneForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
@@ -477,6 +481,9 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA):
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
make_empty_intermediate_tensors
=
(
self
.
transformer
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -506,24 +513,6 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
make_empty_intermediate_tensors
(
self
,
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
IntermediateTensors
:
return
IntermediateTensors
({
"hidden_states"
:
torch
.
zeros
(
(
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
,
),
"residual"
:
torch
.
zeros
(
(
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
,
),
})
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
...
...
vllm/model_executor/models/falcon.py
View file @
6d2051cc
...
...
@@ -28,7 +28,7 @@ from transformers import FalconConfig as HF_FalconConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.activation
import
get_act_fn
...
...
@@ -36,8 +36,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -47,6 +46,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs
import
RWConfig
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
FalconConfig
=
Union
[
HF_FalconConfig
,
RWConfig
]
...
...
@@ -333,6 +336,7 @@ class FalconModel(nn.Module):
config
:
FalconConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -347,35 +351,56 @@ class FalconModel(nn.Module):
)
# Transformer blocks
self
.
h
=
nn
.
ModuleList
([
FalconDecoderLayer
(
config
,
cache_config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
start_layer
,
self
.
end_layer
,
self
.
h
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
FalconDecoderLayer
(
config
,
cache_config
,
quant_config
),
prefix
=
f
"
{
prefix
}
.h"
)
# Final Layer Norm
self
.
ln_f
=
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
hidden_size
))
def
forward
(
self
,
input_ids
:
torch
.
Long
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
word_embeddings
(
input_ids
)
for
i
in
range
(
len
(
self
.
h
)):
intermediate_tensors
:
Optional
[
IntermediateTensors
],
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
word_embeddings
(
input_ids
)
else
:
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
class
FalconForCausalLM
(
nn
.
Module
):
class
FalconForCausalLM
(
nn
.
Module
,
SupportsPP
):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping
=
{}
default_bitsandbytes_target_modules
=
[
".query_key_value."
,
".dense."
,
".dense_h_to_4h."
,
".dense_4h_to_h."
,
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules
=
[
".dense_4h_to_h."
,
".dense."
]
def
__init__
(
self
,
...
...
@@ -403,6 +428,8 @@ class FalconForCausalLM(nn.Module):
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
transformer
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
...
...
@@ -412,12 +439,8 @@ class FalconForCausalLM(nn.Module):
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
)
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
return
hidden_states
def
compute_logits
(
...
...
@@ -454,6 +477,8 @@ class FalconForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
if
"query_key_value"
in
name
:
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
...
...
vllm/model_executor/models/fuyu.py
View file @
6d2051cc
...
...
@@ -27,11 +27,11 @@ from transformers import FuyuConfig, FuyuImageProcessor
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
InputContext
,
token_inputs
)
from
vllm.model_executor.layers.linear
import
ColumnParallelLinear
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.persimmon
import
PersimmonForCausalLM
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
...
...
@@ -41,8 +41,8 @@ from vllm.multimodal.utils import cached_get_tokenizer
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SequenceData
)
from
.interfaces
import
SupportsMultiModal
from
.utils
import
merge_multimodal_embeddings
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.utils
import
AutoWeightsLoader
,
flatten_bn
,
merge_multimodal_embeddings
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID
=
71011
...
...
@@ -150,10 +150,10 @@ def _fuyu_image_preprocess(image_processor: FuyuImageProcessor,
return
model_image_input
def
input_processor_for_fuyu
(
ctx
:
InputContext
,
llm_
inputs
:
LLM
Inputs
):
multi_modal_data
=
llm_
inputs
.
get
(
"multi_modal_data"
)
def
input_processor_for_fuyu
(
ctx
:
InputContext
,
inputs
:
DecoderOnly
Inputs
):
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
return
llm_
inputs
return
inputs
model_config
=
ctx
.
model_config
image_data
=
multi_modal_data
[
"image"
]
...
...
@@ -165,7 +165,7 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
model_config
.
model
)
model_image_input
=
_fuyu_image_preprocess
(
image_processor
,
image_data
)
image_patches
=
torch
.
stack
([
image_patches
=
torch
.
cat
([
image_patch
[
0
]
for
image_patch
in
model_image_input
[
"image_patches"
]
])
...
...
@@ -177,8 +177,8 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
raise
TypeError
(
f
"Invalid image type:
{
type
(
image_data
)
}
"
)
# process prompts
prompt
=
llm_
inputs
.
get
(
"prompt"
)
prompt_token_ids
=
llm_
inputs
[
"prompt_token_ids"
]
prompt
=
inputs
.
get
(
"prompt"
)
prompt_token_ids
=
inputs
[
"prompt_token_ids"
]
tokenizer
=
cached_get_tokenizer
(
model_config
.
model
)
# dim0 is batch_size, dim1 is subseq_size which will always be 1
image_input_ids
:
List
[
List
[
...
...
@@ -191,9 +191,9 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
new_prompt_token_ids
=
image_input_ids
+
bos_token
+
prompt_token_ids
[
1
:]
+
boa_token
return
LLMI
nputs
(
prompt
=
new_prompt
,
prompt_token_ids
=
new_prompt_token_ids
,
multi_modal_data
=
new_multi_modal_data
)
return
token_i
nputs
(
prompt
=
new_prompt
,
prompt_token_ids
=
new_prompt_token_ids
,
multi_modal_data
=
new_multi_modal_data
)
def
input_mapper_for_fuyu
(
ctx
:
InputContext
,
data
:
object
):
...
...
@@ -210,14 +210,14 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object):
])
# image has been processed with prompt in input processor
return
MultiModalInputs
({
"
image_patch
es"
:
data
})
return
MultiModalInputs
({
"
pixel_valu
es"
:
data
})
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
(
input_mapper_for_fuyu
)
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_fuyu_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_fuyu
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_fuyu
)
class
FuyuForCausalLM
(
nn
.
Module
,
SupportsMultiModal
):
class
FuyuForCausalLM
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
def
__init__
(
self
,
config
:
FuyuConfig
,
...
...
@@ -237,28 +237,54 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
self
.
image_feature_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
gather_output
=
True
,
)
self
.
language_model
=
PersimmonForCausalLM
(
config
,
self
.
language_model
=
PersimmonForCausalLM
(
config
.
text_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
@
property
def
sampler
(
self
):
return
self
.
language_model
.
sampler
def
_validate_pixel_values
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
h
=
w
=
self
.
config
.
patch_size
num_channels
=
self
.
config
.
num_channels
expected_dims
=
num_channels
*
h
*
w
def
_validate_shape
(
d
:
torch
.
Tensor
):
actual_dims
=
d
.
size
(
-
1
)
if
actual_dims
!=
expected_dims
:
expected_expr
=
str
(
expected_dims
)
raise
ValueError
(
"The expected shape of pixel values per image per batch "
f
" per patch is
{
expected_expr
}
. "
f
"You supplied
{
tuple
(
d
.
shape
)
}
."
)
for
d
in
data
:
_validate_shape
(
d
)
return
data
.
to
(
self
.
vision_embed_tokens
.
weight
.
dtype
)
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
FuyuImagePixelInputs
]:
image_patch
es
=
kwargs
.
pop
(
"
image_patch
es"
,
None
)
pixel_valu
es
=
kwargs
.
pop
(
"
pixel_valu
es"
,
None
)
if
isinstance
(
image_patches
,
torch
.
Tensor
):
# Remove the N dimension until multiple images are supported.
image_patches
=
image_patches
.
squeeze
(
1
)
if
pixel_values
is
not
None
:
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of image patches. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
return
FuyuImagePixelInputs
(
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
flatten_bn
(
pixel_values
,
concat
=
True
)),
)
expected_feature_size
=
self
.
image_feature_size
if
image_patches
.
size
(
-
1
)
!=
expected_feature_size
:
raise
ValueError
(
f
"Expected image patches to have the last dimension of "
f
"
{
expected_feature_size
}
, got
{
image_patches
.
size
(
-
1
)
}
"
)
image_patches
=
image_patches
.
to
(
self
.
vision_embed_tokens
.
weight
.
dtype
)
return
FuyuImagePixelInputs
(
type
=
"pixel_values"
,
data
=
image_patches
)
return
None
def
_process_image_input
(
...
...
@@ -277,23 +303,29 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
**
kwargs
:
object
,
):
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
intermediate_tensors
is
not
None
:
input_ids
=
None
inputs_embeds
=
None
else
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
not
None
:
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
inputs_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
image_token_id
)
if
image_input
is
not
None
:
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
inputs_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
image_token_id
)
else
:
inputs_embeds
=
None
else
:
inputs_embeds
=
None
hidden_states
=
self
.
language_model
(
input_ids
=
input_ids
,
positions
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
return
hidden_states
...
...
@@ -316,34 +348,5 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
(
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
param
=
params_dict
[
name
]
if
"query_key_value"
in
name
:
# copy from vllm/model_executor/models/bloom.py
# NOTE: Fuyu's fused QKV's output_dim has the shape of
# (num_heads * 3 * head_size), while the
# required shape is (3 * num_heads * head_size).
# Thus, we need weight conversion.
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
num_heads
=
self
.
config
.
num_attention_heads
if
output_dim
is
not
None
:
loaded_weight_shape
=
loaded_weight
.
shape
loaded_weight
=
loaded_weight
.
view
(
loaded_weight_shape
[:
output_dim
]
+
(
num_heads
,
3
,
-
1
)
+
loaded_weight_shape
[
output_dim
+
1
:])
loaded_weight
=
loaded_weight
.
transpose
(
output_dim
,
output_dim
+
1
)
loaded_weight
=
loaded_weight
.
reshape
(
loaded_weight_shape
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loader
=
AutoWeightsLoader
(
self
)
loader
.
load_weights
(
weights
)
vllm/model_executor/models/gemma.py
View file @
6d2051cc
...
...
@@ -15,7 +15,7 @@
# limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights."""
from
functools
import
lru_cache
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
...
...
@@ -23,7 +23,7 @@ from transformers import GemmaConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.layernorm
import
GemmaRMSNorm
...
...
@@ -31,8 +31,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -41,7 +40,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
logger
=
init_logger
(
__name__
)
...
...
@@ -245,6 +246,7 @@ class GemmaModel(nn.Module):
config
:
GemmaConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
...
...
@@ -253,10 +255,11 @@ class GemmaModel(nn.Module):
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
GemmaDecoderLayer
(
config
,
cache_config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
GemmaDecoderLayer
(
config
,
cache_config
,
quant_config
),
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
norm
=
GemmaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
# Normalize the embedding by sqrt(hidden_size)
...
...
@@ -265,6 +268,9 @@ class GemmaModel(nn.Module):
# See https://github.com/huggingface/transformers/pull/29402
normalizer
=
self
.
config
.
hidden_size
**
0.5
self
.
register_buffer
(
"normalizer"
,
torch
.
tensor
(
normalizer
))
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
...
...
@@ -275,29 +281,38 @@ class GemmaModel(nn.Module):
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
hidden_states
*=
self
.
normalizer
residual
=
None
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
hidden_states
*=
self
.
normalizer
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"residual"
:
residual
})
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
GemmaForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
class
GemmaForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
@@ -317,6 +332,28 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
"gate_up_proj"
,
"down_proj"
,
]
# BitandBytes specific attributes
default_bitsandbytes_target_modules
=
[
".gate_proj."
,
".down_proj."
,
".up_proj."
,
".q_proj."
,
".k_proj."
,
".v_proj."
,
".o_proj."
,
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules
=
[
".down_proj."
,
".o_proj."
]
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"k_proj"
:
(
"qkv_proj"
,
1
),
"v_proj"
:
(
"qkv_proj"
,
2
),
"gate_proj"
:
(
"gate_up_proj"
,
0
),
"up_proj"
:
(
"gate_up_proj"
,
1
),
}
# Gemma does not apply LoRA to the embedding layer.
embedding_modules
=
{}
embedding_padding_modules
=
[]
...
...
@@ -339,6 +376,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
self
.
model
=
GemmaModel
(
config
,
cache_config
,
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
...
...
@@ -347,9 +386,9 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
attn_metadata
,
intermediate_tensors
)
return
hidden_states
def
compute_logits
(
...
...
@@ -388,6 +427,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
...
...
@@ -400,6 +441,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
...
...
vllm/model_executor/models/gemma2.py
View file @
6d2051cc
...
...
@@ -14,15 +14,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
transformers
import
Gemma2Config
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.layernorm
import
GemmaRMSNorm
...
...
@@ -30,17 +31,20 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.
quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.
pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
.interfaces
import
SupportsLoRA
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
logger
=
init_logger
(
__name__
)
...
...
@@ -237,6 +241,13 @@ class Gemma2DecoderLayer(nn.Module):
return
hidden_states
,
residual
@
support_torch_compile
(
dynamic_arg_dims
=
{
"input_ids"
:
0
,
"positions"
:
0
,
"inputs_embeds"
:
0
,
"intermediate_tensors"
:
0
,
})
class
Gemma2Model
(
nn
.
Module
):
def
__init__
(
...
...
@@ -244,6 +255,7 @@ class Gemma2Model(nn.Module):
config
:
Gemma2Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
...
...
@@ -252,10 +264,11 @@ class Gemma2Model(nn.Module):
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
Gemma2DecoderLayer
(
layer_idx
,
config
,
cache_config
,
quant_config
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
])
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
Gemma2DecoderLayer
(
int
(
prefix
.
split
(
"."
)[
-
1
]),
config
,
cache_config
,
quant_config
),
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
norm
=
GemmaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
# Normalize the embedding by sqrt(hidden_size)
...
...
@@ -264,32 +277,92 @@ class Gemma2Model(nn.Module):
# See https://github.com/huggingface/transformers/pull/29402
normalizer
=
self
.
config
.
hidden_size
**
0.5
self
.
register_buffer
(
"normalizer"
,
torch
.
tensor
(
normalizer
))
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
Optional
[
torch
.
Tensor
]
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
*=
self
.
normalizer
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
*=
self
.
normalizer
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"residual"
:
residual
})
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
shard_name
,
shard_id
)
in
stacked_params_mapping
:
if
shard_name
not
in
name
:
continue
name
=
name
.
replace
(
shard_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
class
Gemma2ForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
unloaded_params
=
params_dict
.
keys
()
-
loaded_params
if
unloaded_params
:
logger
.
warning
(
"Some weights are not initialized from checkpoints: %s"
,
unloaded_params
)
class
Gemma2ForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
@@ -312,6 +385,19 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
# Gemma does not apply LoRA to the embedding layer.
embedding_modules
=
{}
embedding_padding_modules
=
[]
# BitandBytes specific attributes
default_bitsandbytes_target_modules
=
[
".gate_proj."
,
".down_proj."
,
".up_proj."
,
".q_proj."
,
".k_proj."
,
".v_proj."
,
".o_proj."
,
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules
=
[
".down_proj."
,
".o_proj."
]
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
...
...
@@ -338,6 +424,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
,
soft_cap
=
config
.
final_logit_softcapping
)
self
.
sampler
=
Sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
...
...
@@ -346,9 +434,9 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
attn_metadata
,
intermediate_tensors
)
return
hidden_states
def
compute_logits
(
...
...
@@ -369,44 +457,56 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
shard_name
,
shard_id
)
in
stacked_params_mapping
:
if
shard_name
not
in
name
:
continue
name
=
name
.
replace
(
shard_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if
"lm_head.weight"
in
name
:
continue
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
([
"lm_head."
]
if
self
.
config
.
tie_word_embeddings
else
None
),
)
loader
.
load_weights
(
weights
)
unloaded_params
=
params_dict
.
keys
()
-
loaded_params
if
unloaded_params
:
logger
.
warning
(
"Some weights are not initialized from checkpoints: %s"
,
unloaded_params
)
class
Gemma2EmbeddingModel
(
nn
.
Module
,
SupportsPP
):
"""
A model that uses Gemma2 with additional embedding functionalities.
This class encapsulates the Gemma2Model and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of Gemma2Model used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def
__init__
(
self
,
**
kwargs
,
)
->
None
:
super
().
__init__
()
self
.
model
=
Gemma2Model
(
**
kwargs
)
self
.
_pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
return
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
self
.
model
.
load_weights
(
weights
)
vllm/model_executor/models/glm4_vision_encoder.py
0 → 100644
View file @
6d2051cc
# coding=utf-8
# Adapted from
# https://github.com/THUDM/GLM-4
"""Inference-only GLM-4v model visual encoder compatible with THUDM weights."""
from
argparse
import
Namespace
from
typing
import
Optional
import
torch
from
torch
import
nn
from
torch.nn
import
LayerNorm
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
,
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
class
PatchEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
proj
=
nn
.
Conv2d
(
config
.
in_channels
,
config
.
hidden_size
,
kernel_size
=
config
.
patch_size
,
stride
=
config
.
patch_size
)
self
.
cls_embedding
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
config
.
hidden_size
))
self
.
position_embedding
=
nn
.
Embedding
(
config
.
num_positions
,
config
.
hidden_size
)
def
forward
(
self
,
images
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Parameters:
images : torch.Tensor
Input image tensor with shape (B, C, H, W)
Returns:
torch.Tensor
Transformed tensor with shape (B, L, D)
"""
images
=
images
.
to
(
self
.
proj
.
weight
.
device
)
x
=
self
.
proj
(
images
)
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
cls_token
=
self
.
cls_embedding
.
expand
(
x
.
shape
[
0
],
-
1
,
-
1
)
x
=
torch
.
cat
((
cls_token
,
x
),
dim
=
1
)
x
+=
self
.
position_embedding
.
weight
.
unsqueeze
(
0
)
return
x
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads_per_rank
=
config
.
num_heads
//
self
.
tp_size
self
.
head_dim
=
config
.
hidden_size
//
config
.
num_heads
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
query_key_value
=
QKVParallelLinear
(
config
.
hidden_size
,
self
.
head_dim
,
config
.
num_heads
,
quant_config
=
quant_config
,
)
self
.
dense
=
RowParallelLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
)
self
.
output_dropout
=
torch
.
nn
.
Dropout
(
config
.
dropout_prob
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
B
,
L
,
_
=
x
.
shape
qkv
,
_
=
self
.
query_key_value
(
x
)
# B, L, 3 * H * D
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=-
1
)
q
=
q
.
reshape
(
B
,
L
,
self
.
num_heads_per_rank
,
self
.
head_dim
).
permute
(
0
,
2
,
1
,
3
)
# B, H, L, D
k
=
k
.
reshape
(
B
,
L
,
self
.
num_heads_per_rank
,
self
.
head_dim
).
permute
(
0
,
2
,
1
,
3
)
# B, H, L, D
v
=
v
.
reshape
(
B
,
L
,
self
.
num_heads_per_rank
,
self
.
head_dim
).
permute
(
0
,
2
,
1
,
3
)
# B, H, L, D
out
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
None
,
dropout_p
=
0.
,
is_causal
=
False
)
# output, _ = self.dense(out.transpose(1, 2).view(B, L, -1))
output
,
_
=
self
.
dense
(
out
.
transpose
(
1
,
2
).
reshape
(
B
,
L
,
-
1
))
output
=
self
.
output_dropout
(
output
)
return
output
class
MLP
(
nn
.
Module
):
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
activation_fn
=
get_act_fn
(
config
.
hidden_act
)
self
.
fc1
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
intermediate_size
,
quant_config
=
quant_config
,
)
self
.
fc2
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
,
_
=
self
.
fc1
(
x
)
x
=
self
.
activation_fn
(
x
)
x
,
_
=
self
.
fc2
(
x
)
return
x
class
TransformerLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
input_layernorm
=
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
attention
=
Attention
(
config
,
quant_config
=
quant_config
)
self
.
mlp
=
MLP
(
config
,
quant_config
=
quant_config
)
self
.
post_attention_layernorm
=
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
self
,
hidden_states
):
attention_input
=
hidden_states
attention_output
=
self
.
input_layernorm
(
self
.
attention
(
attention_input
))
hidden_states
=
attention_input
+
attention_output
mlp_input
=
hidden_states
mlp_output
=
self
.
post_attention_layernorm
(
self
.
mlp
(
mlp_input
))
output
=
mlp_input
+
mlp_output
return
output
class
Transformer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
layers
=
nn
.
ModuleList
([
TransformerLayer
(
config
,
quant_config
=
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
def
forward
(
self
,
hidden_states
):
for
layer_module
in
self
.
layers
:
hidden_states
=
layer_module
(
hidden_states
)
return
hidden_states
class
GLU
(
nn
.
Module
):
def
__init__
(
self
,
config
,
in_features
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
"""
The original implementation is the same as:
```python
self.dense_h_to_4h = ColumnParallelLinear(
config.hidden_size,
config.ffn_hidden_size,
bias=False,
quant_config=quant_config
)
self.gate_proj = ColumnParallelLinear(
config.hidden_size,
config.ffn_hidden_size,
bias=False,
quant_config=quant_config
)
```
```
gate_proj_output, _ = self.gate_proj(x)
dense_h_to_4h_output, _ = self.dense_h_to_4h(x)
x = torch.cat([gate_proj_output, dense_h_to_4h_output], dim=-1)
```
We merge two ColumnParallelLinear into one MergedColumnParallelLinear:
```
self.merged_proj = MergedColumnParallelLinear(
config.hidden_size,
[config.ffn_hidden_size] * 2,
bias=False,
quant_config=quant_config
)
```
```
x, _ = self.merged_proj(x)
```
"""
super
().
__init__
()
self
.
linear_proj
=
ReplicatedLinear
(
in_features
,
config
.
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
norm1
=
nn
.
LayerNorm
(
config
.
hidden_size
)
self
.
act1
=
nn
.
GELU
()
self
.
act2
=
SiluAndMul
()
self
.
merged_proj
=
MergedColumnParallelLinear
(
config
.
hidden_size
,
[
config
.
ffn_hidden_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
dense_4h_to_h
=
RowParallelLinear
(
config
.
ffn_hidden_size
,
config
.
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
)
def
forward
(
self
,
x
):
x
,
_
=
self
.
linear_proj
(
x
)
x
=
self
.
act1
(
self
.
norm1
(
x
))
x
,
_
=
self
.
merged_proj
(
x
)
x
=
self
.
act2
(
x
)
x
,
_
=
self
.
dense_4h_to_h
(
x
)
return
x
class
EVA2CLIPModel
(
nn
.
Module
):
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
vision_config
=
Namespace
(
**
config
.
vision_config
)
self
.
patch_embedding
=
PatchEmbedding
(
vision_config
)
self
.
transformer
=
Transformer
(
vision_config
,
quant_config
=
quant_config
)
self
.
linear_proj
=
GLU
(
config
,
in_features
=
config
.
hidden_size
,
quant_config
=
quant_config
)
self
.
conv
=
nn
.
Conv2d
(
in_channels
=
vision_config
.
hidden_size
,
out_channels
=
config
.
hidden_size
,
kernel_size
=
2
,
stride
=
2
)
self
.
boi
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
config
.
hidden_size
))
self
.
eoi
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
config
.
hidden_size
))
self
.
scaling_factor
=
vision_config
.
scaling_factor
def
forward
(
self
,
images
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Parameters:
images : torch.Tensor
Input image tensor with shape (B, C, H, W)
Returns:
torch.Tensor
Transformed tensor with shape (B, L, D)
"""
x
=
self
.
patch_embedding
(
images
)
x
=
self
.
transformer
(
x
)
x
=
x
[:,
1
:]
b
,
s
,
h
=
x
.
shape
grid_size
=
int
(
s
**
0.5
)
x
=
x
.
view
(
b
,
grid_size
,
grid_size
,
h
).
permute
(
0
,
3
,
1
,
2
)
x
=
self
.
conv
(
x
)
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
x
=
self
.
linear_proj
(
x
)
boi
=
self
.
boi
.
expand
(
x
.
shape
[
0
],
-
1
,
-
1
)
eoi
=
self
.
eoi
.
expand
(
x
.
shape
[
0
],
-
1
,
-
1
)
x
=
torch
.
cat
((
boi
,
x
,
eoi
),
dim
=
1
)
x
=
x
/
self
.
scaling_factor
return
x
vllm/model_executor/models/gpt2.py
View file @
6d2051cc
...
...
@@ -32,8 +32,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
...
...
@@ -41,7 +40,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.utils
import
is_pp_missing_parameter
,
make_layers
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
class
GPT2Attention
(
nn
.
Module
):
...
...
@@ -204,6 +205,9 @@ class GPT2Model(nn.Module):
config
,
cache_config
,
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.h"
)
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
n_embd
))
def
forward
(
self
,
...
...
@@ -234,7 +238,7 @@ class GPT2Model(nn.Module):
return
hidden_states
class
GPT2LMHeadModel
(
nn
.
Module
):
class
GPT2LMHeadModel
(
nn
.
Module
,
SupportsPP
):
def
__init__
(
self
,
...
...
@@ -256,6 +260,8 @@ class GPT2LMHeadModel(nn.Module):
self
.
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
transformer
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
...
...
@@ -264,7 +270,7 @@ class GPT2LMHeadModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
return
hidden_states
...
...
@@ -286,16 +292,6 @@ class GPT2LMHeadModel(nn.Module):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
make_empty_intermediate_tensors
(
self
,
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
IntermediateTensors
:
return
IntermediateTensors
({
"hidden_states"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
})
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
weights
:
...
...
vllm/model_executor/models/gpt_bigcode.py
View file @
6d2051cc
...
...
@@ -18,7 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
nn
...
...
@@ -26,14 +26,13 @@ from transformers import GPTBigCodeConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
...
...
@@ -41,7 +40,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
class
GPTBigCodeAttention
(
nn
.
Module
):
...
...
@@ -194,6 +195,7 @@ class GPTBigCodeModel(nn.Module):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -207,11 +209,15 @@ class GPTBigCodeModel(nn.Module):
self
.
embed_dim
,
org_num_embeddings
=
config
.
vocab_size
)
self
.
wpe
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
self
.
embed_dim
)
self
.
h
=
nn
.
ModuleList
([
GPTBigCodeBlock
(
config
,
cache_config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
start_layer
,
self
.
end_layer
,
self
.
h
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
GPTBigCodeBlock
(
config
,
cache_config
,
quant_config
),
prefix
=
f
"
{
prefix
}
.h"
,
)
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
n_embd
))
def
forward
(
self
,
...
...
@@ -219,20 +225,28 @@ class GPTBigCodeModel(nn.Module):
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
wte
(
input_ids
)
position_embeds
=
self
.
wpe
(
position_ids
)
hidden_states
=
inputs_embeds
+
position_embeds
intermediate_tensors
:
Optional
[
IntermediateTensors
],
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
inputs_embeds
=
self
.
wte
(
input_ids
)
position_embeds
=
self
.
wpe
(
position_ids
)
hidden_states
=
inputs_embeds
+
position_embeds
else
:
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
len
(
self
.
h
)
):
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
attn_metadata
)
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
class
GPTBigCodeForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
class
GPTBigCodeForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
packed_modules_mapping
=
{
"c_attn"
:
[
"c_attn"
]}
supported_lora_modules
=
[
"c_fc"
,
"c_proj"
,
"wte"
,
"c_attn"
]
...
...
@@ -272,6 +286,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
transformer
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
...
...
@@ -280,9 +296,9 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
attn_metadata
,
intermediate_tensors
)
return
hidden_states
def
compute_logits
(
...
...
@@ -311,6 +327,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
...
...
vllm/model_executor/models/gpt_j.py
View file @
6d2051cc
...
...
@@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-J model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
nn
...
...
@@ -24,14 +24,13 @@ from transformers import GPTJConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -40,6 +39,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
class
GPTJAttention
(
nn
.
Module
):
...
...
@@ -178,6 +181,7 @@ class GPTJModel(nn.Module):
config
:
GPTJConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -186,11 +190,15 @@ class GPTJModel(nn.Module):
config
.
vocab_size
,
self
.
embed_dim
,
)
self
.
h
=
nn
.
ModuleList
([
GPTJBlock
(
config
,
cache_config
,
quant_config
)
for
_
in
range
(
config
.
n_layer
)
])
self
.
start_layer
,
self
.
end_layer
,
self
.
h
=
make_layers
(
config
.
n_layer
,
lambda
prefix
:
GPTJBlock
(
config
,
cache_config
,
quant_config
),
prefix
=
f
"
{
prefix
}
.h"
,
)
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
n_embd
))
def
forward
(
self
,
...
...
@@ -198,21 +206,27 @@ class GPTJModel(nn.Module):
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
wte
(
input_ids
)
for
i
in
range
(
len
(
self
.
h
)):
intermediate_tensors
:
Optional
[
IntermediateTensors
],
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
wte
(
input_ids
)
else
:
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
position_ids
,
hidden_states
,
kv_caches
[
i
],
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
class
GPTJForCausalLM
(
nn
.
Module
):
class
GPTJForCausalLM
(
nn
.
Module
,
SupportsPP
):
def
__init__
(
self
,
...
...
@@ -233,6 +247,8 @@ class GPTJForCausalLM(nn.Module):
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
transformer
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
...
...
@@ -241,9 +257,9 @@ class GPTJForCausalLM(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
attn_metadata
,
intermediate_tensors
)
return
hidden_states
def
compute_logits
(
...
...
@@ -283,6 +299,8 @@ class GPTJForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
...
...
@@ -291,6 +309,8 @@ class GPTJForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
...
...
vllm/model_executor/models/gpt_neox.py
View file @
6d2051cc
...
...
@@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
nn
...
...
@@ -24,14 +24,13 @@ from transformers import GPTNeoXConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -40,6 +39,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
class
GPTNeoXAttention
(
nn
.
Module
):
...
...
@@ -191,6 +194,7 @@ class GPTNeoXModel(nn.Module):
config
:
GPTNeoXConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -199,12 +203,16 @@ class GPTNeoXModel(nn.Module):
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
GPTNeoXLayer
(
config
,
cache_config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
GPTNeoXLayer
(
config
,
cache_config
,
quant_config
),
prefix
=
f
"
{
prefix
}
.layers"
,
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
hidden_size
))
def
forward
(
self
,
...
...
@@ -212,21 +220,27 @@ class GPTNeoXModel(nn.Module):
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_in
(
input_ids
)
for
i
in
range
(
len
(
self
.
layers
)):
intermediate_tensors
:
Optional
[
IntermediateTensors
],
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
embed_in
(
input_ids
)
else
:
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
position_ids
,
hidden_states
,
kv_caches
[
i
],
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
return
hidden_states
class
GPTNeoXForCausalLM
(
nn
.
Module
):
class
GPTNeoXForCausalLM
(
nn
.
Module
,
SupportsPP
):
def
__init__
(
self
,
...
...
@@ -247,6 +261,8 @@ class GPTNeoXForCausalLM(nn.Module):
self
.
embed_out
.
weight
=
self
.
gpt_neox
.
embed_in
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
gpt_neox
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
...
...
@@ -255,9 +271,9 @@ class GPTNeoXForCausalLM(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]
:
hidden_states
=
self
.
gpt_neox
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
attn_metadata
,
intermediate_tensors
)
return
hidden_states
def
compute_logits
(
...
...
@@ -288,6 +304,8 @@ class GPTNeoXForCausalLM(nn.Module):
# Models trained using OpenRLHF may include
# these tensors in the checkpoint. Skip them.
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
if
"query_key_value"
in
name
:
...
...
vllm/model_executor/models/granite.py
View file @
6d2051cc
...
...
@@ -51,7 +51,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_hip
from
.interfaces
import
SupportsLoRA
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
PPMissingLayer
,
is_pp_missing_parameter
,
make_layers
...
...
@@ -311,13 +311,13 @@ class GraniteModel(nn.Module):
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
hidden_states
*=
self
.
config
.
embedding_multiplier
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
hidden_states
*=
self
.
config
.
embedding_multiplier
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
...
...
@@ -337,7 +337,7 @@ class GraniteModel(nn.Module):
return
hidden_states
class
GraniteForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
class
GraniteForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
@@ -404,9 +404,12 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA):
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
if
hasattr
(
config
,
"logits_scaling"
):
logit_scale
/=
config
.
logits_scaling
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
)
scale
=
logit_scale
)
self
.
sampler
=
Sampler
()
else
:
self
.
lm_head
=
PPMissingLayer
()
...
...
@@ -428,8 +431,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA):
sampling_metadata
:
SamplingMetadata
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
if
logits
is
not
None
:
logits
/=
self
.
config
.
logits_scaling
return
logits
def
sample
(
...
...
vllm/model_executor/models/granitemoe.py
0 → 100644
View file @
6d2051cc
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GraniteMoe model."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers.models.granitemoe
import
GraniteMoeConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.
import
mixtral
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
make_layers
class
GraniteMoeMoE
(
nn
.
Module
):
"""A tensor-parallel MoE implementation for GraniteMoe that shards each
expert across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def
__init__
(
self
,
num_experts
:
int
,
top_k
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
# Gate always runs at half / full precision for now.
self
.
gate
=
ReplicatedLinear
(
hidden_size
,
num_experts
,
bias
=
False
,
params_dtype
=
params_dtype
,
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.gate"
)
self
.
experts
=
FusedMoE
(
num_experts
=
num_experts
,
top_k
=
top_k
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
params_dtype
=
params_dtype
,
reduce_results
=
True
,
renormalize
=
True
,
quant_config
=
quant_config
,
tp_size
=
tp_size
,
prefix
=
f
"
{
prefix
}
.experts"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_size
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
router_logits
)
return
final_hidden_states
.
view
(
orig_shape
)
class
GraniteMoeAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
max_position
:
int
=
4096
*
32
,
rope_theta
:
float
=
10000
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
attention_multiplier
:
Optional
[
float
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
(
attention_multiplier
if
attention_multiplier
is
not
None
else
self
.
head_dim
**-
1
)
self
.
rope_theta
=
rope_theta
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position
,
base
=
int
(
self
.
rope_theta
),
is_neox_style
=
True
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
GraniteMoeDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GraniteMoeConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
# Requires transformers > 4.32.0
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
self
.
self_attn
=
GraniteMoeAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
max_position
=
config
.
max_position_embeddings
,
num_kv_heads
=
config
.
num_key_value_heads
,
rope_theta
=
rope_theta
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
attention_multiplier
=
config
.
attention_multiplier
)
self
.
block_sparse_moe
=
GraniteMoeMoE
(
num_experts
=
config
.
num_local_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.block_sparse_moe"
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
residual_multiplier
=
config
.
residual_multiplier
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
# Self Attention
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
hidden_states
=
residual
+
hidden_states
*
self
.
residual_multiplier
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
block_sparse_moe
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
*
self
.
residual_multiplier
return
hidden_states
class
GraniteMoeModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GraniteMoeConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
self
.
org_vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
)
self
.
embedding_multiplier
=
config
.
embedding_multiplier
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
GraniteMoeDecoderLayer
(
config
,
cache_config
,
quant_config
=
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
)
->
torch
.
Tensor
:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
*=
self
.
embedding_multiplier
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"residual"
:
residual
})
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
class
GraniteMoeForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
fall_back_to_pt_during_load
=
False
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
}
# LoRA specific attributes
supported_lora_modules
=
[
"qkv_proj"
,
"o_proj"
,
"embed_tokens"
,
"lm_head"
,
]
embedding_modules
=
{
"embed_tokens"
:
"input_embeddings"
,
"lm_head"
:
"output_embeddings"
,
}
embedding_padding_modules
=
[
"lm_head"
]
def
__init__
(
self
,
config
:
GraniteMoeConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
model
=
GraniteMoeModel
(
config
,
cache_config
,
quant_config
,
lora_config
=
lora_config
,
prefix
=
"model"
)
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
quant_config
=
quant_config
,
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
scale
=
1
/
self
.
config
.
logits_scaling
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
make_empty_intermediate_tensors
(
self
,
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
IntermediateTensors
:
return
IntermediateTensors
({
"hidden_states"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
"residual"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
})
def
sample
(
self
,
logits
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
new_weights
=
{}
for
n
,
p
in
weights
:
if
n
.
endswith
(
'.block_sparse_moe.input_linear.weight'
):
for
e
in
range
(
p
.
size
(
0
)):
w1_name
=
n
.
replace
(
'.block_sparse_moe.input_linear.weight'
,
".block_sparse_moe.experts.%d.w1.weight"
%
e
)
w3_name
=
n
.
replace
(
'.block_sparse_moe.input_linear.weight'
,
".block_sparse_moe.experts.%d.w3.weight"
%
e
)
w1_param
,
w3_param
=
p
[
e
].
chunk
(
2
,
dim
=
0
)
assert
w1_name
not
in
new_weights
assert
w3_name
not
in
new_weights
new_weights
[
w1_name
]
=
w1_param
new_weights
[
w3_name
]
=
w3_param
elif
n
.
endswith
(
'.block_sparse_moe.output_linear.weight'
):
for
e
in
range
(
p
.
size
(
0
)):
w2_name
=
n
.
replace
(
'.block_sparse_moe.output_linear.weight'
,
".block_sparse_moe.experts.%d.w2.weight"
%
e
)
w2_param
=
p
[
e
]
assert
w2_name
not
in
new_weights
new_weights
[
w2_name
]
=
w2_param
elif
n
.
endswith
(
'.block_sparse_moe.router.layer.weight'
):
gate_name
=
n
.
replace
(
'.block_sparse_moe.router.layer.weight'
,
".block_sparse_moe.gate.weight"
)
assert
gate_name
not
in
new_weights
new_weights
[
gate_name
]
=
p
elif
n
==
'lm_head.weight'
and
self
.
config
.
tie_word_embeddings
:
pass
else
:
new_weights
[
n
]
=
p
mixtral
.
MixtralForCausalLM
.
load_weights
(
self
,
new_weights
.
items
())
vllm/model_executor/models/idefics2_vision_model.py
View file @
6d2051cc
...
...
@@ -65,11 +65,10 @@ class Idefics2VisionEmbeddings(nn.Module):
self
.
position_embedding
=
nn
.
Embedding
(
self
.
num_positions
,
self
.
embed_dim
)
def
forward
(
self
,
pixel_values
:
torch
.
FloatTensor
,
patch_attention_mask
:
torch
.
BoolTensor
,
)
->
torch
.
Tensor
:
def
forward
(
self
,
pixel_values
:
torch
.
FloatTensor
,
patch_attention_mask
:
torch
.
BoolTensor
,
tgt_sizes
:
Optional
[
torch
.
IntTensor
]
=
None
)
->
torch
.
Tensor
:
batch_size
,
_
,
max_im_h
,
max_im_w
=
pixel_values
.
shape
patch_embeds
=
self
.
patch_embedding
(
pixel_values
)
embeddings
=
patch_embeds
.
flatten
(
2
).
transpose
(
1
,
2
)
...
...
@@ -84,8 +83,13 @@ class Idefics2VisionEmbeddings(nn.Module):
fill_value
=
0
)
for
batch_idx
,
p_attn_mask
in
enumerate
(
patch_attention_mask
):
nb_patches_h
=
p_attn_mask
[:,
0
].
sum
()
nb_patches_w
=
p_attn_mask
[
0
].
sum
()
if
tgt_sizes
is
not
None
:
nb_patches_h
=
tgt_sizes
[
batch_idx
][
0
]
nb_patches_w
=
tgt_sizes
[
batch_idx
][
1
]
else
:
nb_patches_h
=
p_attn_mask
[:,
0
].
sum
()
nb_patches_w
=
p_attn_mask
[
0
].
sum
()
fractional_coords_h
=
torch
.
arange
(
0
,
1
-
1e-6
,
1
/
nb_patches_h
)
fractional_coords_w
=
torch
.
arange
(
0
,
1
-
1e-6
,
1
/
nb_patches_w
)
bucket_coords_h
=
torch
.
bucketize
(
fractional_coords_h
,
...
...
@@ -287,10 +291,12 @@ class Idefics2VisionTransformer(nn.Module):
self
,
pixel_values
,
patch_attention_mask
:
Optional
[
torch
.
BoolTensor
]
=
None
,
)
->
torch
.
tensor
:
tgt_sizes
:
Optional
[
torch
.
IntTensor
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embeddings
(
pixel_values
=
pixel_values
,
patch_attention_mask
=
patch_attention_mask
)
patch_attention_mask
=
patch_attention_mask
,
tgt_sizes
=
tgt_sizes
)
encoder_outputs
=
self
.
encoder
(
hidden_states
)
last_hidden_state
=
self
.
post_layernorm
(
encoder_outputs
)
return
last_hidden_state
vllm/model_executor/models/interfaces.py
View file @
6d2051cc
from
typing
import
(
ClassVar
,
Dict
,
List
,
Literal
,
Optional
,
Protocol
,
Type
,
Union
,
overload
,
runtime_checkable
)
from
typing
import
(
TYPE_CHECKING
,
ClassVar
,
Dict
,
List
,
Literal
,
Optional
,
Protocol
,
Type
,
Union
,
overload
,
runtime_checkable
)
import
torch
from
typing_extensions
import
TypeIs
from
vllm.config
import
LoRAConfig
,
MultiModalConfig
,
SchedulerConfig
from
vllm.logger
import
init_logger
from
vllm.utils
import
supports_kw
if
TYPE_CHECKING
:
from
vllm.config
import
LoRAConfig
,
MultiModalConfig
,
SchedulerConfig
from
vllm.sequence
import
IntermediateTensors
logger
=
init_logger
(
__name__
)
...
...
@@ -22,7 +27,7 @@ class SupportsMultiModal(Protocol):
MRO of your model class.
"""
def
__init__
(
self
,
*
,
multimodal_config
:
MultiModalConfig
)
->
None
:
def
__init__
(
self
,
*
,
multimodal_config
:
"
MultiModalConfig
"
)
->
None
:
...
...
...
@@ -32,7 +37,7 @@ class SupportsMultiModal(Protocol):
class
_SupportsMultiModalType
(
Protocol
):
supports_multimodal
:
Literal
[
True
]
def
__call__
(
self
,
*
,
multimodal_config
:
MultiModalConfig
)
->
None
:
def
__call__
(
self
,
*
,
multimodal_config
:
"
MultiModalConfig
"
)
->
None
:
...
...
...
@@ -75,7 +80,7 @@ class SupportsLoRA(Protocol):
embedding_padding_modules
:
ClassVar
[
List
[
str
]]
# lora_config is None when LoRA is not enabled
def
__init__
(
self
,
*
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
)
->
None
:
def
__init__
(
self
,
*
,
lora_config
:
Optional
[
"
LoRAConfig
"
]
=
None
)
->
None
:
...
...
...
@@ -90,7 +95,7 @@ class _SupportsLoRAType(Protocol):
embedding_modules
:
Dict
[
str
,
str
]
embedding_padding_modules
:
List
[
str
]
def
__call__
(
self
,
*
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
)
->
None
:
def
__call__
(
self
,
*
,
lora_config
:
Optional
[
"
LoRAConfig
"
]
=
None
)
->
None
:
...
...
...
@@ -136,15 +141,128 @@ def supports_lora(
return
result
def
_supports_lora
(
model
:
Union
[
Type
[
object
],
object
],
)
->
Union
[
TypeIs
[
Type
[
SupportsLoRA
]],
TypeIs
[
SupportsLoRA
]]:
def
_supports_lora
(
model
:
Union
[
Type
[
object
],
object
])
->
bool
:
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
_SupportsLoRAType
)
return
isinstance
(
model
,
SupportsLoRA
)
@
runtime_checkable
class
SupportsPP
(
Protocol
):
"""The interface required for all models that support pipeline parallel."""
supports_pp
:
ClassVar
[
Literal
[
True
]]
=
True
"""
A flag that indicates this model supports pipeline parallel.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
def
make_empty_intermediate_tensors
(
self
,
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
)
->
"IntermediateTensors"
:
"""Called when PP rank > 0 for profiling purposes."""
...
def
forward
(
self
,
*
,
intermediate_tensors
:
Optional
[
"IntermediateTensors"
],
)
->
Union
[
torch
.
Tensor
,
"IntermediateTensors"
]:
"""
Accept :class:`IntermediateTensors` when PP rank > 0.
Return :class:`IntermediateTensors` only for the last PP rank.
"""
...
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@
runtime_checkable
class
_SupportsPPType
(
Protocol
):
supports_pp
:
Literal
[
True
]
def
make_empty_intermediate_tensors
(
self
,
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
)
->
"IntermediateTensors"
:
...
def
forward
(
self
,
*
,
intermediate_tensors
:
Optional
[
"IntermediateTensors"
],
)
->
Union
[
torch
.
Tensor
,
"IntermediateTensors"
]:
...
@
overload
def
supports_pp
(
model
:
Type
[
object
])
->
TypeIs
[
Type
[
SupportsPP
]]:
...
@
overload
def
supports_pp
(
model
:
object
)
->
TypeIs
[
SupportsPP
]:
...
def
supports_pp
(
model
:
Union
[
Type
[
object
],
object
],
)
->
Union
[
bool
,
TypeIs
[
Type
[
SupportsPP
]],
TypeIs
[
SupportsPP
]]:
supports_attributes
=
_supports_pp_attributes
(
model
)
supports_inspect
=
_supports_pp_inspect
(
model
)
if
supports_attributes
and
not
supports_inspect
:
logger
.
warning
(
"The model (%s) sets `supports_pp=True`, but does not accept "
"`intermediate_tensors` in its `forward` method"
,
model
)
if
not
supports_attributes
:
pp_attrs
=
(
"make_empty_intermediate_tensors"
,
)
missing_attrs
=
tuple
(
attr
for
attr
in
pp_attrs
if
not
hasattr
(
model
,
attr
))
if
getattr
(
model
,
"supports_pp"
,
False
):
if
missing_attrs
:
logger
.
warning
(
"The model (%s) sets `supports_pp=True`, "
"but is missing PP-specific attributes: %s"
,
model
,
missing_attrs
,
)
else
:
if
not
missing_attrs
:
logger
.
warning
(
"The model (%s) contains all PP-specific attributes, "
"but does not set `supports_pp=True`."
,
model
)
return
supports_attributes
and
supports_inspect
def
_supports_pp_attributes
(
model
:
Union
[
Type
[
object
],
object
])
->
bool
:
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
_SupportsPPType
)
return
isinstance
(
model
,
SupportsPP
)
def
_supports_pp_inspect
(
model
:
Union
[
Type
[
object
],
object
])
->
bool
:
model_forward
=
getattr
(
model
,
"forward"
,
None
)
if
not
callable
(
model_forward
):
return
False
return
supports_kw
(
model_forward
,
"intermediate_tensors"
)
@
runtime_checkable
class
HasInnerState
(
Protocol
):
"""The interface required for all models that has inner state."""
...
...
@@ -153,12 +271,12 @@ class HasInnerState(Protocol):
"""
A flag that indicates this model has inner state.
Models that has inner state usually need access to the scheduler_config
for max_num_seqs
,etc.
.. (Currently only used by
Jamba
)
for max_num_seqs,
etc.
True for e.g. both Mamba and
Jamba
.
"""
def
__init__
(
self
,
*
,
scheduler_config
:
Optional
[
SchedulerConfig
]
=
None
)
->
None
:
scheduler_config
:
Optional
[
"
SchedulerConfig
"
]
=
None
)
->
None
:
...
...
...
@@ -168,7 +286,7 @@ class _HasInnerStateType(Protocol):
def
__init__
(
self
,
*
,
scheduler_config
:
Optional
[
SchedulerConfig
]
=
None
)
->
None
:
scheduler_config
:
Optional
[
"
SchedulerConfig
"
]
=
None
)
->
None
:
...
...
...
@@ -189,3 +307,46 @@ def has_inner_state(
return
isinstance
(
model
,
_HasInnerStateType
)
return
isinstance
(
model
,
HasInnerState
)
@
runtime_checkable
class
IsAttentionFree
(
Protocol
):
"""The interface required for all models like Mamba that lack attention,
but do have state whose size is constant wrt the number of tokens."""
is_attention_free
:
ClassVar
[
Literal
[
True
]]
=
True
"""
A flag that indicates this model has no attention.
Used for block manager and attention backend selection.
True for Mamba but not Jamba.
"""
def
__init__
(
self
)
->
None
:
...
@
runtime_checkable
class
_IsAttentionFreeType
(
Protocol
):
is_attention_free
:
ClassVar
[
Literal
[
True
]]
def
__init__
(
self
)
->
None
:
...
@
overload
def
is_attention_free
(
model
:
object
)
->
TypeIs
[
IsAttentionFree
]:
...
@
overload
def
is_attention_free
(
model
:
Type
[
object
])
->
TypeIs
[
Type
[
IsAttentionFree
]]:
...
def
is_attention_free
(
model
:
Union
[
Type
[
object
],
object
]
)
->
Union
[
TypeIs
[
Type
[
IsAttentionFree
]],
TypeIs
[
IsAttentionFree
]]:
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
_IsAttentionFreeType
)
return
isinstance
(
model
,
IsAttentionFree
)
vllm/model_executor/models/interfaces_base.py
0 → 100644
View file @
6d2051cc
from
typing
import
(
TYPE_CHECKING
,
List
,
Optional
,
Protocol
,
Type
,
Union
,
overload
,
runtime_checkable
)
import
torch
import
torch.nn
as
nn
from
transformers
import
PretrainedConfig
from
typing_extensions
import
TypeIs
,
TypeVar
from
vllm.logger
import
init_logger
from
vllm.utils
import
supports_kw
if
TYPE_CHECKING
:
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.model_executor.layers.pooler
import
PoolerOutput
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
logger
=
init_logger
(
__name__
)
# The type of HF config
C_co
=
TypeVar
(
"C_co"
,
bound
=
PretrainedConfig
,
covariant
=
True
)
# The type of hidden states
# Currently, T = torch.Tensor for all models except for Medusa
# which has T = List[torch.Tensor]
T
=
TypeVar
(
"T"
,
default
=
torch
.
Tensor
)
T_co
=
TypeVar
(
"T_co"
,
default
=
torch
.
Tensor
,
covariant
=
True
)
# NOTE: Unlike those in `interfaces.py`, we don't define `ClassVar` tags
# for the base interfaces to avoid breaking OOT registration for existing models
# that don't inherit from the base interface classes
@
runtime_checkable
class
VllmModel
(
Protocol
[
C_co
,
T_co
]):
def
__init__
(
self
,
config
:
C_co
,
*
,
cache_config
:
Optional
[
"CacheConfig"
],
quant_config
:
Optional
[
"QuantizationConfig"
],
)
->
None
:
...
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
"AttentionMetadata"
,
)
->
T_co
:
...
def
_check_vllm_model_init
(
model
:
Union
[
Type
[
object
],
object
])
->
bool
:
model_init
=
model
.
__init__
vllm_kws
=
(
"cache_config"
,
"quant_config"
)
missing_kws
=
tuple
(
kw
for
kw
in
vllm_kws
if
not
supports_kw
(
model_init
,
kw
))
if
missing_kws
and
(
isinstance
(
model
,
type
)
and
issubclass
(
model
,
nn
.
Module
)):
logger
.
warning
(
"The model (%s) is missing "
"vLLM-specific keywords from its initializer: %s"
,
model
,
missing_kws
,
)
return
len
(
missing_kws
)
==
0
def
_check_vllm_model_forward
(
model
:
Union
[
Type
[
object
],
object
])
->
bool
:
model_forward
=
getattr
(
model
,
"forward"
,
None
)
if
not
callable
(
model_forward
):
return
False
vllm_kws
=
(
"input_ids"
,
"positions"
,
"kv_caches"
,
"attn_metadata"
)
missing_kws
=
tuple
(
kw
for
kw
in
vllm_kws
if
not
supports_kw
(
model_forward
,
kw
))
if
missing_kws
and
(
isinstance
(
model
,
type
)
and
issubclass
(
model
,
nn
.
Module
)):
logger
.
warning
(
"The model (%s) is missing "
"vLLM-specific keywords from its initializer: %s"
,
model
,
missing_kws
,
)
return
len
(
missing_kws
)
==
0
@
overload
def
is_vllm_model
(
model
:
Type
[
object
])
->
TypeIs
[
Type
[
VllmModel
]]:
...
@
overload
def
is_vllm_model
(
model
:
object
)
->
TypeIs
[
VllmModel
]:
...
def
is_vllm_model
(
model
:
Union
[
Type
[
object
],
object
],
)
->
Union
[
TypeIs
[
Type
[
VllmModel
]],
TypeIs
[
VllmModel
]]:
return
_check_vllm_model_init
(
model
)
and
_check_vllm_model_forward
(
model
)
@
runtime_checkable
class
VllmModelForTextGeneration
(
VllmModel
[
C_co
,
T
],
Protocol
[
C_co
,
T
]):
def
compute_logits
(
self
,
hidden_states
:
T
,
sampling_metadata
:
"SamplingMetadata"
,
)
->
Optional
[
T
]:
"""Return `None` if TP rank > 0."""
...
def
sample
(
self
,
logits
:
T
,
sampling_metadata
:
"SamplingMetadata"
,
)
->
"SamplerOutput"
:
"""Only called on TP rank 0."""
...
@
overload
def
is_text_generation_model
(
model
:
Type
[
object
])
->
TypeIs
[
Type
[
VllmModelForTextGeneration
]]:
...
@
overload
def
is_text_generation_model
(
model
:
object
)
->
TypeIs
[
VllmModelForTextGeneration
]:
...
def
is_text_generation_model
(
model
:
Union
[
Type
[
object
],
object
],
)
->
Union
[
TypeIs
[
Type
[
VllmModelForTextGeneration
]],
TypeIs
[
VllmModelForTextGeneration
]]:
if
not
is_vllm_model
(
model
):
return
False
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
VllmModelForTextGeneration
)
return
isinstance
(
model
,
VllmModelForTextGeneration
)
@
runtime_checkable
class
VllmModelForEmbedding
(
VllmModel
[
C_co
,
T
],
Protocol
[
C_co
,
T
]):
def
pooler
(
self
,
hidden_states
:
T
,
pooling_metadata
:
"PoolingMetadata"
,
)
->
"PoolerOutput"
:
"""Only called on TP rank 0."""
...
@
overload
def
is_embedding_model
(
model
:
Type
[
object
])
->
TypeIs
[
Type
[
VllmModelForEmbedding
]]:
...
@
overload
def
is_embedding_model
(
model
:
object
)
->
TypeIs
[
VllmModelForEmbedding
]:
...
def
is_embedding_model
(
model
:
Union
[
Type
[
object
],
object
],
)
->
Union
[
TypeIs
[
Type
[
VllmModelForEmbedding
]],
TypeIs
[
VllmModelForEmbedding
]]:
if
not
is_vllm_model
(
model
):
return
False
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
VllmModelForEmbedding
)
return
isinstance
(
model
,
VllmModelForEmbedding
)
vllm/model_executor/models/intern_vit.py
View file @
6d2051cc
...
...
@@ -4,6 +4,7 @@
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from
functools
import
partial
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
...
...
@@ -11,7 +12,10 @@ import torch.nn as nn
import
torch.nn.functional
as
F
from
transformers
import
PretrainedConfig
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
split_tensor_along_last_dim
,
tensor_model_parallel_all_gather
)
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -54,7 +58,7 @@ class InternVisionEmbeddings(nn.Module):
self
.
position_embedding
=
nn
.
Parameter
(
torch
.
randn
(
1
,
self
.
num_positions
,
self
.
embed_dim
))
def
_get_pos_embed
(
self
,
pos_embed
,
H
,
W
):
def
_get_pos_embed
(
self
,
pos_embed
:
torch
.
Tensor
,
H
:
int
,
W
:
int
):
target_dtype
=
pos_embed
.
dtype
pos_embed
=
pos_embed
.
float
().
reshape
(
1
,
self
.
image_size
//
self
.
patch_size
,
...
...
@@ -63,9 +67,21 @@ class InternVisionEmbeddings(nn.Module):
size
=
(
H
,
W
),
mode
=
'bicubic'
,
align_corners
=
False
)
pos_embed
=
pos_embed
.
reshape
(
1
,
-
1
,
H
*
W
).
permute
(
0
,
2
,
1
).
to
(
target_dtype
)
return
pos_embed
return
pos_embed
.
reshape
(
1
,
-
1
,
H
*
W
).
permute
(
0
,
2
,
1
).
to
(
target_dtype
)
def
_get_position_embedding
(
self
,
H
:
int
,
W
:
int
)
->
torch
.
Tensor
:
position_embedding
=
self
.
position_embedding
if
self
.
num_patches
==
H
*
W
:
return
position_embedding
return
torch
.
cat
(
[
position_embedding
[:,
:
1
,
:],
self
.
_get_pos_embed
(
position_embedding
[:,
1
:,
:],
H
,
W
),
],
dim
=
1
,
)
def
forward
(
self
,
pixel_values
:
torch
.
FloatTensor
)
->
torch
.
Tensor
:
target_dtype
=
self
.
patch_embedding
.
weight
.
dtype
...
...
@@ -76,12 +92,7 @@ class InternVisionEmbeddings(nn.Module):
class_embeds
=
self
.
class_embedding
.
expand
(
batch_size
,
1
,
-
1
).
to
(
target_dtype
)
embeddings
=
torch
.
cat
([
class_embeds
,
patch_embeds
],
dim
=
1
)
position_embedding
=
torch
.
cat
([
self
.
position_embedding
[:,
:
1
,
:],
self
.
_get_pos_embed
(
self
.
position_embedding
[:,
1
:,
:],
height
,
width
)
],
dim
=
1
)
position_embedding
=
self
.
_get_position_embedding
(
height
,
width
)
embeddings
=
embeddings
+
position_embedding
.
to
(
target_dtype
)
return
embeddings
...
...
@@ -93,8 +104,11 @@ class InternParallelAttention(nn.Module):
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
*
,
num_dummy_heads
:
int
=
0
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
...
...
@@ -105,11 +119,19 @@ class InternParallelAttention(nn.Module):
f
'(got `embed_dim`:
{
self
.
embed_dim
}
and `num_heads`:'
f
'
{
self
.
num_heads
}
).'
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
# Additional dummy heads are used to enable TP for common GPU counts.
self
.
dummy_dim
=
(
num_dummy_heads
+
self
.
num_heads
)
*
self
.
head_dim
self
.
num_heads_per_partition
=
divide
(
num_dummy_heads
+
self
.
num_heads
,
self
.
tp_size
)
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
qkv
=
QKVParallelLinear
(
self
.
embed_dim
,
self
.
head_dim
,
self
.
num_heads
,
num_dummy_heads
+
self
.
num_heads
,
bias
=
config
.
qkv_bias
,
quant_config
=
quant_config
,
)
...
...
@@ -117,34 +139,44 @@ class InternParallelAttention(nn.Module):
self
.
qk_normalization
=
config
.
qk_normalization
if
self
.
qk_normalization
:
self
.
q_norm
=
RMSNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
k_norm
=
RMSNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
q_norm
=
RMSNorm
(
self
.
dummy_dim
,
eps
=
config
.
layer_norm_eps
,
var_hidden_size
=
self
.
embed_dim
)
self
.
k_norm
=
RMSNorm
(
self
.
dummy_dim
,
eps
=
config
.
layer_norm_eps
,
var_hidden_size
=
self
.
embed_dim
)
self
.
proj
=
RowParallelLinear
(
self
.
embed
_dim
,
self
.
dummy
_dim
,
self
.
embed_dim
,
quant_config
=
quant_config
,
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
def
forward
(
self
,
x
):
B
,
N
,
C
=
x
.
shape
def
_apply_qk_norm
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
):
if
self
.
tp_size
>
1
:
q
=
tensor_model_parallel_all_gather
(
q
.
contiguous
())
k
=
tensor_model_parallel_all_gather
(
k
.
contiguous
())
q
=
self
.
q_norm
.
forward_native
(
q
)
k
=
self
.
k_norm
.
forward_native
(
k
)
if
self
.
tp_size
>
1
:
splitter
=
partial
(
split_tensor_along_last_dim
,
num_partitions
=
self
.
tp_size
)
q
=
splitter
(
q
)[
self
.
tp_rank
]
k
=
splitter
(
k
)[
self
.
tp_rank
]
return
q
,
k
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
B
,
N
,
_
=
x
.
shape
qkv
,
_
=
self
.
qkv
(
x
)
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=-
1
)
if
self
.
qk_normalization
:
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
=
q
.
view
(
B
,
N
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
k
=
k
.
view
(
B
,
N
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
v
=
v
.
view
(
B
,
N
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
if
self
.
qk_normalization
:
B_
,
N_
,
H_
,
D_
=
q
.
shape
q
=
self
.
q_norm
.
forward_native
(
q
.
flatten
(
-
2
,
-
1
)).
view
(
B_
,
N_
,
H_
,
D_
)
k
=
self
.
k_norm
.
forward_native
(
k
.
flatten
(
-
2
,
-
1
)).
view
(
B_
,
N_
,
H_
,
D_
)
x
=
xops
.
memory_efficient_attention_forward
(
q
,
k
,
v
,
scale
=
self
.
scale
)
x
=
x
.
view
(
B
,
N
,
-
1
)
...
...
@@ -155,8 +187,14 @@ class InternParallelAttention(nn.Module):
class
InternSdpaAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
self
,
config
:
PretrainedConfig
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
*
,
num_dummy_heads
:
int
=
0
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
...
...
@@ -167,20 +205,27 @@ class InternSdpaAttention(nn.Module):
f
'(got `embed_dim`:
{
self
.
embed_dim
}
and `num_heads`:'
f
'
{
self
.
num_heads
}
).'
)
# Additional dummy heads are used to enable TP for common GPU counts.
self
.
dummy_dim
=
(
num_dummy_heads
+
self
.
num_heads
)
*
self
.
head_dim
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
qkv
=
nn
.
Linear
(
self
.
embed_dim
,
3
*
self
.
embed
_dim
,
3
*
self
.
dummy
_dim
,
bias
=
config
.
qkv_bias
)
self
.
qk_normalization
=
config
.
qk_normalization
if
self
.
qk_normalization
:
self
.
q_norm
=
RMSNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
k_norm
=
RMSNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
q_norm
=
RMSNorm
(
self
.
dummy_dim
,
eps
=
config
.
layer_norm_eps
,
var_hidden_size
=
self
.
embed_dim
)
self
.
k_norm
=
RMSNorm
(
self
.
dummy_dim
,
eps
=
config
.
layer_norm_eps
,
var_hidden_size
=
self
.
embed_dim
)
self
.
proj
=
nn
.
Linear
(
self
.
embed
_dim
,
self
.
embed_dim
)
self
.
proj
=
nn
.
Linear
(
self
.
dummy
_dim
,
self
.
embed_dim
)
def
forward
(
self
,
x
)
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
B
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
)
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=-
1
)
...
...
@@ -233,22 +278,23 @@ class InternMLP(nn.Module):
class
InternVisionEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
*
,
num_dummy_heads
:
int
=
0
,
)
->
None
:
super
().
__init__
()
self
.
embed_dim
=
config
.
hidden_size
self
.
intermediate_size
=
config
.
intermediate_size
self
.
norm_type
=
config
.
norm_type
# fallback to sdpa attention if tp unavailable
tp_size
=
get_tensor_model_parallel_world_size
()
num_heads
=
config
.
num_attention_heads
if
USE_XFORMERS_OPS
and
num_heads
%
tp_size
==
0
:
self
.
attn
=
InternParallelAttention
(
config
,
quant_config
=
quant_config
)
else
:
self
.
attn
=
InternSdpaAttention
(
config
)
self
.
attn
=
self
.
_init_attn
(
config
,
quant_config
,
num_dummy_heads
=
num_dummy_heads
)
self
.
mlp
=
InternMLP
(
config
,
quant_config
=
quant_config
)
self
.
norm1
=
NORM2FN
[
self
.
norm_type
](
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
...
...
@@ -260,6 +306,24 @@ class InternVisionEncoderLayer(nn.Module):
self
.
ls2
=
nn
.
Parameter
(
config
.
initializer_factor
*
torch
.
ones
(
self
.
embed_dim
))
def
_init_attn
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
],
*
,
num_dummy_heads
:
int
,
):
# fallback to sdpa attention if tp unavailable
tp_size
=
get_tensor_model_parallel_world_size
()
num_heads
=
config
.
num_attention_heads
if
USE_XFORMERS_OPS
and
(
num_heads
+
num_dummy_heads
)
%
tp_size
==
0
:
return
InternParallelAttention
(
config
,
quant_config
=
quant_config
,
num_dummy_heads
=
num_dummy_heads
)
return
InternSdpaAttention
(
config
,
num_dummy_heads
=
num_dummy_heads
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -275,19 +339,27 @@ class InternVisionEncoderLayer(nn.Module):
class
InternVisionEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
num_hidden_layers_override
:
Optional
[
int
]
=
None
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
*
,
num_hidden_layers_override
:
Optional
[
int
]
=
None
,
num_dummy_heads
:
int
=
0
,
):
super
().
__init__
()
self
.
config
=
config
if
num_hidden_layers_override
is
None
:
num_hidden_layers
=
config
.
num_hidden_layers
else
:
num_hidden_layers
=
num_hidden_layers_override
self
.
layers
=
nn
.
ModuleList
([
InternVisionEncoderLayer
(
config
=
config
,
quant_config
=
quant_config
)
InternVisionEncoderLayer
(
config
,
quant_config
,
num_dummy_heads
=
num_dummy_heads
)
for
_
in
range
(
num_hidden_layers
)
])
...
...
@@ -302,35 +374,25 @@ class InternVisionEncoder(nn.Module):
class
InternVisionModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
num_hidden_layers_override
:
Optional
[
int
]
=
None
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
*
,
num_hidden_layers_override
:
Optional
[
int
]
=
None
,
num_dummy_heads
:
int
=
0
,
):
super
().
__init__
()
self
.
config
=
config
self
.
embeddings
=
InternVisionEmbeddings
(
config
)
self
.
encoder
=
InternVisionEncoder
(
config
=
config
,
quant_config
=
quant_config
,
num_hidden_layers_override
=
num_hidden_layers_override
)
def
resize_pos_embeddings
(
self
,
old_size
,
new_size
,
patch_size
):
pos_emb
=
self
.
embeddings
.
position_embedding
_
,
num_positions
,
embed_dim
=
pos_emb
.
shape
cls_emb
=
pos_emb
[:,
:
1
,
:]
pos_emb
=
pos_emb
[:,
1
:,
:].
reshape
(
1
,
old_size
//
patch_size
,
old_size
//
patch_size
,
-
1
).
permute
(
0
,
3
,
1
,
2
)
pos_emb
=
F
.
interpolate
(
pos_emb
.
float
(),
size
=
new_size
//
patch_size
,
mode
=
'bicubic'
,
align_corners
=
False
)
pos_emb
=
pos_emb
.
to
(
cls_emb
.
dtype
).
reshape
(
1
,
embed_dim
,
-
1
).
permute
(
0
,
2
,
1
)
pos_emb
=
torch
.
cat
([
cls_emb
,
pos_emb
],
dim
=
1
)
self
.
embeddings
.
position_embedding
=
nn
.
Parameter
(
pos_emb
)
self
.
embeddings
.
image_size
=
new_size
num_hidden_layers_override
=
num_hidden_layers_override
,
num_dummy_heads
=
num_dummy_heads
,
)
def
get_input_embeddings
(
self
):
return
self
.
embeddings
...
...
vllm/model_executor/models/internlm2.py
View file @
6d2051cc
...
...
@@ -18,8 +18,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -28,6 +27,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
...
...
@@ -266,7 +266,7 @@ class InternLM2Model(nn.Module):
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
IntermediateTensors
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
...
...
@@ -297,7 +297,7 @@ class InternLM2Model(nn.Module):
return
hidden_states
class
InternLM2ForCausalLM
(
nn
.
Module
):
class
InternLM2ForCausalLM
(
nn
.
Module
,
SupportsPP
):
def
__init__
(
self
,
...
...
@@ -325,7 +325,7 @@ class InternLM2ForCausalLM(nn.Module):
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
IntermediateTensors
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
...
...
vllm/model_executor/models/internvl.py
View file @
6d2051cc
...
...
@@ -5,6 +5,7 @@
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import
re
from
functools
import
cached_property
,
partial
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
...
...
@@ -16,11 +17,10 @@ from transformers import PretrainedConfig
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.
distributed
import
get_pp_group
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMI
nputs
from
vllm.
inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
InputContext
,
token_i
nputs
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.intern_vit
import
InternVisionModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
...
...
@@ -31,9 +31,9 @@ from vllm.utils import is_list_of
from
.clip
import
(
dummy_image_for_clip
,
dummy_seq_data_for_clip
,
get_clip_num_patches
)
from
.interfaces
import
SupportsMultiModal
from
.utils
import
(
flatten_bn
,
group_weights_with_prefix
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
IMG_START
=
'<img>'
IMG_END
=
'</img>'
...
...
@@ -122,6 +122,20 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
return
blocks
,
target_width
,
target_height
def
calculate_num_blocks_wrapper
(
hf_config
:
PretrainedConfig
,
max_dynamic_patch
:
Optional
[
int
]
=
None
):
if
max_dynamic_patch
is
None
:
max_dynamic_patch
=
hf_config
.
max_dynamic_patch
min_num
=
hf_config
.
min_dynamic_patch
image_size
=
hf_config
.
vision_config
.
image_size
use_thumbnail
=
hf_config
.
use_thumbnail
return
partial
(
calculate_num_blocks
,
min_num
=
min_num
,
max_num
=
max_dynamic_patch
,
image_size
=
image_size
,
use_thumbnail
=
use_thumbnail
)
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def
dynamic_preprocess
(
image
:
Image
.
Image
,
min_num
:
int
,
max_num
:
int
,
image_size
:
int
,
...
...
@@ -168,172 +182,231 @@ def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
return
pixel_values
def
get_internvl_num_patches
(
image_size
:
int
,
patch_size
:
int
,
downsample_ratio
:
float
):
def
image_to_pixel_values_wrapper
(
hf_config
:
PretrainedConfig
,
max_dynamic_patch
:
Optional
[
int
]
=
None
):
image_size
=
hf_config
.
vision_config
.
image_size
min_num
=
hf_config
.
min_dynamic_patch
if
max_dynamic_patch
is
None
:
max_dynamic_patch
=
hf_config
.
max_dynamic_patch
use_thumbnail
=
hf_config
.
use_thumbnail
return
partial
(
image_to_pixel_values
,
input_size
=
image_size
,
min_num
=
min_num
,
max_num
=
max_dynamic_patch
,
use_thumbnail
=
use_thumbnail
)
def
get_internvl_num_patches
(
hf_config
:
PretrainedConfig
):
vision_config
=
hf_config
.
vision_config
downsample_ratio
=
hf_config
.
downsample_ratio
image_size
=
vision_config
.
image_size
patch_size
=
vision_config
.
patch_size
return
int
(
get_clip_num_patches
(
image_size
=
image_size
,
patch_size
=
patch_size
)
*
(
downsample_ratio
**
2
))
def
get_max_internvl_image_tokens
(
ctx
:
InputContext
):
def
get_max_internvl_image_tokens
(
ctx
:
InputContext
,
*
,
max_dynamic_patch
:
Optional
[
int
]
=
None
):
hf_config
=
ctx
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
if
max_dynamic_patch
is
None
:
max_dynamic_patch
=
hf_config
.
max_dynamic_patch
use_thumbnail
=
hf_config
.
use_thumbnail
max_dynamic_patch
=
hf_config
.
max_dynamic_patch
if
use_thumbnail
:
if
use_thumbnail
and
max_dynamic_patch
>
1
:
max_dynamic_patch
+=
1
downsample_ratio
=
hf_config
.
downsample_ratio
image_size
=
vision_config
.
image_size
patch_size
=
vision_config
.
patch_size
num_patches
=
get_internvl_num_patches
(
image_size
,
patch_size
,
downsample_ratio
)
num_patches
=
get_internvl_num_patches
(
hf_config
)
return
num_patches
*
max_dynamic_patch
def
input_processor_for_internvl
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
):
multi_modal_data
=
llm_inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
return
llm_inputs
model_config
=
ctx
.
model_config
def
get_max_internvl_image_size
(
ctx
:
InputContext
,
*
,
max_dynamic_patch
:
Optional
[
int
]
=
None
):
hf_config
=
ctx
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
image_size
=
vision_config
.
image_size
patch_size
=
vision_config
.
patch_size
downsample_ratio
=
hf_config
.
downsample_ratio
num_patches
=
get_internvl_num_patches
(
image_size
,
patch_size
,
downsample_ratio
)
image_size
=
hf_config
.
vision_config
.
image_size
image_data
=
multi_modal_data
[
"image"
]
min_num
=
hf_config
.
min_dynamic_patch
max_num
=
hf_config
.
max_dynamic_patch
if
max_dynamic_patch
is
None
:
max_dynamic_patch
=
hf_config
.
max_dynamic_patch
use_thumbnail
=
hf_config
.
use_thumbnail
if
isinstance
(
image_data
,
Image
.
Image
):
width
,
height
=
image_data
.
size
num_blocks
,
_
,
_
=
calculate_num_blocks
(
width
,
height
,
min_num
,
max_num
,
image_size
,
use_thumbnail
)
image_feature_size
=
[
num_blocks
*
num_patches
]
elif
is_list_of
(
image_data
,
Image
.
Image
):
image_feature_size
=
[]
for
image
in
image_data
:
width
,
height
=
image
.
size
num_blocks
,
_
,
_
=
calculate_num_blocks
(
width
,
height
,
min_num
,
max_num
,
image_size
,
use_thumbnail
)
image_feature_size
.
append
(
num_blocks
*
num_patches
)
elif
isinstance
(
image_data
,
torch
.
Tensor
):
num_images
,
image_feature_size
,
hidden_size
=
image_data
.
shape
else
:
raise
TypeError
(
f
"Invalid image type:
{
type
(
image_data
)
}
"
)
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
)
prompt
=
llm_inputs
.
get
(
"prompt"
)
prompt_token_ids
=
llm_inputs
[
"prompt_token_ids"
]
if
prompt
is
None
:
prompt
=
tokenizer
.
decode
(
prompt_token_ids
)
new_prompt
=
prompt
image_idx
=
sorted
(
map
(
int
,
re
.
findall
(
r
"Image-(\d+): <image>\n"
,
prompt
)))
for
idx
,
feature_size
in
enumerate
(
image_feature_size
,
start
=
1
):
image_prompt
=
IMG_START
+
IMG_CONTEXT
*
feature_size
+
IMG_END
if
not
image_idx
:
image_prompt
=
f
"Image-
{
idx
}
:
{
image_prompt
}
"
new_prompt
=
new_prompt
.
replace
(
'<image>'
,
image_prompt
,
1
)
new_prompt_token_ids
=
tokenizer
.
encode
(
new_prompt
)
return
LLMInputs
(
prompt
=
prompt
,
prompt_token_ids
=
new_prompt_token_ids
,
multi_modal_data
=
multi_modal_data
)
def
input_mapper_for_internvl
(
ctx
:
InputContext
,
data
:
object
):
hf_config
=
ctx
.
get_hf_config
()
if
use_thumbnail
and
max_dynamic_patch
>
1
:
max_dynamic_patch
+=
1
width
=
image_size
*
max_dynamic_patch
height
=
image_size
return
width
,
height
use_thumbnail
=
hf_config
.
use_thumbnail
min_num
=
hf_config
.
min_dynamic_patch
max_num
=
hf_config
.
max_dynamic_patch
image_size
=
hf_config
.
vision_config
.
image_size
if
isinstance
(
data
,
Image
.
Image
):
data
=
image_to_pixel_values
(
data
,
image_size
,
min_num
,
max_num
,
use_thumbnail
=
use_thumbnail
)
# Add an N dimension for number of images per prompt (currently 1).
data
=
data
.
unsqueeze
(
0
)
elif
is_list_of
(
data
,
Image
.
Image
):
# we can't stack here because the images may have different num_patches
data
=
[
image_to_pixel_values
(
img
,
image_size
,
min_num
,
max_num
,
use_thumbnail
=
use_thumbnail
)
for
img
in
data
]
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
)
image_token_id
=
tokenizer
.
encode
(
IMG_CONTEXT
,
add_special_tokens
=
False
,
return_tensors
=
"pt"
)[
0
]
return
MultiModalInputs
({
"pixel_values"
:
data
,
"image_token_id"
:
image_token_id
})
def
dummy_data_for_internvl
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
num_images
=
mm_counts
[
"image"
]
image_feature_size
=
get_max_internvl_image_tokens
(
ctx
)
model_config
=
ctx
.
model_config
hf_config
=
ctx
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
)
seq_data
=
dummy_seq_data_for_clip
(
vision_config
,
seq_len
,
num_images
,
image_token_id
=
tokenizer
.
encode
(
IMG_CONTEXT
,
add_special_tokens
=
False
)[
0
],
image_feature_size_override
=
image_feature_size
,
)
class
InternVLInputPipeline
:
image_size
=
vision_config
.
image_size
min_num
=
hf_config
.
min_dynamic_patch
max_num
=
hf_config
.
max_dynamic_patch
max_image_width
=
max_num
*
image_size
max_image_height
=
min_num
*
image_size
def
__init__
(
self
,
img_start_token
:
str
,
img_end_token
:
str
,
img_context_token
:
str
,
)
->
None
:
super
().
__init__
()
self
.
img_start_token
=
img_start_token
self
.
img_end_token
=
img_end_token
self
.
img_context_token
=
img_context_token
def
_create_image_prompt
(
self
,
feature_size
:
int
,
num_patches
:
int
)
->
str
:
return
(
self
.
img_start_token
+
self
.
img_context_token
*
feature_size
+
self
.
img_end_token
)
mm_data
=
dummy_image_for_clip
(
vision_config
,
num_images
,
image_width_override
=
max_image_width
,
image_height_override
=
max_image_height
,
)
def
_expand_image_prompt
(
self
,
prompt
:
str
,
feature_sizes
:
List
[
int
],
num_patches
:
int
,
)
->
str
:
image_idx
=
sorted
(
map
(
int
,
re
.
findall
(
r
"Image-(\d+): <image>\n"
,
prompt
)))
new_prompt
=
prompt
for
idx
,
feature_size
in
enumerate
(
feature_sizes
,
start
=
1
):
image_prompt
=
self
.
_create_image_prompt
(
feature_size
,
num_patches
)
if
not
image_idx
:
image_prompt
=
f
"Image-
{
idx
}
:
{
image_prompt
}
"
new_prompt
=
new_prompt
.
replace
(
'<image>'
,
image_prompt
,
1
)
return
new_prompt
def
input_processor
(
self
,
ctx
:
InputContext
,
inputs
:
DecoderOnlyInputs
,
*
,
max_dynamic_patch
:
Optional
[
int
]
=
None
,
)
->
DecoderOnlyInputs
:
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
return
inputs
model_config
=
ctx
.
model_config
hf_config
=
ctx
.
get_hf_config
()
image_data
=
multi_modal_data
[
"image"
]
num_patches
=
get_internvl_num_patches
(
hf_config
)
num_blocks_calculator
=
calculate_num_blocks_wrapper
(
hf_config
,
max_dynamic_patch
)
if
isinstance
(
image_data
,
Image
.
Image
):
width
,
height
=
image_data
.
size
num_blocks
,
_
,
_
=
num_blocks_calculator
(
width
,
height
)
image_feature_sizes
=
[
num_blocks
*
num_patches
]
elif
is_list_of
(
image_data
,
Image
.
Image
):
image_feature_sizes
=
[]
for
image
in
image_data
:
width
,
height
=
image
.
size
num_blocks
,
_
,
_
=
num_blocks_calculator
(
width
,
height
)
image_feature_sizes
.
append
(
num_blocks
*
num_patches
)
elif
isinstance
(
image_data
,
torch
.
Tensor
):
num_images
,
image_feature_size
,
hidden_size
=
image_data
.
shape
image_feature_sizes
=
[
image_feature_size
]
else
:
raise
TypeError
(
f
"Invalid image type:
{
type
(
image_data
)
}
"
)
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
)
prompt
=
inputs
.
get
(
"prompt"
)
prompt_token_ids
=
inputs
[
"prompt_token_ids"
]
if
prompt
is
None
:
prompt
=
tokenizer
.
decode
(
prompt_token_ids
)
new_prompt
=
self
.
_expand_image_prompt
(
prompt
,
image_feature_sizes
,
num_patches
)
new_prompt_token_ids
=
tokenizer
.
encode
(
new_prompt
)
return
token_inputs
(
prompt
=
prompt
,
prompt_token_ids
=
new_prompt_token_ids
,
multi_modal_data
=
multi_modal_data
)
def
input_mapper
(
self
,
ctx
:
InputContext
,
data
:
object
,
*
,
max_dynamic_patch
:
Optional
[
int
]
=
None
,
):
hf_config
=
ctx
.
get_hf_config
()
image_pixel_values_mapper
=
image_to_pixel_values_wrapper
(
hf_config
,
max_dynamic_patch
)
if
isinstance
(
data
,
Image
.
Image
):
data
=
image_pixel_values_mapper
(
data
)
# Add an N dimension for number of images per prompt (currently 1).
data
=
data
.
unsqueeze
(
0
)
elif
is_list_of
(
data
,
Image
.
Image
):
# we can't stack here because images may have different num_patches
data
=
[
image_pixel_values_mapper
(
img
)
for
img
in
data
]
else
:
return
MultiModalInputs
({
"image_embeds"
:
data
})
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
)
image_token_id
=
tokenizer
.
encode
(
self
.
img_context_token
,
add_special_tokens
=
False
,
return_tensors
=
"pt"
)[
0
]
return
MultiModalInputs
({
"pixel_values"
:
data
,
"image_token_id"
:
image_token_id
})
def
dummy_data
(
self
,
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
*
,
max_dynamic_patch
:
Optional
[
int
]
=
None
,
):
num_images
=
mm_counts
[
"image"
]
hf_config
=
ctx
.
get_hf_config
()
image_feature_size
=
get_max_internvl_image_tokens
(
ctx
,
max_dynamic_patch
=
max_dynamic_patch
)
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
)
seq_data
=
dummy_seq_data_for_clip
(
hf_config
.
vision_config
,
seq_len
,
num_images
,
image_token_id
=
tokenizer
.
encode
(
self
.
img_context_token
,
add_special_tokens
=
False
)[
0
],
image_feature_size_override
=
image_feature_size
,
)
max_image_width
,
max_image_height
=
get_max_internvl_image_size
(
ctx
,
max_dynamic_patch
=
max_dynamic_patch
)
mm_data
=
dummy_image_for_clip
(
hf_config
.
vision_config
,
num_images
,
image_width_override
=
max_image_width
,
image_height_override
=
max_image_height
,
)
return
seq_data
,
mm_data
return
seq_data
,
mm_data
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
(
input_mapper_for_internvl
)
input_pipeline
=
InternVLInputPipeline
(
IMG_START
,
IMG_END
,
IMG_CONTEXT
)
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
(
input_pipeline
.
input_mapper
)
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_internvl_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_internvl
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_p
rocessor_for_internvl
)
class
InternVLChatModel
(
nn
.
Module
,
SupportsMultiModal
):
@
INPUT_REGISTRY
.
register_dummy_data
(
input_pipeline
.
dummy_data
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_p
ipeline
.
input_processor
)
class
InternVLChatModel
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
...
...
@@ -360,29 +433,40 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
+
vision_feature_layer
+
1
else
:
num_hidden_layers
=
vision_feature_layer
+
1
self
.
vision_model
=
InternVisionModel
(
config
.
vision_config
,
num_hidden_layers_override
=
num_hidden_layers
)
self
.
vision_model
=
self
.
_init_vision_model
(
config
,
num_hidden_layers
)
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
cache_config
,
quant_config
)
vit_hidden_size
=
config
.
vision_config
.
hidden_size
llm_hidden_size
=
config
.
text_config
.
hidden_size
self
.
mlp1
=
nn
.
Sequential
(
nn
.
LayerNorm
(
vit_hidden_size
*
int
(
1
/
self
.
downsample_ratio
)
**
2
),
nn
.
Linear
(
vit_hidden_size
*
int
(
1
/
self
.
downsample_ratio
)
**
2
,
llm_hidden_size
),
nn
.
GELU
(),
nn
.
Linear
(
llm_hidden_size
,
llm_hidden_size
))
self
.
mlp1
=
self
.
_init_mlp1
(
config
)
self
.
img_context_token_id
=
None
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
@
cached_property
def
sampler
(
self
):
if
hasattr
(
self
.
language_model
,
"sampler"
):
self
.
sampler
=
self
.
language_model
.
sampler
else
:
self
.
sampler
=
Sampler
()
return
self
.
language_model
.
sampler
return
Sampler
()
def
_init_vision_model
(
self
,
config
:
PretrainedConfig
,
num_hidden_layers
:
int
):
return
InternVisionModel
(
config
.
vision_config
,
num_hidden_layers_override
=
num_hidden_layers
)
def
_init_mlp1
(
self
,
config
:
PretrainedConfig
)
->
nn
.
Sequential
:
vit_hidden_size
=
config
.
vision_config
.
hidden_size
llm_hidden_size
=
config
.
text_config
.
hidden_size
return
nn
.
Sequential
(
nn
.
LayerNorm
(
vit_hidden_size
*
int
(
1
/
self
.
downsample_ratio
)
**
2
),
nn
.
Linear
(
vit_hidden_size
*
int
(
1
/
self
.
downsample_ratio
)
**
2
,
llm_hidden_size
),
nn
.
GELU
(),
nn
.
Linear
(
llm_hidden_size
,
llm_hidden_size
),
)
def
pixel_shuffle
(
self
,
x
,
scale_factor
=
0.5
):
n
,
w
,
h
,
c
=
x
.
size
()
...
...
@@ -470,7 +554,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
self
,
image_input
:
InternVLImageInputs
,
)
->
torch
.
Tensor
:
if
image_input
[
"type"
]
==
"image_embeds"
:
return
image_input
[
"data"
]
...
...
@@ -487,18 +570,22 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
**
kwargs
:
object
,
)
->
SamplerOutput
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
not
None
and
get_pp_group
().
is_first_rank
:
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
input_ids
)
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
img_context_token_id
)
)
->
Union
[
SamplerOutput
,
IntermediateTensors
]:
if
intermediate_tensors
is
not
None
:
input_ids
=
None
else
:
inputs_embeds
=
None
else
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
not
None
:
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
input_ids
)
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
img_context_token_id
)
input_ids
=
None
else
:
inputs_embeds
=
None
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
...
...
@@ -524,19 +611,5 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# prepare weight iterators for components
weights_group
=
group_weights_with_prefix
(
weights
)
# load vision encoder
self
.
vision_model
.
load_weights
(
weights_group
[
"vision_model"
])
# load mlp projector
mlp_params_dict
=
dict
(
self
.
mlp1
.
named_parameters
())
for
name
,
loaded_weight
in
weights_group
[
"mlp1"
]:
param
=
mlp_params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load llm backbone
self
.
language_model
.
load_weights
(
weights_group
[
"language_model"
])
loader
=
AutoWeightsLoader
(
self
)
loader
.
load_weights
(
weights
)
Prev
1
…
18
19
20
21
22
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment