Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cdc1fa12
Unverified
Commit
cdc1fa12
authored
Feb 25, 2025
by
Harry Mellor
Committed by
GitHub
Feb 24, 2025
Browse files
Remove unused kwargs from model definitions (#13555)
parent
f61528d4
Changes
104
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
77 additions
and
287 deletions
+77
-287
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+7
-28
vllm/model_executor/models/internlm2_ve.py
vllm/model_executor/models/internlm2_ve.py
+2
-12
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+0
-5
vllm/model_executor/models/jais.py
vllm/model_executor/models/jais.py
+8
-24
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+5
-24
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+7
-21
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+0
-5
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+0
-5
vllm/model_executor/models/llava_next_video.py
vllm/model_executor/models/llava_next_video.py
+0
-5
vllm/model_executor/models/llava_onevision.py
vllm/model_executor/models/llava_onevision.py
+0
-5
vllm/model_executor/models/mamba.py
vllm/model_executor/models/mamba.py
+4
-12
vllm/model_executor/models/mamba2.py
vllm/model_executor/models/mamba2.py
+7
-11
vllm/model_executor/models/minicpm.py
vllm/model_executor/models/minicpm.py
+5
-19
vllm/model_executor/models/minicpm3.py
vllm/model_executor/models/minicpm3.py
+2
-4
vllm/model_executor/models/minicpmo.py
vllm/model_executor/models/minicpmo.py
+0
-5
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+0
-5
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+6
-20
vllm/model_executor/models/mixtral_quant.py
vllm/model_executor/models/mixtral_quant.py
+6
-20
vllm/model_executor/models/mllama.py
vllm/model_executor/models/mllama.py
+15
-35
vllm/model_executor/models/molmo.py
vllm/model_executor/models/molmo.py
+3
-22
No files found.
vllm/model_executor/models/internlm2.py
View file @
cdc1fa12
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
functools
import
partial
from
functools
import
partial
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Set
,
Tuple
,
Type
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
...
@@ -175,13 +175,11 @@ class InternLM2Attention(nn.Module):
...
@@ -175,13 +175,11 @@ class InternLM2Attention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
wqkv
(
hidden_states
)
qkv
,
_
=
self
.
wqkv
(
hidden_states
)
q
,
k
,
v
=
self
.
split_qkv
(
qkv
)
q
,
k
,
v
=
self
.
split_qkv
(
qkv
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
wo
(
attn_output
)
output
,
_
=
self
.
wo
(
attn_output
)
return
output
return
output
...
@@ -227,8 +225,6 @@ class InternLMDecoderLayer(nn.Module):
...
@@ -227,8 +225,6 @@ class InternLMDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
# Self Attention
...
@@ -241,8 +237,6 @@ class InternLMDecoderLayer(nn.Module):
...
@@ -241,8 +237,6 @@ class InternLMDecoderLayer(nn.Module):
hidden_states
=
self
.
attention
(
hidden_states
=
self
.
attention
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
# Fully Connected
# Fully Connected
...
@@ -290,8 +284,6 @@ class InternLM2Model(nn.Module):
...
@@ -290,8 +284,6 @@ class InternLM2Model(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -305,15 +297,8 @@ class InternLM2Model(nn.Module):
...
@@ -305,15 +297,8 @@ class InternLM2Model(nn.Module):
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"hidden_states"
:
hidden_states
,
...
@@ -363,13 +348,10 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
...
@@ -363,13 +348,10 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
hidden_states
return
hidden_states
...
@@ -466,13 +448,10 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
...
@@ -466,13 +448,10 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
logits
,
_
=
self
.
v_head
(
hidden_states
)
logits
,
_
=
self
.
v_head
(
hidden_states
)
return
logits
return
logits
...
...
vllm/model_executor/models/internlm2_ve.py
View file @
cdc1fa12
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
from
vllm.distributed
import
get_pp_group
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
...
@@ -65,8 +64,6 @@ class InternLM2VEDecoderLayer(nn.Module):
...
@@ -65,8 +64,6 @@ class InternLM2VEDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
visual_token_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
visual_token_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
@@ -80,8 +77,6 @@ class InternLM2VEDecoderLayer(nn.Module):
...
@@ -80,8 +77,6 @@ class InternLM2VEDecoderLayer(nn.Module):
hidden_states
=
self
.
attention
(
hidden_states
=
self
.
attention
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
# Fully Connected
# Fully Connected
...
@@ -113,8 +108,6 @@ class InternLM2VEModel(InternLM2Model):
...
@@ -113,8 +108,6 @@ class InternLM2VEModel(InternLM2Model):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
visual_token_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
visual_token_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -129,13 +122,10 @@ class InternLM2VEModel(InternLM2Model):
...
@@ -129,13 +122,10 @@ class InternLM2VEModel(InternLM2Model):
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
positions
,
hidden_states
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
residual
,
visual_token_mask
=
visual_token_mask
,
visual_token_mask
=
visual_token_mask
,
)
)
...
...
vllm/model_executor/models/internvl.py
View file @
cdc1fa12
...
@@ -17,7 +17,6 @@ import torchvision.transforms as T
...
@@ -17,7 +17,6 @@ import torchvision.transforms as T
from
PIL
import
Image
from
PIL
import
Image
from
transformers
import
BatchFeature
,
PretrainedConfig
,
TensorType
from
transformers
import
BatchFeature
,
PretrainedConfig
,
TensorType
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
...
@@ -929,8 +928,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -929,8 +928,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
**
kwargs
:
object
,
...
@@ -951,8 +948,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -951,8 +948,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
forward_kwargs
=
{
forward_kwargs
=
{
"input_ids"
:
input_ids
,
"input_ids"
:
input_ids
,
"positions"
:
positions
,
"positions"
:
positions
,
"kv_caches"
:
kv_caches
,
"attn_metadata"
:
attn_metadata
,
"intermediate_tensors"
:
intermediate_tensors
,
"intermediate_tensors"
:
intermediate_tensors
,
"inputs_embeds"
:
inputs_embeds
,
"inputs_embeds"
:
inputs_embeds
,
}
}
...
...
vllm/model_executor/models/jais.py
View file @
cdc1fa12
...
@@ -21,12 +21,12 @@
...
@@ -21,12 +21,12 @@
"""Inference-only Jais model compatible with HuggingFace weights."""
"""Inference-only Jais model compatible with HuggingFace weights."""
import
math
import
math
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
...
@@ -123,12 +123,10 @@ class JAISAttention(nn.Module):
...
@@ -123,12 +123,10 @@ class JAISAttention(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
attn_output
,
_
=
self
.
c_proj
(
attn_output
)
attn_output
,
_
=
self
.
c_proj
(
attn_output
)
return
attn_output
return
attn_output
...
@@ -200,16 +198,10 @@ class JAISBlock(nn.Module):
...
@@ -200,16 +198,10 @@ class JAISBlock(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
ln_1
(
hidden_states
)
hidden_states
=
self
.
ln_1
(
hidden_states
)
attn_output
=
self
.
attn
(
attn_output
=
self
.
attn
(
hidden_states
=
hidden_states
,
)
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
# residual connection
# residual connection
hidden_states
=
attn_output
+
residual
hidden_states
=
attn_output
+
residual
...
@@ -266,8 +258,6 @@ class JAISModel(nn.Module):
...
@@ -266,8 +258,6 @@ class JAISModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
IntermediateTensors
,
torch
.
Tensor
]:
)
->
Union
[
IntermediateTensors
,
torch
.
Tensor
]:
...
@@ -285,11 +275,8 @@ class JAISModel(nn.Module):
...
@@ -285,11 +275,8 @@ class JAISModel(nn.Module):
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
h
[
self
.
start_layer
:
self
.
end_layer
]:
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
hidden_states
)
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
...
@@ -332,14 +319,11 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
...
@@ -332,14 +319,11 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
IntermediateTensors
,
torch
.
Tensor
]:
)
->
Union
[
IntermediateTensors
,
torch
.
Tensor
]:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
attn_metadata
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
hidden_states
return
hidden_states
def
compute_logits
(
def
compute_logits
(
...
...
vllm/model_executor/models/jamba.py
View file @
cdc1fa12
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
"""Inference-only Jamba model."""
"""Inference-only Jamba model."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
JambaConfig
from
transformers
import
JambaConfig
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
...
@@ -138,7 +137,6 @@ class JambaMambaDecoderLayer(nn.Module):
...
@@ -138,7 +137,6 @@ class JambaMambaDecoderLayer(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
mamba_cache_params
:
MambaCacheParams
,
**
kwargs
,
**
kwargs
,
...
@@ -150,8 +148,7 @@ class JambaMambaDecoderLayer(nn.Module):
...
@@ -150,8 +148,7 @@ class JambaMambaDecoderLayer(nn.Module):
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
)
hidden_states
=
self
.
mamba
(
hidden_states
,
attn_metadata
,
hidden_states
=
self
.
mamba
(
hidden_states
,
mamba_cache_params
)
mamba_cache_params
)
# Fully Connected
# Fully Connected
hidden_states
,
residual
=
self
.
pre_ff_layernorm
(
hidden_states
,
residual
=
self
.
pre_ff_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
)
...
@@ -223,13 +220,11 @@ class JambaAttentionDecoderLayer(nn.Module):
...
@@ -223,13 +220,11 @@ class JambaAttentionDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
**
kwargs
,
**
kwargs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -237,8 +232,6 @@ class JambaAttentionDecoderLayer(nn.Module):
...
@@ -237,8 +232,6 @@ class JambaAttentionDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
**
kwargs
,
**
kwargs
,
):
):
...
@@ -252,8 +245,6 @@ class JambaAttentionDecoderLayer(nn.Module):
...
@@ -252,8 +245,6 @@ class JambaAttentionDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attention
(
hidden_states
=
self
.
self_attention
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
# Fully Connected
# Fully Connected
hidden_states
,
residual
=
self
.
pre_ff_layernorm
(
hidden_states
,
residual
=
self
.
pre_ff_layernorm
(
...
@@ -320,8 +311,6 @@ class JambaModel(nn.Module):
...
@@ -320,8 +311,6 @@ class JambaModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
mamba_cache_params
:
MambaCacheParams
,
mamba_cache_params
:
MambaCacheParams
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -339,12 +328,9 @@ class JambaModel(nn.Module):
...
@@ -339,12 +328,9 @@ class JambaModel(nn.Module):
kv_cache_index
=
0
kv_cache_index
=
0
mamba_cache_index
=
0
mamba_cache_index
=
0
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
layer
=
self
.
layers
[
i
]
kv_cache
=
None
layer_mamba_cache_params
=
None
layer_mamba_cache_params
=
None
if
isinstance
(
layer
,
JambaAttentionDecoderLayer
):
if
isinstance
(
layer
,
JambaAttentionDecoderLayer
):
kv_cache
=
kv_caches
[
kv_cache_index
]
kv_cache_index
+=
1
kv_cache_index
+=
1
if
isinstance
(
layer
,
JambaMambaDecoderLayer
):
if
isinstance
(
layer
,
JambaMambaDecoderLayer
):
current_state_layer
=
mamba_cache_index
current_state_layer
=
mamba_cache_index
...
@@ -355,8 +341,6 @@ class JambaModel(nn.Module):
...
@@ -355,8 +341,6 @@ class JambaModel(nn.Module):
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
residual
=
residual
,
residual
=
residual
,
mamba_cache_params
=
layer_mamba_cache_params
)
mamba_cache_params
=
layer_mamba_cache_params
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
...
@@ -429,8 +413,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
...
@@ -429,8 +413,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def
forward
(
self
,
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
**
kwargs
):
...
@@ -443,8 +425,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
...
@@ -443,8 +425,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
mamba_cache_params
,
attn_metadata
,
mamba_cache_params
,
intermediate_tensors
,
inputs_embeds
)
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/llama.py
View file @
cdc1fa12
...
@@ -22,13 +22,13 @@
...
@@ -22,13 +22,13 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Set
,
Tuple
,
Type
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
LlamaConfig
from
transformers
import
LlamaConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
@@ -197,13 +197,11 @@ class LlamaAttention(nn.Module):
...
@@ -197,13 +197,11 @@ class LlamaAttention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -268,8 +266,6 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -268,8 +266,6 @@ class LlamaDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
# Self Attention
...
@@ -280,9 +276,7 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -280,9 +276,7 @@ class LlamaDecoderLayer(nn.Module):
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
)
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
)
# Fully Connected
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
...
@@ -347,8 +341,6 @@ class LlamaModel(nn.Module):
...
@@ -347,8 +341,6 @@ class LlamaModel(nn.Module):
self
,
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -363,11 +355,8 @@ class LlamaModel(nn.Module):
...
@@ -363,11 +355,8 @@ class LlamaModel(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
return
IntermediateTensors
({
...
@@ -535,13 +524,10 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -535,13 +524,10 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
model_output
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
model_output
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
model_output
return
model_output
...
...
vllm/model_executor/models/llava.py
View file @
cdc1fa12
...
@@ -15,7 +15,6 @@ from transformers import __version__ as TRANSFORMERS_VERSION
...
@@ -15,7 +15,6 @@ from transformers import __version__ as TRANSFORMERS_VERSION
from
transformers.models.llava
import
LlavaProcessor
from
transformers.models.llava
import
LlavaProcessor
from
transformers.models.pixtral
import
PixtralProcessor
from
transformers.models.pixtral
import
PixtralProcessor
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.inputs
import
InputProcessingContext
from
vllm.inputs
import
InputProcessingContext
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
...
@@ -658,8 +657,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -658,8 +657,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
**
kwargs
:
object
,
...
@@ -712,8 +709,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -712,8 +709,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
inputs_embeds
=
inputs_embeds
)
...
...
vllm/model_executor/models/llava_next.py
View file @
cdc1fa12
...
@@ -12,7 +12,6 @@ from transformers.models.llava_next.modeling_llava_next import (
...
@@ -12,7 +12,6 @@ from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape
,
unpad_image
)
get_anyres_image_grid_shape
,
unpad_image
)
from
typing_extensions
import
NotRequired
from
typing_extensions
import
NotRequired
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
@@ -508,8 +507,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -508,8 +507,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
**
kwargs
:
object
,
...
@@ -571,8 +568,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -571,8 +568,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
inputs_embeds
=
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/llava_next_video.py
View file @
cdc1fa12
...
@@ -10,7 +10,6 @@ import torch.nn as nn
...
@@ -10,7 +10,6 @@ import torch.nn as nn
from
transformers
import
(
BatchFeature
,
LlavaNextVideoConfig
,
from
transformers
import
(
BatchFeature
,
LlavaNextVideoConfig
,
LlavaNextVideoProcessor
)
LlavaNextVideoProcessor
)
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
...
@@ -443,8 +442,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -443,8 +442,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
**
kwargs
:
object
,
...
@@ -468,8 +465,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -468,8 +465,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
inputs_embeds
=
inputs_embeds
)
...
...
vllm/model_executor/models/llava_onevision.py
View file @
cdc1fa12
...
@@ -13,7 +13,6 @@ from transformers.models.llava_onevision.modeling_llava_onevision import (
...
@@ -13,7 +13,6 @@ from transformers.models.llava_onevision.modeling_llava_onevision import (
get_anyres_image_grid_shape
,
unpad_image
)
get_anyres_image_grid_shape
,
unpad_image
)
from
typing_extensions
import
NotRequired
from
typing_extensions
import
NotRequired
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
...
@@ -922,8 +921,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -922,8 +921,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
**
kwargs
:
object
,
...
@@ -955,8 +952,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -955,8 +952,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
inputs_embeds
=
inputs_embeds
)
...
...
vllm/model_executor/models/mamba.py
View file @
cdc1fa12
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
"""PyTorch MAMBA model."""
"""PyTorch MAMBA model."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
MambaConfig
from
transformers
import
MambaConfig
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.distributed.parallel_state
import
get_pp_group
...
@@ -64,7 +63,6 @@ class MambaDecoderLayer(nn.Module):
...
@@ -64,7 +63,6 @@ class MambaDecoderLayer(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
mamba_cache_params
:
MambaCacheParams
,
**
kwargs
,
**
kwargs
,
...
@@ -75,8 +73,7 @@ class MambaDecoderLayer(nn.Module):
...
@@ -75,8 +73,7 @@ class MambaDecoderLayer(nn.Module):
else
:
else
:
hidden_states
,
residual
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
residual
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mixer
(
hidden_states
,
attn_metadata
,
hidden_states
=
self
.
mixer
(
hidden_states
,
mamba_cache_params
)
mamba_cache_params
)
return
hidden_states
,
residual
return
hidden_states
,
residual
...
@@ -125,7 +122,6 @@ class MambaModel(nn.Module):
...
@@ -125,7 +122,6 @@ class MambaModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
mamba_cache_params
:
MambaCacheParams
,
mamba_cache_params
:
MambaCacheParams
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -146,7 +142,6 @@ class MambaModel(nn.Module):
...
@@ -146,7 +142,6 @@ class MambaModel(nn.Module):
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
attn_metadata
=
attn_metadata
,
residual
=
residual
,
residual
=
residual
,
mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
i
-
self
.
start_layer
))
i
-
self
.
start_layer
))
...
@@ -208,8 +203,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
...
@@ -208,8 +203,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
def
forward
(
self
,
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
**
kwargs
):
...
@@ -222,9 +215,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
...
@@ -222,9 +215,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
hidden_states
=
self
.
backbone
(
input_ids
,
positions
,
attn_metadata
,
hidden_states
=
self
.
backbone
(
input_ids
,
positions
,
mamba_cache_params
,
mamba_cache_params
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/mamba2.py
View file @
cdc1fa12
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
"""PyTorch MAMBA2 model."""
"""PyTorch MAMBA2 model."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import AttentionMetadata
...
@@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
(
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
(
...
@@ -63,7 +64,6 @@ class Mamba2DecoderLayer(nn.Module):
...
@@ -63,7 +64,6 @@ class Mamba2DecoderLayer(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
mamba_cache_params
:
MambaCacheParams
,
sequence_idx
:
Optional
[
torch
.
Tensor
],
sequence_idx
:
Optional
[
torch
.
Tensor
],
...
@@ -75,8 +75,8 @@ class Mamba2DecoderLayer(nn.Module):
...
@@ -75,8 +75,8 @@ class Mamba2DecoderLayer(nn.Module):
else
:
else
:
hidden_states
,
residual
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
residual
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mixer
(
hidden_states
,
attn_metadata
,
hidden_states
=
self
.
mixer
(
hidden_states
,
mamba_cache_params
,
mamba_cache_params
,
sequence_idx
)
sequence_idx
)
return
hidden_states
,
residual
return
hidden_states
,
residual
...
@@ -122,7 +122,6 @@ class Mamba2Model(nn.Module):
...
@@ -122,7 +122,6 @@ class Mamba2Model(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
mamba_cache_params
:
MambaCacheParams
,
mamba_cache_params
:
MambaCacheParams
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -142,6 +141,7 @@ class Mamba2Model(nn.Module):
...
@@ -142,6 +141,7 @@ class Mamba2Model(nn.Module):
# proper continuous batching computation including
# proper continuous batching computation including
# chunked prefill
# chunked prefill
seq_idx
=
None
seq_idx
=
None
attn_metadata
:
AttentionMetadata
=
get_forward_context
().
attn_metadata
if
attn_metadata
.
num_prefills
>
0
:
if
attn_metadata
.
num_prefills
>
0
:
seq_idx
=
torch
.
zeros_like
(
input_ids
,
dtype
=
torch
.
int32
)
seq_idx
=
torch
.
zeros_like
(
input_ids
,
dtype
=
torch
.
int32
)
for
i
,
(
srt
,
end
)
in
enumerate
(
for
i
,
(
srt
,
end
)
in
enumerate
(
...
@@ -158,7 +158,6 @@ class Mamba2Model(nn.Module):
...
@@ -158,7 +158,6 @@ class Mamba2Model(nn.Module):
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
attn_metadata
=
attn_metadata
,
residual
=
residual
,
residual
=
residual
,
mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
i
-
self
.
start_layer
),
i
-
self
.
start_layer
),
...
@@ -224,8 +223,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
...
@@ -224,8 +223,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
def
forward
(
self
,
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
**
kwargs
):
...
@@ -238,9 +235,8 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
...
@@ -238,9 +235,8 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
hidden_states
=
self
.
backbone
(
input_ids
,
positions
,
attn_metadata
,
hidden_states
=
self
.
backbone
(
input_ids
,
positions
,
mamba_cache_params
,
mamba_cache_params
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/minicpm.py
View file @
cdc1fa12
...
@@ -23,13 +23,13 @@
...
@@ -23,13 +23,13 @@
# limitations under the License.
# limitations under the License.
"""Inference-only MiniCPM model compatible with HuggingFace weights."""
"""Inference-only MiniCPM model compatible with HuggingFace weights."""
import
math
import
math
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
...
@@ -257,8 +257,6 @@ class MiniCPMAttention(nn.Module):
...
@@ -257,8 +257,6 @@ class MiniCPMAttention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
...
@@ -266,7 +264,7 @@ class MiniCPMAttention(nn.Module):
...
@@ -266,7 +264,7 @@ class MiniCPMAttention(nn.Module):
q
,
k
=
q
.
float
(),
k
.
float
()
q
,
k
=
q
.
float
(),
k
.
float
()
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
q
.
to
(
orig_dtype
),
k
.
to
(
orig_dtype
)
q
,
k
=
q
.
to
(
orig_dtype
),
k
.
to
(
orig_dtype
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -331,8 +329,6 @@ class MiniCPMDecoderLayer(nn.Module):
...
@@ -331,8 +329,6 @@ class MiniCPMDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
# Self Attention
...
@@ -341,8 +337,6 @@ class MiniCPMDecoderLayer(nn.Module):
...
@@ -341,8 +337,6 @@ class MiniCPMDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
hidden_states
=
residual
+
hidden_states
*
\
hidden_states
=
residual
+
hidden_states
*
\
(
self
.
config
.
scale_depth
/
math
.
sqrt
(
self
.
config
.
num_hidden_layers
))
(
self
.
config
.
scale_depth
/
math
.
sqrt
(
self
.
config
.
num_hidden_layers
))
...
@@ -409,8 +403,6 @@ class MiniCPMModel(nn.Module):
...
@@ -409,8 +403,6 @@ class MiniCPMModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -424,13 +416,10 @@ class MiniCPMModel(nn.Module):
...
@@ -424,13 +416,10 @@ class MiniCPMModel(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
positions
,
hidden_states
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
residual
,
)
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
...
@@ -579,13 +568,10 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -579,13 +568,10 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/minicpm3.py
View file @
cdc1fa12
...
@@ -29,7 +29,7 @@ import torch
...
@@ -29,7 +29,7 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
...
@@ -129,8 +129,6 @@ class MiniCPM3Attention(nn.Module):
...
@@ -129,8 +129,6 @@ class MiniCPM3Attention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
q
,
_
=
self
.
q_a_proj
(
hidden_states
)
q
,
_
=
self
.
q_a_proj
(
hidden_states
)
q
=
self
.
q_a_layernorm
(
q
)
q
=
self
.
q_a_layernorm
(
q
)
...
@@ -170,7 +168,7 @@ class MiniCPM3Attention(nn.Module):
...
@@ -170,7 +168,7 @@ class MiniCPM3Attention(nn.Module):
v
,
[
0
,
self
.
qk_head_dim
-
self
.
v_head_dim
],
v
,
[
0
,
self
.
qk_head_dim
-
self
.
v_head_dim
],
value
=
0
).
view
(
-
1
,
self
.
num_local_heads
*
self
.
qk_head_dim
)
value
=
0
).
view
(
-
1
,
self
.
num_local_heads
*
self
.
qk_head_dim
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
attn_output
=
attn_output
.
view
(
attn_output
=
attn_output
.
view
(
-
1
,
self
.
num_local_heads
,
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)[...,
:
self
.
v_head_dim
].
reshape
(
self
.
qk_head_dim
)[...,
:
self
.
v_head_dim
].
reshape
(
...
...
vllm/model_executor/models/minicpmo.py
View file @
cdc1fa12
...
@@ -33,7 +33,6 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
...
@@ -33,7 +33,6 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
from
transformers.models.whisper.modeling_whisper
import
(
from
transformers.models.whisper.modeling_whisper
import
(
ACT2FN
,
WHISPER_ATTENTION_CLASSES
,
WhisperConfig
,
WhisperEncoder
)
ACT2FN
,
WHISPER_ATTENTION_CLASSES
,
WhisperConfig
,
WhisperEncoder
)
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
...
@@ -792,8 +791,6 @@ class MiniCPMO(MiniCPMV2_6):
...
@@ -792,8 +791,6 @@ class MiniCPMO(MiniCPMV2_6):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
**
kwargs
:
Any
,
**
kwargs
:
Any
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -818,8 +815,6 @@ class MiniCPMO(MiniCPMV2_6):
...
@@ -818,8 +815,6 @@ class MiniCPMO(MiniCPMV2_6):
output
=
self
.
llm
.
model
(
output
=
self
.
llm
.
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
vlm_embeddings
,
inputs_embeds
=
vlm_embeddings
,
)
)
...
...
vllm/model_executor/models/minicpmv.py
View file @
cdc1fa12
...
@@ -37,7 +37,6 @@ from torch import nn
...
@@ -37,7 +37,6 @@ from torch import nn
from
transformers
import
BatchFeature
,
PretrainedConfig
from
transformers
import
BatchFeature
,
PretrainedConfig
from
typing_extensions
import
TypeVar
from
typing_extensions
import
TypeVar
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.resampler
import
(
BaseResampler
,
Resampler2
,
from
vllm.model_executor.layers.resampler
import
(
BaseResampler
,
Resampler2
,
...
@@ -1030,8 +1029,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -1030,8 +1029,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
**
kwargs
:
Any
,
**
kwargs
:
Any
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -1051,8 +1048,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -1051,8 +1048,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
output
=
self
.
llm
.
model
(
output
=
self
.
llm
.
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
vlm_embeddings
,
inputs_embeds
=
vlm_embeddings
,
)
)
...
...
vllm/model_executor/models/mixtral.py
View file @
cdc1fa12
...
@@ -22,13 +22,13 @@
...
@@ -22,13 +22,13 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only Mixtral model."""
"""Inference-only Mixtral model."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
MixtralConfig
from
transformers
import
MixtralConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
@@ -175,13 +175,11 @@ class MixtralAttention(nn.Module):
...
@@ -175,13 +175,11 @@ class MixtralAttention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -224,8 +222,6 @@ class MixtralDecoderLayer(nn.Module):
...
@@ -224,8 +222,6 @@ class MixtralDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Self Attention
# Self Attention
...
@@ -238,8 +234,6 @@ class MixtralDecoderLayer(nn.Module):
...
@@ -238,8 +234,6 @@ class MixtralDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
# Fully Connected
# Fully Connected
...
@@ -291,8 +285,6 @@ class MixtralModel(nn.Module):
...
@@ -291,8 +285,6 @@ class MixtralModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -306,11 +298,8 @@ class MixtralModel(nn.Module):
...
@@ -306,11 +298,8 @@ class MixtralModel(nn.Module):
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"hidden_states"
:
hidden_states
,
...
@@ -377,13 +366,10 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -377,13 +366,10 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/mixtral_quant.py
View file @
cdc1fa12
...
@@ -22,7 +22,7 @@
...
@@ -22,7 +22,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only Mixtral model."""
"""Inference-only Mixtral model."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -30,7 +30,7 @@ import torch.nn.functional as F
...
@@ -30,7 +30,7 @@ import torch.nn.functional as F
from
torch
import
nn
from
torch
import
nn
from
transformers
import
MixtralConfig
from
transformers
import
MixtralConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
...
@@ -229,13 +229,11 @@ class MixtralAttention(nn.Module):
...
@@ -229,13 +229,11 @@ class MixtralAttention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -274,8 +272,6 @@ class MixtralDecoderLayer(nn.Module):
...
@@ -274,8 +272,6 @@ class MixtralDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Self Attention
# Self Attention
...
@@ -288,8 +284,6 @@ class MixtralDecoderLayer(nn.Module):
...
@@ -288,8 +284,6 @@ class MixtralDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
# Fully Connected
# Fully Connected
...
@@ -333,8 +327,6 @@ class MixtralModel(nn.Module):
...
@@ -333,8 +327,6 @@ class MixtralModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -348,11 +340,8 @@ class MixtralModel(nn.Module):
...
@@ -348,11 +340,8 @@ class MixtralModel(nn.Module):
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"hidden_states"
:
hidden_states
,
...
@@ -390,13 +379,10 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
...
@@ -390,13 +379,10 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/mllama.py
View file @
cdc1fa12
...
@@ -38,7 +38,8 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType
...
@@ -38,7 +38,8 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.attention.selector
import
_Backend
from
vllm.attention.selector
import
_Backend
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tp_group
from
vllm.forward_context
import
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
@@ -416,11 +417,11 @@ class MllamaVisionSdpaAttention(nn.Module):
...
@@ -416,11 +417,11 @@ class MllamaVisionSdpaAttention(nn.Module):
prefix
:
str
=
""
):
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
model
_parallel_size
=
get_t
ensor_model_parallel_
world_size
()
tensor
_parallel_size
=
get_t
p_group
().
world_size
self
.
embed_dim
=
config
.
hidden_size
self
.
embed_dim
=
config
.
hidden_size
self
.
num_heads
=
config
.
attention_heads
self
.
num_heads
=
config
.
attention_heads
self
.
head_dim
=
config
.
hidden_size
//
config
.
attention_heads
self
.
head_dim
=
config
.
hidden_size
//
config
.
attention_heads
self
.
num_local_heads
=
self
.
num_heads
//
model
_parallel_size
self
.
num_local_heads
=
self
.
num_heads
//
tensor
_parallel_size
self
.
q_size
=
self
.
num_local_heads
*
self
.
head_dim
self
.
q_size
=
self
.
num_local_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_local_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_local_heads
*
self
.
head_dim
...
@@ -771,12 +772,13 @@ class MllamaTextCrossAttention(nn.Module):
...
@@ -771,12 +772,13 @@ class MllamaTextCrossAttention(nn.Module):
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
model_parallel_size
=
get_tensor_model_parallel_world_size
()
self
.
pipeline_parallel_rank
=
get_pp_group
().
rank_in_group
self
.
tensor_parallel_size
=
get_tp_group
().
world_size
self
.
num_heads
=
self
.
config
.
num_attention_heads
self
.
num_heads
=
self
.
config
.
num_attention_heads
self
.
num_local_heads
=
self
.
num_heads
//
self
.
model
_parallel_size
self
.
num_local_heads
=
self
.
num_heads
//
self
.
tensor
_parallel_size
self
.
num_key_value_heads
=
self
.
config
.
num_key_value_heads
self
.
num_key_value_heads
=
self
.
config
.
num_key_value_heads
self
.
num_local_key_value_heads
=
\
self
.
num_local_key_value_heads
=
\
self
.
num_key_value_heads
//
self
.
model
_parallel_size
self
.
num_key_value_heads
//
self
.
tensor
_parallel_size
self
.
dropout
=
config
.
dropout
self
.
dropout
=
config
.
dropout
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
self
.
head_dim
=
config
.
hidden_size
//
self
.
num_heads
self
.
head_dim
=
config
.
hidden_size
//
self
.
num_heads
...
@@ -824,8 +826,6 @@ class MllamaTextCrossAttention(nn.Module):
...
@@ -824,8 +826,6 @@ class MllamaTextCrossAttention(nn.Module):
attention_mask
:
Optional
[
torch
.
Tensor
],
attention_mask
:
Optional
[
torch
.
Tensor
],
kv_range_for_decode
:
Optional
[
List
[
Tuple
[
int
,
int
]]],
kv_range_for_decode
:
Optional
[
List
[
Tuple
[
int
,
int
]]],
cross_attention_states
:
Optional
[
torch
.
Tensor
],
cross_attention_states
:
Optional
[
torch
.
Tensor
],
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv_dec
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv_dec
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
_
,
_
=
qkv_dec
.
split
(
q
,
_
,
_
=
qkv_dec
.
split
(
...
@@ -846,14 +846,11 @@ class MllamaTextCrossAttention(nn.Module):
...
@@ -846,14 +846,11 @@ class MllamaTextCrossAttention(nn.Module):
q
=
self
.
q_norm
(
q
)
q
=
self
.
q_norm
(
q
)
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
output
=
self
.
_attention_with_mask
(
q
,
k
,
v
,
kv_cache
,
output
=
self
.
_attention_with_mask
(
q
,
k
,
v
,
attention_mask
,
attention_mask
,
kv_range_for_decode
)
kv_range_for_decode
,
attn_metadata
)
else
:
else
:
output
=
self
.
attn
(
output
=
self
.
attn
(
q
.
view
(
-
1
,
self
.
num_local_heads
*
self
.
head_dim
),
k
,
v
,
q
.
view
(
-
1
,
self
.
num_local_heads
*
self
.
head_dim
),
k
,
v
)
kv_cache
,
attn_metadata
)
out
,
_
=
self
.
o_proj
(
output
)
out
,
_
=
self
.
o_proj
(
output
)
return
out
return
out
...
@@ -862,11 +859,11 @@ class MllamaTextCrossAttention(nn.Module):
...
@@ -862,11 +859,11 @@ class MllamaTextCrossAttention(nn.Module):
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
kv_range_for_decode
:
List
[
Tuple
[
int
,
int
]],
kv_range_for_decode
:
List
[
Tuple
[
int
,
int
]],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
kv_cache
=
self
.
attn
.
kv_cache
[
self
.
pipeline_parallel_rank
]
attn_metadata
:
AttentionMetadata
=
get_forward_context
().
attn_metadata
# Skip writing kv-cache for the initial profiling run.
# Skip writing kv-cache for the initial profiling run.
if
len
(
kv_cache
.
shape
)
>
1
:
if
len
(
kv_cache
.
shape
)
>
1
:
i
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
)
i
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
)
...
@@ -978,8 +975,6 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
...
@@ -978,8 +975,6 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
cross_attention_mask
:
torch
.
Tensor
,
cross_attention_mask
:
torch
.
Tensor
,
kv_range_for_decode
:
Optional
[
List
[
Tuple
[
int
,
int
]]],
kv_range_for_decode
:
Optional
[
List
[
Tuple
[
int
,
int
]]],
full_text_row_masked_out_mask
:
torch
.
Tensor
,
full_text_row_masked_out_mask
:
torch
.
Tensor
,
kv_cache
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
...
@@ -989,8 +984,6 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
...
@@ -989,8 +984,6 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
attention_mask
=
cross_attention_mask
,
attention_mask
=
cross_attention_mask
,
kv_range_for_decode
=
kv_range_for_decode
,
kv_range_for_decode
=
kv_range_for_decode
,
cross_attention_states
=
cross_attention_states
,
cross_attention_states
=
cross_attention_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
hidden_states
=
full_text_row_masked_out_mask
*
hidden_states
hidden_states
=
full_text_row_masked_out_mask
*
hidden_states
hidden_states
=
residual
+
self
.
cross_attn_attn_gate
.
tanh
(
hidden_states
=
residual
+
self
.
cross_attn_attn_gate
.
tanh
(
...
@@ -1054,14 +1047,12 @@ class MllamaTextModel(nn.Module):
...
@@ -1054,14 +1047,12 @@ class MllamaTextModel(nn.Module):
kv_range_for_decode
:
Optional
[
List
[
Tuple
[
int
,
int
]]],
kv_range_for_decode
:
Optional
[
List
[
Tuple
[
int
,
int
]]],
full_text_row_masked_out_mask
:
Optional
[
Tuple
[
torch
.
Tensor
,
full_text_row_masked_out_mask
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
torch
.
Tensor
]],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
skip_cross_attention
:
bool
,
skip_cross_attention
:
bool
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
)
:
for
decoder_layer
in
self
.
layers
:
if
isinstance
(
decoder_layer
,
MllamaCrossAttentionDecoderLayer
):
if
isinstance
(
decoder_layer
,
MllamaCrossAttentionDecoderLayer
):
if
not
skip_cross_attention
:
if
not
skip_cross_attention
:
hidden_states
=
decoder_layer
(
hidden_states
=
decoder_layer
(
...
@@ -1071,15 +1062,11 @@ class MllamaTextModel(nn.Module):
...
@@ -1071,15 +1062,11 @@ class MllamaTextModel(nn.Module):
kv_range_for_decode
=
kv_range_for_decode
,
kv_range_for_decode
=
kv_range_for_decode
,
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
,
full_text_row_masked_out_mask
,
kv_cache
=
kv_caches
[
idx
],
attn_metadata
=
attn_metadata
,
)
)
elif
isinstance
(
decoder_layer
,
LlamaDecoderLayer
):
elif
isinstance
(
decoder_layer
,
LlamaDecoderLayer
):
hidden_states
,
residual
=
decoder_layer
(
hidden_states
,
residual
=
decoder_layer
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_caches
[
idx
],
attn_metadata
=
attn_metadata
,
residual
=
None
,
residual
=
None
,
)
)
hidden_states
=
hidden_states
+
residual
hidden_states
=
hidden_states
+
residual
...
@@ -1124,8 +1111,6 @@ class MllamaForCausalLM(nn.Module):
...
@@ -1124,8 +1111,6 @@ class MllamaForCausalLM(nn.Module):
kv_range_for_decode
:
Optional
[
List
[
Tuple
[
int
,
int
]]],
kv_range_for_decode
:
Optional
[
List
[
Tuple
[
int
,
int
]]],
full_text_row_masked_out_mask
:
Optional
[
Tuple
[
torch
.
Tensor
,
full_text_row_masked_out_mask
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
torch
.
Tensor
]],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
skip_cross_attention
:
bool
,
skip_cross_attention
:
bool
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
...
@@ -1135,8 +1120,6 @@ class MllamaForCausalLM(nn.Module):
...
@@ -1135,8 +1120,6 @@ class MllamaForCausalLM(nn.Module):
cross_attention_mask
=
cross_attention_mask
,
cross_attention_mask
=
cross_attention_mask
,
kv_range_for_decode
=
kv_range_for_decode
,
kv_range_for_decode
=
kv_range_for_decode
,
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
,
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
skip_cross_attention
=
skip_cross_attention
,
skip_cross_attention
=
skip_cross_attention
,
)
)
return
hidden_states
return
hidden_states
...
@@ -1353,10 +1336,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -1353,10 +1336,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
**
kwargs
:
object
,
**
kwargs
:
object
,
)
->
Union
[
Tuple
,
CausalLMOutputWithPast
]:
)
->
Union
[
Tuple
,
CausalLMOutputWithPast
]:
attn_metadata
=
get_forward_context
().
attn_metadata
if
attn_metadata
.
num_prefill_tokens
>
0
and
\
if
attn_metadata
.
num_prefill_tokens
>
0
and
\
attn_metadata
.
num_decode_tokens
>
0
:
attn_metadata
.
num_decode_tokens
>
0
:
raise
ValueError
(
"Chunk prefill not supported"
)
raise
ValueError
(
"Chunk prefill not supported"
)
...
@@ -1410,8 +1392,6 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -1410,8 +1392,6 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
cross_attention_mask
=
cross_attention_mask
,
cross_attention_mask
=
cross_attention_mask
,
kv_range_for_decode
=
kv_range_for_decode
,
kv_range_for_decode
=
kv_range_for_decode
,
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
,
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
skip_cross_attention
=
skip_cross_attention
,
skip_cross_attention
=
skip_cross_attention
,
)
)
...
...
vllm/model_executor/models/molmo.py
View file @
cdc1fa12
...
@@ -16,7 +16,7 @@ from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin,
...
@@ -16,7 +16,7 @@ from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin,
from
transformers.image_utils
import
ImageInput
from
transformers.image_utils
import
ImageInput
from
transformers.tokenization_utils_base
import
TextInput
from
transformers.tokenization_utils_base
import
TextInput
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.attention.layer
import
MultiHeadAttention
from
vllm.attention.layer
import
MultiHeadAttention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
...
@@ -460,15 +460,13 @@ class MolmoAttention(nn.Module):
...
@@ -460,15 +460,13 @@ class MolmoAttention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
if
self
.
q_norm
is
not
None
and
self
.
k_norm
is
not
None
:
if
self
.
q_norm
is
not
None
and
self
.
k_norm
is
not
None
:
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -580,8 +578,6 @@ class MolmoDecoderLayer(nn.Module):
...
@@ -580,8 +578,6 @@ class MolmoDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]:
# Self Attention
# Self Attention
...
@@ -594,8 +590,6 @@ class MolmoDecoderLayer(nn.Module):
...
@@ -594,8 +590,6 @@ class MolmoDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
...
@@ -610,8 +604,6 @@ class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
...
@@ -610,8 +604,6 @@ class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]:
# Self Attention
# Self Attention
...
@@ -619,8 +611,6 @@ class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
...
@@ -619,8 +611,6 @@ class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
...
@@ -841,8 +831,6 @@ class MolmoModel(nn.Module, SupportsQuant):
...
@@ -841,8 +831,6 @@ class MolmoModel(nn.Module, SupportsQuant):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -858,13 +846,10 @@ class MolmoModel(nn.Module, SupportsQuant):
...
@@ -858,13 +846,10 @@ class MolmoModel(nn.Module, SupportsQuant):
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
# Apply blocks one-by-one.
# Apply blocks one-by-one.
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
positions
,
hidden_states
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
residual
,
)
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
...
@@ -1643,8 +1628,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
...
@@ -1643,8 +1628,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
self
,
self
,
input_ids
:
torch
.
LongTensor
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
LongTensor
,
positions
:
torch
.
LongTensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
**
kwargs
:
object
,
...
@@ -1663,8 +1646,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
...
@@ -1663,8 +1646,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
hidden_states
=
self
.
model
(
input_ids
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
inputs_embeds
=
inputs_embeds
)
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment