Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
af7f4372
Commit
af7f4372
authored
Sep 03, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.5' into v0.5.5-dtk24.04.1
parents
5e19cdef
09c77926
Changes
448
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1519 additions
and
166 deletions
+1519
-166
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+7
-2
vllm/model_executor/models/bart.py
vllm/model_executor/models/bart.py
+1001
-0
vllm/model_executor/models/blip.py
vllm/model_executor/models/blip.py
+14
-10
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+97
-38
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+12
-4
vllm/model_executor/models/chameleon.py
vllm/model_executor/models/chameleon.py
+43
-21
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+8
-2
vllm/model_executor/models/clip.py
vllm/model_executor/models/clip.py
+42
-11
vllm/model_executor/models/commandr.py
vllm/model_executor/models/commandr.py
+13
-19
vllm/model_executor/models/dbrx.py
vllm/model_executor/models/dbrx.py
+8
-2
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+7
-2
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+6
-3
vllm/model_executor/models/eagle.py
vllm/model_executor/models/eagle.py
+161
-0
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+5
-2
vllm/model_executor/models/fuyu.py
vllm/model_executor/models/fuyu.py
+45
-25
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+10
-7
vllm/model_executor/models/gemma2.py
vllm/model_executor/models/gemma2.py
+11
-8
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+11
-4
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+13
-4
vllm/model_executor/models/gpt_j.py
vllm/model_executor/models/gpt_j.py
+5
-2
No files found.
Too many changes to show.
To preserve performance only
448 of 448+
files are displayed.
Plain diff
Email patch
vllm/model_executor/models/baichuan.py
View file @
af7f4372
...
...
@@ -345,6 +345,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
@@ -369,8 +371,11 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/bart.py
0 → 100644
View file @
af7f4372
# Derived from BART implementation posted on HuggingFace; license below:
#
# coding=utf-8
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BART model."""
import
math
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
BartConfig
from
transformers.utils
import
logging
from
vllm.attention
import
Attention
,
AttentionMetadata
,
AttentionType
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
logger
=
logging
.
get_logger
(
__name__
)
def
get_bsz_seq_len
(
input_ids
):
shp
=
input_ids
.
shape
ndim
=
len
(
shp
)
if
ndim
==
1
:
return
1
,
input_ids
.
numel
()
else
:
return
shp
[:
2
]
class
BartLearnedPositionalEmbedding
(
VocabParallelEmbedding
):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
):
# Bart is set up so that if padding_idx is
# specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately.
# Other models don't have this hack
self
.
offset
=
2
super
().
__init__
(
num_embeddings
+
self
.
offset
,
embedding_dim
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
attn_type
:
AttentionType
,
)
->
torch
.
Tensor
:
"""`input_ids' shape is expected to be [bsz x seqlen]."""
assert
attn_type
!=
AttentionType
.
ENCODER_DECODER
return
super
().
forward
(
positions
+
self
.
offset
)
class
BartScaledWordEmbedding
(
VocabParallelEmbedding
):
"""
This module overrides VocabParallelEmbedding's
forward by multiplying with embeddings scale.
"""
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
embed_scale
:
float
=
1.0
):
super
().
__init__
(
num_embeddings
,
embedding_dim
)
self
.
embed_scale
=
embed_scale
def
forward
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
super
().
forward
(
input_ids
)
*
self
.
embed_scale
class
BartParallelLMHead
(
ParallelLMHead
):
"""
This module overrides ParallelLMHead's
forward by dividing by embeddings scale,
yielding effectively the inverse of
BartScaledWordEmbedding
"""
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
embed_scale
:
float
=
1.0
):
super
().
__init__
(
num_embeddings
,
embedding_dim
)
self
.
embed_scale
=
embed_scale
def
forward
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
super
().
forward
(
input_ids
)
/
self
.
embed_scale
class
BartEncoderAttention
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
:
int
,
num_heads
:
int
,
bias
:
bool
=
True
,
config
:
Optional
[
BartConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
d_model
=
config
.
d_model
self
.
embed_dim
=
embed_dim
self
.
total_num_heads
=
num_heads
self
.
total_num_kv_heads
=
self
.
total_num_heads
self
.
head_dim
=
embed_dim
//
num_heads
self
.
config
=
config
if
(
self
.
head_dim
*
num_heads
)
!=
self
.
embed_dim
:
raise
ValueError
(
f
"embed_dim must be divisible by num_heads "
f
"(got `embed_dim`:
{
self
.
embed_dim
}
"
f
" and `num_heads`:
{
num_heads
}
)."
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
d_model
,
self
.
d_model
//
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
bias
,
quant_config
=
quant_config
,
)
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
quant_config
=
quant_config
,
)
tp_world_size
=
get_tensor_model_parallel_world_size
()
assert
self
.
total_num_heads
%
tp_world_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_world_size
if
self
.
total_num_kv_heads
>=
tp_world_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_world_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_world_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_world_size
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
)
->
torch
.
Tensor
:
"""Input shape: Batch x Time x Channel"""
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
,
attn_type
=
AttentionType
.
ENCODER
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
class
BartDecoderSelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
:
int
,
num_heads
:
int
,
bias
:
bool
=
True
,
config
:
Optional
[
BartConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
d_model
=
config
.
d_model
self
.
embed_dim
=
embed_dim
self
.
total_num_heads
=
num_heads
self
.
total_num_kv_heads
=
self
.
total_num_heads
self
.
head_dim
=
embed_dim
//
num_heads
self
.
config
=
config
if
(
self
.
head_dim
*
num_heads
)
!=
self
.
embed_dim
:
raise
ValueError
(
f
"embed_dim must be divisible by num_heads "
f
"(got `embed_dim`:
{
self
.
embed_dim
}
"
f
" and `num_heads`:
{
num_heads
}
)."
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
d_model
,
self
.
d_model
//
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
bias
,
quant_config
=
quant_config
,
)
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
quant_config
=
quant_config
,
)
tp_world_size
=
get_tensor_model_parallel_world_size
()
assert
self
.
total_num_heads
%
tp_world_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_world_size
if
self
.
total_num_kv_heads
>=
tp_world_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_world_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_world_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_world_size
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
)
->
torch
.
Tensor
:
"""Input shape: Batch x Time x Channel"""
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
,
attn_type
=
AttentionType
.
DECODER
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
class
BartCrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
:
int
,
num_heads
:
int
,
bias
:
bool
=
True
,
config
:
Optional
[
BartConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
d_model
=
config
.
d_model
self
.
embed_dim
=
embed_dim
self
.
total_num_heads
=
num_heads
self
.
total_num_kv_heads
=
self
.
total_num_heads
self
.
head_dim
=
embed_dim
//
num_heads
self
.
config
=
config
if
(
self
.
head_dim
*
num_heads
)
!=
self
.
embed_dim
:
raise
ValueError
(
f
"embed_dim must be divisible by num_heads "
f
"(got `embed_dim`:
{
self
.
embed_dim
}
"
f
" and `num_heads`:
{
num_heads
}
)."
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
d_model
,
self
.
d_model
//
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
bias
,
quant_config
=
quant_config
,
)
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
quant_config
=
quant_config
,
)
tp_world_size
=
get_tensor_model_parallel_world_size
()
assert
self
.
total_num_heads
%
tp_world_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_world_size
if
self
.
total_num_kv_heads
>=
tp_world_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_world_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_world_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_world_size
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
def
forward
(
self
,
decoder_hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Input shape: Batch x Time x Channel"""
# (afeldman-nm 2024/07/22) TODO:
# Need a more efficient solution for q/k/v
qkv_dec
,
_
=
self
.
qkv_proj
(
decoder_hidden_states
)
q
,
_
,
_
=
qkv_dec
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
if
encoder_hidden_states
is
None
:
k
=
None
v
=
None
else
:
qkv_enc
,
_
=
self
.
qkv_proj
(
encoder_hidden_states
)
_
,
k
,
v
=
qkv_enc
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
,
attn_type
=
AttentionType
.
ENCODER_DECODER
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
class
BartEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BartConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
embed_dim
=
config
.
d_model
self
.
self_attn
=
BartEncoderAttention
(
embed_dim
=
self
.
embed_dim
,
num_heads
=
config
.
encoder_attention_heads
,
config
=
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
)
self
.
activation_fn
=
get_act_fn
(
config
.
activation_function
,
quant_config
)
ffn_hidden_size
=
self
.
embed_dim
ffn_intermediate_size
=
config
.
encoder_ffn_dim
ffn_has_bias
=
True
self
.
fc1
=
ColumnParallelLinear
(
ffn_hidden_size
,
ffn_intermediate_size
,
bias
=
ffn_has_bias
,
quant_config
=
quant_config
,
)
self
.
act
=
get_act_fn
(
"gelu"
,
quant_config
,
ffn_intermediate_size
)
self
.
fc2
=
RowParallelLinear
(
ffn_intermediate_size
,
ffn_hidden_size
,
bias
=
ffn_has_bias
,
quant_config
=
quant_config
,
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
)
->
torch
.
Tensor
:
r
"""
Args:
hidden_states
torch.Tensor of *encoder* input embeddings.
kv_cache:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Encoder layer output torch.Tensor
"""
residual
=
hidden_states
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
residual
=
hidden_states
fc1_out
,
_
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
fc1_out
)
hidden_states
,
_
=
self
.
fc2
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
if
hidden_states
.
dtype
==
torch
.
float16
and
(
torch
.
isinf
(
hidden_states
).
any
()
or
torch
.
isnan
(
hidden_states
).
any
()):
clamp_value
=
torch
.
finfo
(
hidden_states
.
dtype
).
max
-
1000
hidden_states
=
torch
.
clamp
(
hidden_states
,
min
=-
clamp_value
,
max
=
clamp_value
)
return
hidden_states
class
BartDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BartConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
embed_dim
=
config
.
d_model
self
.
self_attn
=
BartDecoderSelfAttention
(
embed_dim
=
self
.
embed_dim
,
num_heads
=
config
.
decoder_attention_heads
,
config
=
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
self
.
activation_fn
=
get_act_fn
(
config
.
activation_function
,
quant_config
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
)
'''
afeldman-nm: personally I would call this "cross-attention",
however I left the name as "encoder_attn" to maintain consistency
with the name of the pretrained weights.
'''
self
.
encoder_attn
=
BartCrossAttention
(
self
.
embed_dim
,
config
.
decoder_attention_heads
,
config
=
config
,
)
self
.
encoder_attn_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
)
ffn_hidden_size
=
self
.
embed_dim
ffn_intermediate_size
=
config
.
encoder_ffn_dim
ffn_has_bias
=
True
self
.
fc1
=
ColumnParallelLinear
(
ffn_hidden_size
,
ffn_intermediate_size
,
bias
=
ffn_has_bias
,
quant_config
=
quant_config
,
)
self
.
fc2
=
RowParallelLinear
(
ffn_intermediate_size
,
ffn_hidden_size
,
bias
=
ffn_has_bias
,
quant_config
=
quant_config
,
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
)
def
forward
(
self
,
decoder_hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
r
"""
Args:
decoder_hidden_states
torch.Tensor of *decoder* input embeddings.
kv_cache:
KV cache tensor
attn_metadata:
vLLM Attention metadata structure
encoder_hidden_states
torch.Tensor of *encoder* input embeddings.
Returns:
Decoder layer output torch.Tensor
"""
residual
=
decoder_hidden_states
# Self Attention
hidden_states
=
self
.
self_attn
(
hidden_states
=
decoder_hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
# Cross-Attention Block
residual
=
hidden_states
hidden_states
=
self
.
encoder_attn
(
decoder_hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
encoder_hidden_states
=
encoder_hidden_states
,
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
self
.
encoder_attn_layer_norm
(
hidden_states
)
# Fully Connected
residual
=
hidden_states
fc1_out
,
_
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
fc1_out
)
hidden_states
,
_
=
self
.
fc2
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
return
hidden_states
class
BartEncoder
(
nn
.
Module
):
"""
Transformer encoder consisting of *config.encoder_layers*
self attention layers. Each layer is a [`BartEncoderLayer`].
Args:
config: BartConfig
embed_tokens (nn.Embedding): output embedding
"""
def
__init__
(
self
,
config
:
BartConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
embed_tokens
:
Optional
[
nn
.
Embedding
]
=
None
):
super
().
__init__
()
self
.
cache_config
=
cache_config
self
.
quant_config
=
quant_config
self
.
lora_config
=
lora_config
embed_dim
=
config
.
d_model
self
.
max_source_positions
=
config
.
max_position_embeddings
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
config
.
scale_embedding
else
1.0
self
.
embed_tokens
=
BartScaledWordEmbedding
(
config
.
vocab_size
,
embed_dim
,
embed_scale
=
embed_scale
)
if
embed_tokens
is
not
None
:
self
.
embed_tokens
.
weight
=
embed_tokens
.
weight
self
.
embed_positions
=
BartLearnedPositionalEmbedding
(
config
.
max_position_embeddings
,
embed_dim
,
)
self
.
layers
=
nn
.
ModuleList
(
[
BartEncoderLayer
(
config
,
cache_config
,
quant_config
)
\
for
_
in
range
(
config
.
encoder_layers
)])
self
.
layernorm_embedding
=
nn
.
LayerNorm
(
embed_dim
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
)
->
torch
.
Tensor
:
r
"""
Args:
input_ids
Indices of *encoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
positions
Positions of *encoder* input sequence tokens.
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Decoder output torch.Tensor
"""
# retrieve input_ids and inputs_embeds
input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
shape
[
-
1
])
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
embed_pos
=
self
.
embed_positions
(
positions
,
AttentionType
.
ENCODER
,
)
embed_pos
=
embed_pos
.
to
(
inputs_embeds
.
device
)
hidden_states
=
inputs_embeds
+
embed_pos
hidden_states
=
self
.
layernorm_embedding
(
hidden_states
)
for
idx
,
encoder_layer
in
enumerate
(
self
.
layers
):
hidden_states
=
encoder_layer
(
hidden_states
=
hidden_states
,
kv_cache
=
kv_caches
[
idx
],
attn_metadata
=
attn_metadata
,
)
return
hidden_states
class
BartDecoder
(
nn
.
Module
):
"""
Transformer decoder consisting of *config.decoder_layers* layers.
Each layer is a [`BartDecoderLayer`]
Args:
config: BartConfig
embed_tokens (nn.Embedding): output embedding
"""
def
__init__
(
self
,
config
:
BartConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
embed_tokens
:
Optional
[
nn
.
Embedding
]
=
None
,
):
super
().
__init__
()
self
.
cache_config
=
cache_config
self
.
quant_config
=
quant_config
self
.
lora_config
=
lora_config
self
.
max_target_positions
=
config
.
max_position_embeddings
embed_scale
=
math
.
sqrt
(
config
.
d_model
)
if
config
.
scale_embedding
else
1.0
self
.
embed_tokens
=
BartScaledWordEmbedding
(
config
.
vocab_size
,
config
.
d_model
,
embed_scale
=
embed_scale
)
if
embed_tokens
is
not
None
:
self
.
embed_tokens
.
weight
=
embed_tokens
.
weight
self
.
embed_positions
=
BartLearnedPositionalEmbedding
(
config
.
max_position_embeddings
,
config
.
d_model
,
)
self
.
layers
=
nn
.
ModuleList
(
[
BartDecoderLayer
(
config
,
cache_config
,
quant_config
)
\
for
_
in
range
(
config
.
decoder_layers
)])
self
.
layernorm_embedding
=
nn
.
LayerNorm
(
config
.
d_model
)
def
forward
(
self
,
decoder_input_ids
:
torch
.
Tensor
,
decoder_positions
:
torch
.
Tensor
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
)
->
torch
.
Tensor
:
r
"""
Args:
decoder_input_ids
Indices of *decoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
decoder_positions
Positions of *decoder* input sequence tokens.
encoder_hidden_states:
Tensor of encoder output embeddings
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Decoder output torch.Tensor
"""
inputs_embeds
=
self
.
embed_tokens
(
decoder_input_ids
)
# embed positions
embed_pos
=
self
.
embed_positions
(
decoder_positions
,
AttentionType
.
DECODER
,
)
embed_pos
=
embed_pos
.
to
(
inputs_embeds
.
device
)
hidden_states
=
inputs_embeds
+
embed_pos
hidden_states
=
self
.
layernorm_embedding
(
hidden_states
)
# decoder layers
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
):
hidden_states
=
decoder_layer
(
decoder_hidden_states
=
hidden_states
,
kv_cache
=
kv_caches
[
idx
],
attn_metadata
=
attn_metadata
,
encoder_hidden_states
=
encoder_hidden_states
,
)
return
hidden_states
class
BartModel
(
nn
.
Module
):
_tied_weights_keys
=
[
"encoder.embed_tokens.weight"
,
"decoder.embed_tokens.weight"
]
def
__init__
(
self
,
config
:
BartConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
self
.
org_vocab_size
=
config
.
vocab_size
self
.
encoder
=
BartEncoder
(
config
,
cache_config
,
quant_config
=
quant_config
)
self
.
decoder
=
BartDecoder
(
config
,
cache_config
,
quant_config
=
quant_config
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
encoder_input_ids
:
torch
.
Tensor
,
encoder_positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
)
->
torch
.
Tensor
:
r
"""
Args:
input_ids
Indices of *decoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
positions
Positions of *decoder* input sequence tokens.
encoder_input_ids
Indices of *encoder* input sequence tokens in the vocabulary.
encoder_positions:
Positions of *encoder* input sequence tokens.
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Model output torch.Tensor
"""
encoder_hidden_states
=
None
if
encoder_input_ids
.
numel
()
>
0
:
# Run encoder attention if a non-zero number of encoder tokens
# are provided as input
encoder_hidden_states
=
self
.
encoder
(
input_ids
=
encoder_input_ids
,
positions
=
encoder_positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
)
# decoder outputs consists of
# (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs
=
self
.
decoder
(
decoder_input_ids
=
input_ids
,
decoder_positions
=
positions
,
encoder_hidden_states
=
encoder_hidden_states
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
)
return
decoder_outputs
class
BartForConditionalGeneration
(
nn
.
Module
):
base_model_prefix
=
"model"
def
__init__
(
self
,
config
:
BartConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
):
super
().
__init__
()
# currently all existing BART models have `tie_word_embeddings` enabled
assert
config
.
tie_word_embeddings
self
.
config
=
config
self
.
model
=
BartModel
(
config
,
cache_config
,
quant_config
,
lora_config
=
lora_config
)
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
embed_scale
=
math
.
sqrt
(
config
.
d_model
)
if
config
.
scale_embedding
else
1.0
self
.
lm_head
=
BartParallelLMHead
(
config
.
vocab_size
,
config
.
d_model
,
embed_scale
=
embed_scale
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
encoder_input_ids
:
torch
.
Tensor
,
encoder_positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
r
"""
Args:
input_ids
torch.Tensor of *decoder* input token ids.
positions
torch.Tensor of *decoder* position indices.
encoder_input_ids
torch.Tensor of *encoder* input token ids.
encoder_positions
torch.Tensor of *encoder* position indices
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Output torch.Tensor
"""
return
self
.
model
(
input_ids
,
positions
,
encoder_input_ids
,
encoder_positions
,
kv_caches
,
attn_metadata
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
stacked_params_mapping
=
{
"q_proj"
:
{
"param_name"
:
"qkv_proj"
,
"shard_id"
:
"q"
,
},
"k_proj"
:
{
"param_name"
:
"qkv_proj"
,
"shard_id"
:
"k"
,
},
"v_proj"
:
{
"param_name"
:
"qkv_proj"
,
"shard_id"
:
"v"
,
},
}
params_mapping
=
{
"beta"
:
"bias"
,
"gamma"
:
"weight"
,
"LayerNorm"
:
"layernorm"
,
}
def
_rename_key
(
self
,
key
:
str
):
prefix
=
f
"
{
self
.
base_model_prefix
}
."
key
=
key
[
len
(
prefix
):]
if
key
.
startswith
(
prefix
)
else
key
for
src
,
dst
in
self
.
params_mapping
.
items
():
key
=
key
.
replace
(
src
,
dst
)
return
key
def
_rename_stacked_param
(
self
,
name
:
str
,
)
->
Tuple
[
str
,
Optional
[
str
]]:
for
key
,
mapping
in
self
.
stacked_params_mapping
.
items
():
if
key
in
name
:
name
=
name
.
replace
(
key
,
mapping
[
"param_name"
])
return
name
,
mapping
[
"shard_id"
]
return
name
,
None
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
model_params_dict
=
dict
(
self
.
model
.
named_parameters
())
top_params_dict
=
dict
(
self
.
named_parameters
())
weights_tuple_list
=
list
(
weights
)
shared_embedding_weight
=
None
shared_embedding_shard_id
=
None
for
name
,
loaded_weight
in
weights_tuple_list
:
name
=
self
.
_rename_key
(
name
)
name
,
shard_id
=
self
.
_rename_stacked_param
(
name
)
if
(
'shared.weight'
in
name
or
'encoder.embed_tokens.weight'
in
name
or
'decoder.embed_tokens.weight'
in
name
or
'lm_head.weight'
in
name
):
assert
shared_embedding_weight
is
None
,
(
"Conflicting embedding weights."
)
shared_embedding_weight
=
loaded_weight
shared_embedding_shard_id
=
shard_id
else
:
# Skip the specific downstream task weight.
if
name
.
startswith
(
'cls.'
):
continue
# use Pooler instead.
if
name
.
startswith
(
'pooler.'
):
continue
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
model_params_dict
:
continue
param
=
model_params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
if
shard_id
:
weight_loader
(
param
,
loaded_weight
,
shard_id
)
else
:
weight_loader
(
param
,
loaded_weight
)
# Assign shared weight values
encoder_in_param
=
model_params_dict
[
'encoder.embed_tokens.weight'
]
encoder_in_weight_loader
=
getattr
(
encoder_in_param
,
"weight_loader"
,
default_weight_loader
)
decoder_in_param
=
model_params_dict
[
'decoder.embed_tokens.weight'
]
decoder_in_weight_loader
=
getattr
(
decoder_in_param
,
"weight_loader"
,
default_weight_loader
)
lm_head_in_param
=
top_params_dict
[
'lm_head.weight'
]
lm_head_in_weight_loader
=
getattr
(
lm_head_in_param
,
"weight_loader"
,
default_weight_loader
)
assert
shared_embedding_weight
is
not
None
if
shared_embedding_shard_id
:
encoder_in_weight_loader
(
encoder_in_param
,
shared_embedding_weight
,
shared_embedding_shard_id
)
decoder_in_weight_loader
(
decoder_in_param
,
shared_embedding_weight
,
shared_embedding_shard_id
)
lm_head_in_weight_loader
(
lm_head_in_param
,
shared_embedding_weight
,
shared_embedding_shard_id
)
else
:
encoder_in_weight_loader
(
encoder_in_param
,
shared_embedding_weight
)
decoder_in_weight_loader
(
decoder_in_param
,
shared_embedding_weight
)
lm_head_in_weight_loader
(
lm_head_in_param
,
shared_embedding_weight
)
vllm/model_executor/models/blip.py
View file @
af7f4372
"""Minimal implementation of BlipVisionModel intended to be only used
within a vision language model."""
from
array
import
array
from
typing
import
Optional
,
Union
import
torch
...
...
@@ -14,9 +15,9 @@ from vllm.model_executor.layers.activation import get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.multimodal.
image
import
(
cached_get_tokenizer
,
repeat_and_pad_
image
_tokens
)
from
vllm.sequence
import
SequenceData
from
vllm.multimodal.
utils
import
(
cached_get_tokenizer
,
repeat_and_pad_
placeholder
_tokens
)
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
def
get_blip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
...
...
@@ -31,13 +32,13 @@ def get_blip_num_patches(*, image_size: int, patch_size: int) -> int:
def
get_blip_image_feature_size
(
hf_config
:
Union
[
BlipVisionConfig
,
Blip2VisionConfig
]
,
)
->
int
:
hf_config
:
Union
[
BlipVisionConfig
,
Blip2VisionConfig
])
->
int
:
return
get_blip_num_patches
(
image_size
=
hf_config
.
image_size
,
patch_size
=
hf_config
.
patch_size
)
def
get_max_blip_image_tokens
(
hf_config
:
Union
[
BlipVisionConfig
,
Blip2VisionConfig
]
,
)
->
int
:
hf_config
:
Union
[
BlipVisionConfig
,
Blip2VisionConfig
])
->
int
:
return
get_blip_image_feature_size
(
hf_config
)
...
...
@@ -53,13 +54,16 @@ def dummy_seq_data_for_blip(
else
:
image_feature_size
=
image_feature_size_override
token_ids
=
[
image_token_id
]
*
image_feature_size
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
)
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
image_token_id
])
*
image_feature_size
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
image_feature_size
)
return
SequenceData
(
token_ids
)
def
dummy_image_for_blip
(
hf_config
:
Union
[
BlipVisionConfig
,
Blip2VisionConfig
],
num_images
:
int
,
*
,
image_width_override
:
Optional
[
int
]
=
None
,
image_height_override
:
Optional
[
int
]
=
None
,
...
...
@@ -71,7 +75,7 @@ def dummy_image_for_blip(
height
=
image_height_override
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
)
return
{
"image"
:
image
}
return
{
"image"
:
image
if
num_images
==
1
else
[
image
]
*
num_images
}
def
input_processor_for_blip
(
...
...
@@ -93,11 +97,11 @@ def input_processor_for_blip(
else
:
image_feature_size
=
image_feature_size_override
new_prompt
,
new_token_ids
=
repeat_and_pad_
image
_tokens
(
new_prompt
,
new_token_ids
=
repeat_and_pad_
placeholder
_tokens
(
tokenizer
,
llm_inputs
.
get
(
"prompt"
),
llm_inputs
[
"prompt_token_ids"
],
image
_token_id
=
image_token_id
,
placeholder
_token_id
=
image_token_id
,
repeat_count
=
image_feature_size
,
)
...
...
vllm/model_executor/models/blip2.py
View file @
af7f4372
from
typing
import
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
from
array
import
array
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
import
torch
import
torch.nn
as
nn
...
...
@@ -16,18 +18,42 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.models.opt
import
OPTModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
,
SequenceData
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SamplerOutput
,
SequenceData
)
from
.blip
import
(
BlipVisionModel
,
dummy_image_for_blip
,
get_max_blip_image_tokens
)
from
.interfaces
import
Supports
Vision
from
.utils
import
merge_
vision
_embeddings
from
.interfaces
import
Supports
MultiModal
from
.utils
import
merge_
multimodal
_embeddings
_KEYS_TO_MODIFY_MAPPING
=
{
"language_model.lm_head"
:
"lm_head"
,
"language_model.model"
:
"language_model"
,
}
# We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo
BLIP2_IMAGE_TOKEN
=
"<image>"
BLIP2_IMAGE_TOKEN_ID
=
50265
class
Blip2ImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
data
:
torch
.
Tensor
"""Shape: (batch_size, num_channels, height, width)"""
class
Blip2ImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
data
:
torch
.
Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
Blip2ImageInputs
=
Union
[
Blip2ImagePixelInputs
,
Blip2ImageEmbeddingInputs
]
class
Blip2QFormerMultiHeadAttention
(
nn
.
Module
):
...
...
@@ -375,20 +401,6 @@ class Blip2QFormerModel(nn.Module):
return
sequence_output
class
Blip2ImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
data
:
torch
.
Tensor
"""Shape: (batch_size, num_channels, height, width)"""
Blip2ImageInputs
=
Blip2ImagePixelInputs
# We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo
BLIP2_IMAGE_TOKEN
=
"<image>"
BLIP2_IMAGE_TOKEN_ID
=
50265
def
get_blip2_image_feature_size
(
hf_config
:
Blip2Config
)
->
int
:
return
hf_config
.
num_query_tokens
...
...
@@ -404,17 +416,41 @@ def get_max_blip2_image_tokens(ctx: InputContext):
raise
NotImplementedError
(
msg
)
def
dummy_data_for_blip2
(
ctx
:
InputContext
,
seq_len
:
int
):
def
dummy_seq_data_for_blip2
(
hf_config
:
Blip2Config
,
seq_len
:
int
,
num_images
:
int
,
*
,
image_token_id
:
int
,
image_feature_size_override
:
Optional
[
int
]
=
None
,
):
if
image_feature_size_override
is
None
:
image_feature_size
=
get_blip2_image_feature_size
(
hf_config
)
else
:
image_feature_size
=
image_feature_size_override
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
image_token_id
])
*
image_feature_size
*
num_images
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
image_feature_size
*
num_images
)
return
SequenceData
(
token_ids
)
def
dummy_data_for_blip2
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
hf_config
=
ctx
.
get_hf_config
(
Blip2Config
)
vision_config
=
hf_config
.
vision_config
num_images
=
mm_counts
[
"image"
]
image_feature_size
=
get_blip2_image_feature_size
(
hf_config
)
token_ids
=
[
BLIP2_IMAGE_TOKEN_ID
]
*
image_feature_size
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
)
seq_data
=
SequenceData
(
token_ids
)
seq_data
=
dummy_seq_data_for_blip2
(
hf_config
,
seq_len
,
num_images
,
image_token_id
=
BLIP2_IMAGE_TOKEN_ID
,
)
if
isinstance
(
vision_config
,
Blip2VisionConfig
):
mm_data
=
dummy_image_for_blip
(
vision_config
)
mm_data
=
dummy_image_for_blip
(
vision_config
,
num_images
)
return
seq_data
,
mm_data
...
...
@@ -448,7 +484,7 @@ def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_blip2_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_blip2
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_blip2
)
class
Blip2ForConditionalGeneration
(
nn
.
Module
,
Supports
Vision
):
class
Blip2ForConditionalGeneration
(
nn
.
Module
,
Supports
MultiModal
):
def
__init__
(
self
,
config
:
Blip2Config
,
...
...
@@ -458,6 +494,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
super
().
__init__
()
# currently all existing BLIP-2 models have `tie_word_embeddings`
# enabled
assert
config
.
tie_word_embeddings
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
...
...
@@ -506,18 +545,31 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
Blip2ImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
if
pixel_values
is
None
:
if
pixel_values
is
None
and
image_embeds
is
None
:
return
None
if
not
isinstance
(
pixel_values
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
if
pixel_values
is
not
None
:
if
not
isinstance
(
pixel_values
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
return
Blip2ImagePixelInputs
(
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
pixel_values
),
)
return
Blip2ImagePixelInputs
(
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
pixel_values
),
)
if
image_embeds
is
not
None
:
if
not
isinstance
(
image_embeds
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of image embeddings. "
f
"Got type:
{
type
(
image_embeds
)
}
"
)
return
Blip2ImageEmbeddingInputs
(
type
=
"image_embeds"
,
data
=
image_embeds
,
)
raise
AssertionError
(
"This line should be unreachable."
)
def
_image_pixels_to_features
(
self
,
vision_model
:
BlipVisionModel
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -538,6 +590,10 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
def
_process_image_input
(
self
,
image_input
:
Blip2ImageInputs
)
->
torch
.
Tensor
:
if
image_input
[
"type"
]
==
"image_embeds"
:
return
image_input
[
"data"
]
assert
self
.
vision_model
is
not
None
image_features
=
self
.
_process_image_pixels
(
image_input
)
...
...
@@ -595,9 +651,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
merge_
vision
_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
BLIP2_IMAGE_TOKEN_ID
)
inputs_embeds
=
merge_
multimodal
_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
BLIP2_IMAGE_TOKEN_ID
)
input_ids
=
None
else
:
...
...
@@ -611,8 +667,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
get_lm_head
(),
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/bloom.py
View file @
af7f4372
...
...
@@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
...
...
@@ -276,7 +276,12 @@ class BloomForCausalLM(nn.Module):
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
transformer
=
BloomModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head
=
self
.
transformer
.
word_embeddings
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
transformer
.
word_embeddings
else
:
self
.
lm_head
=
ParallelLMHead
(
self
.
config
.
vocab_size
,
self
.
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
@@ -292,8 +297,11 @@ class BloomForCausalLM(nn.Module):
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/chameleon.py
View file @
af7f4372
from
array
import
array
from
functools
import
cached_property
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
)
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
)
import
torch
import
torch.nn.functional
as
F
...
...
@@ -19,21 +20,23 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
row_parallel_weight_loader
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.image
import
(
cached_get_tokenizer
,
repeat_and_pad_image_tokens
)
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
,
SequenceData
from
vllm.multimodal.utils
import
(
cached_get_tokenizer
,
repeat_and_pad_placeholder_tokens
)
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SamplerOutput
,
SequenceData
)
from
vllm.utils
import
print_warning_once
from
.interfaces
import
Supports
Vision
from
.interfaces
import
Supports
MultiModal
logger
=
init_logger
(
__name__
)
...
...
@@ -59,6 +62,7 @@ def get_max_chameleon_image_tokens(ctx: InputContext):
def
dummy_seq_data_for_chameleon
(
seq_len
:
int
,
num_images
:
int
,
*
,
image_token_id
:
int
,
image_feature_size_override
:
Optional
[
int
]
=
None
,
...
...
@@ -68,12 +72,16 @@ def dummy_seq_data_for_chameleon(
else
:
image_feature_size
=
image_feature_size_override
token_ids
=
[
image_token_id
]
*
image_feature_size
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
)
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
image_token_id
])
*
image_feature_size
*
num_images
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
image_feature_size
*
num_images
)
return
SequenceData
(
token_ids
)
def
dummy_image_for_chameleon
(
num_images
:
int
,
*
,
image_width_override
:
Optional
[
int
]
=
None
,
image_height_override
:
Optional
[
int
]
=
None
,
):
...
...
@@ -85,17 +93,20 @@ def dummy_image_for_chameleon(
height
=
image_height_override
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
)
return
{
"image"
:
image
}
return
{
"image"
:
image
if
num_images
==
1
else
[
image
]
*
num_images
}
def
dummy_data_for_chameleon
(
ctx
:
InputContext
,
seq_len
:
int
):
def
dummy_data_for_chameleon
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
num_images
=
mm_counts
[
"image"
]
seq_data
=
dummy_seq_data_for_chameleon
(
seq_len
,
num_images
,
image_token_id
=
CHAMELEON_IMAGE_TOKEN_ID
,
)
mm_data
=
dummy_image_for_chameleon
()
mm_data
=
dummy_image_for_chameleon
(
num_images
)
return
seq_data
,
mm_data
...
...
@@ -113,11 +124,11 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs):
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
)
new_prompt
,
new_token_ids
=
repeat_and_pad_
image
_tokens
(
new_prompt
,
new_token_ids
=
repeat_and_pad_
placeholder
_tokens
(
tokenizer
,
llm_inputs
.
get
(
"prompt"
),
llm_inputs
[
"prompt_token_ids"
],
image
_token_id
=
CHAMELEON_IMAGE_TOKEN_ID
,
placeholder
_token_id
=
CHAMELEON_IMAGE_TOKEN_ID
,
repeat_count
=
CHAMELEON_IMAGE_SEQ_LENGTH
,
pad_token_left
=
CHAMELEON_IMAGE_START_TOKEN_ID
,
pad_token_right
=
CHAMELEON_IMAGE_END_TOKEN_ID
,
...
...
@@ -141,6 +152,11 @@ class ChameleonLayerNorm(nn.LayerNorm):
super
().
__init__
(
hidden_size
,
*
args
,
**
kwargs
)
self
.
normalized_shape
=
(
hidden_size
[
-
1
],
)
set_weight_attrs
(
self
.
weight
,
{
"weight_loader"
:
row_parallel_weight_loader
})
set_weight_attrs
(
self
.
bias
,
{
"weight_loader"
:
row_parallel_weight_loader
})
def
forward
(
self
,
hidden_states
):
hidden_states
=
F
.
layer_norm
(
hidden_states
,
self
.
normalized_shape
,
...
...
@@ -697,6 +713,8 @@ class ChameleonVQVAEEncoder(nn.Module):
)
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
):
pixel_values
=
pixel_values
.
to
(
self
.
conv_in
.
weight
.
dtype
)
# downsampling
hidden_states
=
[
self
.
conv_in
(
pixel_values
)]
for
i_level
in
range
(
self
.
num_resolutions
):
...
...
@@ -877,7 +895,7 @@ class ChameleonModel(nn.Module):
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_chameleon_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_chameleon
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_chameleon
)
class
ChameleonForConditionalGeneration
(
nn
.
Module
,
Supports
Vision
):
class
ChameleonForConditionalGeneration
(
nn
.
Module
,
Supports
MultiModal
):
def
__init__
(
self
,
...
...
@@ -959,15 +977,19 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsVision):
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
# Disallow image tokens which does not include special
# begin-image and end-image tokens
image_tokens
=
self
.
model
.
vocabulary_mapping
.
image_tokens
logits
[:,
image_tokens
]
=
torch
.
finfo
(
logits
.
dtype
).
min
if
logits
is
not
None
:
image_tokens
=
self
.
model
.
vocabulary_mapping
.
image_tokens
logits
[:,
image_tokens
]
=
torch
.
finfo
(
logits
.
dtype
).
min
return
logits
...
...
vllm/model_executor/models/chatglm.py
View file @
af7f4372
...
...
@@ -368,6 +368,9 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
self
.
max_position_embeddings
=
getattr
(
config
,
"max_sequence_length"
,
8192
)
self
.
transformer
=
ChatGLMModel
(
config
,
cache_config
,
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
transformer
.
output_layer
.
weight
=
(
self
.
transformer
.
embedding
.
weight
)
self
.
lm_head
=
self
.
transformer
.
output_layer
self
.
logits_processor
=
LogitsProcessor
(
config
.
padded_vocab_size
)
self
.
sampler
=
Sampler
()
...
...
@@ -393,8 +396,11 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/clip.py
View file @
af7f4372
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
from
typing
import
Optional
from
array
import
array
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
import
torch.nn
as
nn
...
...
@@ -14,9 +15,10 @@ from vllm.model_executor.layers.activation import get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.multimodal.image
import
(
cached_get_tokenizer
,
repeat_and_pad_image_tokens
)
from
vllm.sequence
import
SequenceData
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.multimodal.utils
import
(
cached_get_tokenizer
,
repeat_and_pad_placeholder_tokens
)
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
def
get_clip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
...
...
@@ -32,7 +34,7 @@ def get_clip_num_patches(*, image_size: int, patch_size: int) -> int:
def
get_clip_image_feature_size
(
hf_config
:
CLIPVisionConfig
)
->
int
:
return
get_clip_num_patches
(
image_size
=
hf_config
.
image_size
,
patch_size
=
hf_config
.
patch_size
)
patch_size
=
hf_config
.
patch_size
)
+
1
def
get_max_clip_image_tokens
(
hf_config
:
CLIPVisionConfig
)
->
int
:
...
...
@@ -42,6 +44,7 @@ def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
def
dummy_seq_data_for_clip
(
hf_config
:
CLIPVisionConfig
,
seq_len
:
int
,
num_images
:
int
,
*
,
image_token_id
:
int
,
image_feature_size_override
:
Optional
[
int
]
=
None
,
...
...
@@ -51,13 +54,16 @@ def dummy_seq_data_for_clip(
else
:
image_feature_size
=
image_feature_size_override
token_ids
=
[
image_token_id
]
*
image_feature_size
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
)
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
image_token_id
])
*
image_feature_size
*
num_images
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
image_feature_size
*
num_images
)
return
SequenceData
(
token_ids
)
def
dummy_image_for_clip
(
hf_config
:
CLIPVisionConfig
,
num_images
:
int
,
*
,
image_width_override
:
Optional
[
int
]
=
None
,
image_height_override
:
Optional
[
int
]
=
None
,
...
...
@@ -69,7 +75,7 @@ def dummy_image_for_clip(
height
=
image_height_override
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
)
return
{
"image"
:
image
}
return
{
"image"
:
image
if
num_images
==
1
else
[
image
]
*
num_images
}
def
input_processor_for_clip
(
...
...
@@ -87,15 +93,21 @@ def input_processor_for_clip(
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
)
if
image_feature_size_override
is
None
:
image_feature_size
=
get_clip_image_feature_size
(
hf_config
)
image_data
=
multi_modal_data
[
"image"
]
if
isinstance
(
image_data
,
Image
.
Image
):
image_feature_size
=
get_clip_image_feature_size
(
hf_config
)
elif
isinstance
(
image_data
,
torch
.
Tensor
):
image_feature_size
=
image_data
.
shape
[
0
]
else
:
raise
TypeError
(
f
"Invalid image type:
{
type
(
image_data
)
}
"
)
else
:
image_feature_size
=
image_feature_size_override
new_prompt
,
new_token_ids
=
repeat_and_pad_
image
_tokens
(
new_prompt
,
new_token_ids
=
repeat_and_pad_
placeholder
_tokens
(
tokenizer
,
llm_inputs
.
get
(
"prompt"
),
llm_inputs
[
"prompt_token_ids"
],
image
_token_id
=
image_token_id
,
placeholder
_token_id
=
image_token_id
,
repeat_count
=
image_feature_size
,
)
...
...
@@ -291,3 +303,22 @@ class CLIPVisionModel(nn.Module):
@
property
def
device
(
self
):
return
next
(
self
.
parameters
()).
device
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
dict
(
self
.
named_parameters
())
layer_count
=
len
(
self
.
vision_model
.
encoder
.
layers
)
for
name
,
loaded_weight
in
weights
:
# post_layernorm is not needed in CLIPVisionModel
if
"vision_model.post_layernorm"
in
name
:
continue
# omit layers when num_hidden_layers_override is set
if
"vision_model.encoder.layers."
in
name
:
layer_idx
=
int
(
name
.
split
(
"."
)[
3
])
if
layer_idx
>=
layer_count
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/commandr.py
View file @
af7f4372
...
...
@@ -25,13 +25,11 @@ from typing import Iterable, List, Optional, Set, Tuple
import
torch
import
torch.utils.checkpoint
from
torch
import
nn
from
torch.nn.parameter
import
Parameter
from
transformers
import
CohereConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
...
...
@@ -43,7 +41,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
row_parallel_weight_loader
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
...
...
@@ -67,25 +66,14 @@ class LayerNorm(nn.Module):
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
param_shape
))
self
.
variance_epsilon
=
eps
set_weight_attrs
(
self
.
weight
,
{
"weight_loader"
:
self
.
weight_loader
})
set_weight_attrs
(
self
.
weight
,
{
"weight_loader"
:
row_parallel_weight_loader
})
def
forward
(
self
,
hidden_states
,
residuals
=
None
):
hidden_states
=
layer_norm_func
(
hidden_states
,
self
.
weight
,
self
.
variance_epsilon
)
return
hidden_states
,
residuals
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
tp_rank
=
get_tensor_model_parallel_rank
()
shard_dim
=
0
if
param
.
dim
()
!=
1
else
None
param_data
=
param
.
data
if
shard_dim
is
not
None
:
shard_size
=
param_data
.
shape
[
shard_dim
]
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
start_idx
,
shard_size
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
class
CohereMLP
(
nn
.
Module
):
...
...
@@ -333,6 +321,9 @@ class CohereForCausalLM(nn.Module):
)
->
None
:
super
().
__init__
()
self
.
config
=
config
# currently all existing command R models have `tie_word_embeddings`
# enabled
assert
config
.
tie_word_embeddings
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
...
...
@@ -359,8 +350,11 @@ class CohereForCausalLM(nn.Module):
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
is_not_lora
=
hasattr
(
self
.
model
.
embed_tokens
,
'weight'
)
if
is_not_lora
:
logits
=
self
.
logits_processor
(
self
.
model
.
embed_tokens
,
...
...
vllm/model_executor/models/dbrx.py
View file @
af7f4372
...
...
@@ -362,6 +362,9 @@ class DbrxForCausalLM(nn.Module):
):
super
().
__init__
()
self
.
config
=
config
if
config
.
tie_word_embeddings
:
raise
ValueError
(
"tie_word_embeddings is not supported for Dbrx models."
)
self
.
quant_config
=
quant_config
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
transformer
=
DbrxModel
(
config
,
cache_config
,
quant_config
)
...
...
@@ -388,8 +391,11 @@ class DbrxForCausalLM(nn.Module):
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/deepseek.py
View file @
af7f4372
...
...
@@ -380,6 +380,8 @@ class DeepseekForCausalLM(nn.Module):
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
@@ -395,8 +397,11 @@ class DeepseekForCausalLM(nn.Module):
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
af7f4372
...
...
@@ -505,8 +505,11 @@ class DeepseekV2ForCausalLM(nn.Module):
attn_metadata
,
intermediate_tensors
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
@@ -590,7 +593,7 @@ class DeepseekV2ForCausalLM(nn.Module):
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
weight_
name
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
break
...
...
vllm/model_executor/models/eagle.py
0 → 100644
View file @
af7f4372
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch.nn
as
nn
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.transformers_utils.configs.eagle
import
EAGLEConfig
class
EAGLE
(
nn
.
Module
):
"""This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
Reference implementation: https://github.com/SafeAILab/EAGLE
Differences from reference implementation:
1. In reference, LlamaDecoderLayer implementation doesn't have
input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427)
but we do as HF implementation also does.
2. We allow any decoder layer to be used in EAGLE whereas in reference
decoder layer is fixed to be LlamaDecoderLayer.
3. We have an optional token_map which reduces draft vocab to most
frequently used tokens to give some additional speed-up by reducing
sampling overhead. This is disabled unless the checkpoint file has
explicit token_map tensor and config has an optional attribute
truncated_vocab_size < vocab_size. To use this technique, one has to find
the top-k most frequent tokens in target dataset and add that as a tensor
in the draft checkpoint (using key token_map). Also, the draft config
needs to have truncated_vocab_size (=k) as an attribute."""
def
__init__
(
self
,
config
:
EAGLEConfig
,
*
args
,
**
kwargs
)
->
None
:
super
().
__init__
()
self
.
config
=
config
architectures
=
getattr
(
self
.
config
.
model
,
"architectures"
,
[])
model_cls
,
_
=
ModelRegistry
.
resolve_model_cls
(
architectures
)
self
.
model
=
model_cls
(
self
.
config
.
model
,
*
args
,
**
kwargs
)
self
.
fc
=
nn
.
Linear
(
config
.
model
.
hidden_size
*
2
,
config
.
model
.
hidden_size
,
bias
=
False
)
self
.
orig_vocab_size
=
config
.
vocab_size
self
.
truncated_vocab_size
=
config
.
truncated_vocab_size
self
.
unpadded_vocab_size
=
self
.
truncated_vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
self
.
truncated_vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
,
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
self
.
truncated_vocab_size
,
logit_scale
)
# Token map is a idx to token mapping to reduce the vocab size for
# the draft model. Using smaller vocab size for draft, containing
# only most frequent tokens reduces the speculation overhead. This
# doesn't affect the acceptance rate much and thus gives more speed
# -up. By default, this is disabled and is only used if the EAGLE
# checkpoint file has token_map tensor.
self
.
token_map
=
None
@
property
def
sampler
(
self
):
return
self
.
model
.
sampler
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
previous_hidden_states
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
tok_embeds
=
self
.
model
.
model
.
embed_tokens
(
input_ids
)
inputs_embeds
=
self
.
fc
(
torch
.
cat
([
tok_embeds
,
previous_hidden_states
],
dim
=-
1
))
inputs_embeds
[
positions
==
0
]
=
0
# masking inputs at position=0
hidden_states
=
self
.
model
.
model
(
input_ids
=
None
,
inputs_embeds
=
inputs_embeds
,
positions
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
if
self
.
token_map
is
not
None
:
_logits
=
logits
logits
=
-
torch
.
inf
*
torch
.
ones
(
size
=
(
*
_logits
.
shape
[:
-
1
],
self
.
orig_vocab_size
),
device
=
_logits
.
device
,
dtype
=
_logits
.
dtype
)
logits
[...,
self
.
token_map
]
=
_logits
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# This implementation is incompitable with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B
# due to missing lm_head weights and its config being that of a
# Llama model. Here's a compatible version with the same weights:
# https://huggingface.co/abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm
# Also, here's an example script for converting trained EAGLE
# checkpoint to vLLM compatible version: https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d
model_weights
=
{}
for
name
,
loaded_weight
in
weights
:
if
name
==
"token_map"
:
if
self
.
config
.
truncated_vocab_size
<
self
.
config
.
vocab_size
:
self
.
token_map
=
nn
.
Parameter
(
loaded_weight
,
requires_grad
=
False
)
elif
name
.
startswith
(
"fc."
):
weight_loader
=
getattr
(
self
.
fc
.
weight
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
self
.
fc
.
weight
,
loaded_weight
)
elif
name
.
startswith
(
"model.lm_head."
)
or
name
.
startswith
(
"model.model."
):
model_weights
[
name
.
split
(
"model."
,
1
)[
-
1
]]
=
loaded_weight
elif
name
.
startswith
(
"lm_head."
)
or
name
.
startswith
(
"model."
):
model_weights
[
name
]
=
loaded_weight
else
:
model_weights
[
f
"model.
{
name
}
"
]
=
loaded_weight
lm_head_weight
=
model_weights
.
pop
(
"lm_head.weight"
)
if
self
.
token_map
is
not
None
and
\
lm_head_weight
.
shape
[
0
]
>
self
.
token_map
.
shape
[
0
]:
lm_head_weight
=
lm_head_weight
[
self
.
token_map
]
weight_loader
=
getattr
(
self
.
lm_head
.
weight
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
self
.
lm_head
.
weight
,
lm_head_weight
)
self
.
model
.
load_weights
(
model_weights
.
items
())
vllm/model_executor/models/falcon.py
View file @
af7f4372
...
...
@@ -420,8 +420,11 @@ class FalconForCausalLM(nn.Module):
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/fuyu.py
View file @
af7f4372
...
...
@@ -16,7 +16,8 @@
# limitations under the License.
""" PyTorch Fuyu model."""
import
math
from
typing
import
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
from
array
import
array
from
typing
import
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
import
torch
import
torch.nn
as
nn
...
...
@@ -29,19 +30,19 @@ from vllm.config import CacheConfig, MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
ColumnParallelLinear
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.persimmon
import
PersimmonForCausalLM
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.image
import
(
cached_get_image_processor
,
cached_get_tokenizer
)
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
,
SequenceData
from
vllm.multimodal.image
import
cached_get_image_processor
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SamplerOutput
,
SequenceData
)
from
.interfaces
import
Supports
Vision
from
.utils
import
merge_
vision
_embeddings
from
.interfaces
import
Supports
MultiModal
from
.utils
import
merge_
multimodal
_embeddings
logger
=
init_logger
(
__name__
)
...
...
@@ -94,27 +95,36 @@ def get_max_fuyu_image_tokens(ctx: InputContext):
return
(
ncol
+
1
)
*
nrow
def
dummy_seq_data_for_fuyu
(
ctx
:
InputContext
,
seq_len
:
int
):
def
dummy_seq_data_for_fuyu
(
ctx
:
InputContext
,
seq_len
:
int
,
num_images
:
int
):
ncol
,
nrow
=
get_max_fuyu_image_feature_size
()
image_feature_size
=
get_max_fuyu_image_tokens
(
ctx
)
token_ids
=
([
_IMAGE_TOKEN_ID
]
*
ncol
+
[
_NEWLINE_TOKEN_ID
])
*
nrow
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
)
image_token_ids
=
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
_IMAGE_TOKEN_ID
])
*
ncol
+
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
_NEWLINE_TOKEN_ID
]))
*
nrow
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
image_token_ids
)
*
num_images
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
image_feature_size
*
num_images
)
return
SequenceData
(
token_ids
)
def
dummy_image_for_fuyu
(
num_images
:
int
,
*
,
image_width
:
int
,
image_height
:
int
,
):
image
=
Image
.
new
(
"RGB"
,
(
image_width
,
image_height
),
color
=
0
)
return
{
"image"
:
image
}
return
{
"image"
:
image
if
num_images
==
1
else
[
image
]
*
num_images
}
def
dummy_data_for_fuyu
(
ctx
:
InputContext
,
seq_len
:
int
):
seq_data
=
dummy_seq_data_for_fuyu
(
ctx
,
seq_len
)
mm_data
=
dummy_image_for_fuyu
(
MAX_IMAGE_FEATURE_SIZE_WIDTH
,
MAX_IMAGE_FEATURE_SIZE_HEIGHT
)
def
dummy_data_for_fuyu
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
num_images
=
mm_counts
[
"image"
]
seq_data
=
dummy_seq_data_for_fuyu
(
ctx
,
seq_len
,
num_images
)
mm_data
=
dummy_image_for_fuyu
(
num_images
,
image_width
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
,
image_height
=
MAX_IMAGE_FEATURE_SIZE_HEIGHT
)
return
seq_data
,
mm_data
...
...
@@ -209,7 +219,7 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object):
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_fuyu_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_fuyu
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_fuyu
)
class
FuyuForCausalLM
(
nn
.
Module
,
Supports
Vision
):
class
FuyuForCausalLM
(
nn
.
Module
,
Supports
MultiModal
):
def
__init__
(
self
,
config
:
FuyuConfig
,
...
...
@@ -234,7 +244,8 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
cache_config
=
cache_config
,
quant_config
=
quant_config
)
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
):
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
FuyuImagePixelInputs
]:
image_patches
=
kwargs
.
pop
(
"image_patches"
,
None
)
if
isinstance
(
image_patches
,
torch
.
Tensor
):
...
...
@@ -249,6 +260,13 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
data
=
image_patches
)
return
None
def
_process_image_input
(
self
,
image_input
:
FuyuImagePixelInputs
)
->
torch
.
Tensor
:
assert
self
.
vision_embed_tokens
is
not
None
vision_embeddings
,
_
=
self
.
vision_embed_tokens
(
image_input
[
"data"
])
return
vision_embeddings
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -261,12 +279,11 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
not
None
:
vision_embeddings
,
_
=
self
.
vision_embed_tokens
(
image_input
[
"data"
])
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
inputs_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
inputs_embeds
=
merge_
vision
_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
image_token_id
)
inputs_embeds
=
merge_
multimodal
_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
image_token_id
)
else
:
inputs_embeds
=
None
...
...
@@ -280,8 +297,11 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
language_model
.
logits_processor
(
self
.
language_model
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/gemma.py
View file @
af7f4372
...
...
@@ -33,7 +33,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
GemmaRotaryEmbedding
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
...
...
@@ -148,14 +148,12 @@ class GemmaAttention(nn.Module):
quant_config
=
quant_config
,
)
# TODO(woosuk): Use the `get_rope` interface.
self
.
rotary_emb
=
GemmaRotaryEmbedding
(
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
_embeddings
=
max_position_embeddings
,
max_position
=
max_position_embeddings
,
base
=
self
.
rope_theta
,
is_neox_style
=
True
,
dtype
=
torch
.
get_default_dtype
(),
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
...
...
@@ -333,6 +331,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
super
().
__init__
()
self
.
config
=
config
# currently all existing Gemma models have `tie_word_embeddings` enabled
assert
config
.
tie_word_embeddings
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
...
...
@@ -352,8 +352,11 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
model
.
embed_tokens
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/gemma2.py
View file @
af7f4372
...
...
@@ -32,7 +32,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
GemmaRotaryEmbedding
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
...
...
@@ -130,14 +130,12 @@ class Gemma2Attention(nn.Module):
bias
=
config
.
attention_bias
,
quant_config
=
quant_config
,
)
# TODO(woosuk): Use the `get_rope` interface.
self
.
rotary_emb
=
GemmaRotaryEmbedding
(
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
self
.
head_dim
,
max_position_embeddings
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
base
=
self
.
rope_theta
,
is_neox_style
=
True
,
dtype
=
torch
.
get_default_dtype
(),
)
# FIXME(woosuk): While Gemma 2 uses sliding window attention for every
...
...
@@ -325,6 +323,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
del
lora_config
# Unused.
super
().
__init__
()
self
.
config
=
config
# currently all existing Gemma models have `tie_word_embeddings` enabled
assert
config
.
tie_word_embeddings
self
.
quant_config
=
quant_config
self
.
model
=
Gemma2Model
(
config
,
cache_config
,
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
...
...
@@ -343,8 +343,11 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
model
.
embed_tokens
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/gpt2.py
View file @
af7f4372
...
...
@@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
...
...
@@ -249,7 +249,11 @@ class GPT2LMHeadModel(nn.Module):
cache_config
,
quant_config
,
prefix
=
"transformer"
)
self
.
lm_head
=
self
.
transformer
.
wte
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
transformer
.
wte
else
:
self
.
lm_head
=
ParallelLMHead
(
self
.
config
.
vocab_size
,
self
.
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
@@ -265,8 +269,11 @@ class GPT2LMHeadModel(nn.Module):
attn_metadata
,
intermediate_tensors
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/gpt_bigcode.py
View file @
af7f4372
...
...
@@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
...
...
@@ -259,7 +259,13 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
self
.
quant_config
=
quant_config
self
.
transformer
=
GPTBigCodeModel
(
config
,
cache_config
,
quant_config
,
lora_config
)
self
.
lm_head
=
self
.
transformer
.
wte
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
transformer
.
wte
else
:
self
.
lm_head
=
ParallelLMHead
(
self
.
transformer
.
vocab_size
,
self
.
transformer
.
embed_dim
,
org_num_embeddings
=
self
.
config
.
vocab_size
)
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
...
...
@@ -279,8 +285,11 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
...
...
vllm/model_executor/models/gpt_j.py
View file @
af7f4372
...
...
@@ -246,8 +246,11 @@ class GPTJForCausalLM(nn.Module):
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
,
self
.
lm_head
.
bias
)
return
logits
...
...
Prev
1
…
18
19
20
21
22
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment