Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cdc1fa12
Unverified
Commit
cdc1fa12
authored
Feb 25, 2025
by
Harry Mellor
Committed by
GitHub
Feb 24, 2025
Browse files
Remove unused kwargs from model definitions (#13555)
parent
f61528d4
Changes
104
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
77 additions
and
287 deletions
+77
-287
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+7
-28
vllm/model_executor/models/internlm2_ve.py
vllm/model_executor/models/internlm2_ve.py
+2
-12
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+0
-5
vllm/model_executor/models/jais.py
vllm/model_executor/models/jais.py
+8
-24
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+5
-24
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+7
-21
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+0
-5
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+0
-5
vllm/model_executor/models/llava_next_video.py
vllm/model_executor/models/llava_next_video.py
+0
-5
vllm/model_executor/models/llava_onevision.py
vllm/model_executor/models/llava_onevision.py
+0
-5
vllm/model_executor/models/mamba.py
vllm/model_executor/models/mamba.py
+4
-12
vllm/model_executor/models/mamba2.py
vllm/model_executor/models/mamba2.py
+7
-11
vllm/model_executor/models/minicpm.py
vllm/model_executor/models/minicpm.py
+5
-19
vllm/model_executor/models/minicpm3.py
vllm/model_executor/models/minicpm3.py
+2
-4
vllm/model_executor/models/minicpmo.py
vllm/model_executor/models/minicpmo.py
+0
-5
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+0
-5
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+6
-20
vllm/model_executor/models/mixtral_quant.py
vllm/model_executor/models/mixtral_quant.py
+6
-20
vllm/model_executor/models/mllama.py
vllm/model_executor/models/mllama.py
+15
-35
vllm/model_executor/models/molmo.py
vllm/model_executor/models/molmo.py
+3
-22
No files found.
vllm/model_executor/models/internlm2.py
View file @
cdc1fa12
# SPDX-License-Identifier: Apache-2.0
from
functools
import
partial
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Set
,
Tuple
,
Type
,
Union
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
...
...
@@ -175,13 +175,11 @@ class InternLM2Attention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
wqkv
(
hidden_states
)
q
,
k
,
v
=
self
.
split_qkv
(
qkv
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
wo
(
attn_output
)
return
output
...
...
@@ -227,8 +225,6 @@ class InternLMDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
...
...
@@ -241,8 +237,6 @@ class InternLMDecoderLayer(nn.Module):
hidden_states
=
self
.
attention
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
# Fully Connected
...
...
@@ -290,8 +284,6 @@ class InternLM2Model(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
...
@@ -305,15 +297,8 @@ class InternLM2Model(nn.Module):
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
)
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
...
...
@@ -363,13 +348,10 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
...
...
@@ -466,13 +448,10 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
logits
,
_
=
self
.
v_head
(
hidden_states
)
return
logits
...
...
vllm/model_executor/models/internlm2_ve.py
View file @
cdc1fa12
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
from
vllm.model_executor.layers.layernorm
import
RMSNorm
...
...
@@ -65,8 +64,6 @@ class InternLM2VEDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
visual_token_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
@@ -80,8 +77,6 @@ class InternLM2VEDecoderLayer(nn.Module):
hidden_states
=
self
.
attention
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
# Fully Connected
...
...
@@ -113,8 +108,6 @@ class InternLM2VEModel(InternLM2Model):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
visual_token_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -129,13 +122,10 @@ class InternLM2VEModel(InternLM2Model):
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
visual_token_mask
=
visual_token_mask
,
)
...
...
vllm/model_executor/models/internvl.py
View file @
cdc1fa12
...
...
@@ -17,7 +17,6 @@ import torchvision.transforms as T
from
PIL
import
Image
from
transformers
import
BatchFeature
,
PretrainedConfig
,
TensorType
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
...
...
@@ -929,8 +928,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
...
...
@@ -951,8 +948,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
forward_kwargs
=
{
"input_ids"
:
input_ids
,
"positions"
:
positions
,
"kv_caches"
:
kv_caches
,
"attn_metadata"
:
attn_metadata
,
"intermediate_tensors"
:
intermediate_tensors
,
"inputs_embeds"
:
inputs_embeds
,
}
...
...
vllm/model_executor/models/jais.py
View file @
cdc1fa12
...
...
@@ -21,12 +21,12 @@
"""Inference-only Jais model compatible with HuggingFace weights."""
import
math
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
...
...
@@ -123,12 +123,10 @@ class JAISAttention(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
attn_output
,
_
=
self
.
c_proj
(
attn_output
)
return
attn_output
...
...
@@ -200,16 +198,10 @@ class JAISBlock(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
ln_1
(
hidden_states
)
attn_output
=
self
.
attn
(
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
attn_output
=
self
.
attn
(
hidden_states
=
hidden_states
,
)
# residual connection
hidden_states
=
attn_output
+
residual
...
...
@@ -266,8 +258,6 @@ class JAISModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
IntermediateTensors
,
torch
.
Tensor
]:
...
...
@@ -285,11 +275,8 @@ class JAISModel(nn.Module):
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
)
for
layer
in
self
.
h
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
=
layer
(
hidden_states
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
...
...
@@ -332,14 +319,11 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
IntermediateTensors
,
torch
.
Tensor
]:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
...
...
vllm/model_executor/models/jamba.py
View file @
cdc1fa12
# SPDX-License-Identifier: Apache-2.0
"""Inference-only Jamba model."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
JambaConfig
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.layer
import
Attention
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
...
...
@@ -138,7 +137,6 @@ class JambaMambaDecoderLayer(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
**
kwargs
,
...
...
@@ -150,8 +148,7 @@ class JambaMambaDecoderLayer(nn.Module):
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mamba
(
hidden_states
,
attn_metadata
,
mamba_cache_params
)
hidden_states
=
self
.
mamba
(
hidden_states
,
mamba_cache_params
)
# Fully Connected
hidden_states
,
residual
=
self
.
pre_ff_layernorm
(
hidden_states
,
residual
)
...
...
@@ -223,13 +220,11 @@ class JambaAttentionDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
**
kwargs
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -237,8 +232,6 @@ class JambaAttentionDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
**
kwargs
,
):
...
...
@@ -252,8 +245,6 @@ class JambaAttentionDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attention
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
# Fully Connected
hidden_states
,
residual
=
self
.
pre_ff_layernorm
(
...
...
@@ -320,8 +311,6 @@ class JambaModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
mamba_cache_params
:
MambaCacheParams
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -339,12 +328,9 @@ class JambaModel(nn.Module):
kv_cache_index
=
0
mamba_cache_index
=
0
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
kv_cache
=
None
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
layer_mamba_cache_params
=
None
if
isinstance
(
layer
,
JambaAttentionDecoderLayer
):
kv_cache
=
kv_caches
[
kv_cache_index
]
kv_cache_index
+=
1
if
isinstance
(
layer
,
JambaMambaDecoderLayer
):
current_state_layer
=
mamba_cache_index
...
...
@@ -355,8 +341,6 @@ class JambaModel(nn.Module):
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
residual
=
residual
,
mamba_cache_params
=
layer_mamba_cache_params
)
if
not
get_pp_group
().
is_last_rank
:
...
...
@@ -429,8 +413,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
...
...
@@ -443,8 +425,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
mamba_cache_params
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
mamba_cache_params
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
...
...
vllm/model_executor/models/llama.py
View file @
cdc1fa12
...
...
@@ -22,13 +22,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Set
,
Tuple
,
Type
,
Union
import
torch
from
torch
import
nn
from
transformers
import
LlamaConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
...
@@ -197,13 +197,11 @@ class LlamaAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -268,8 +266,6 @@ class LlamaDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
...
...
@@ -280,9 +276,7 @@ class LlamaDecoderLayer(nn.Module):
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
)
hidden_states
=
hidden_states
)
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
...
...
@@ -347,8 +341,6 @@ class LlamaModel(nn.Module):
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
...
@@ -363,11 +355,8 @@ class LlamaModel(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
)
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
...
...
@@ -535,13 +524,10 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
model_output
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
model_output
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
model_output
...
...
vllm/model_executor/models/llava.py
View file @
cdc1fa12
...
...
@@ -15,7 +15,6 @@ from transformers import __version__ as TRANSFORMERS_VERSION
from
transformers.models.llava
import
LlavaProcessor
from
transformers.models.pixtral
import
PixtralProcessor
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.inputs
import
InputProcessingContext
from
vllm.model_executor.layers.activation
import
get_act_fn
...
...
@@ -658,8 +657,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
...
...
@@ -712,8 +709,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
...
...
vllm/model_executor/models/llava_next.py
View file @
cdc1fa12
...
...
@@ -12,7 +12,6 @@ from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape
,
unpad_image
)
from
typing_extensions
import
NotRequired
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
...
@@ -508,8 +507,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
...
...
@@ -571,8 +568,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
return
hidden_states
...
...
vllm/model_executor/models/llava_next_video.py
View file @
cdc1fa12
...
...
@@ -10,7 +10,6 @@ import torch.nn as nn
from
transformers
import
(
BatchFeature
,
LlavaNextVideoConfig
,
LlavaNextVideoProcessor
)
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
...
...
@@ -443,8 +442,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
...
...
@@ -468,8 +465,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
...
...
vllm/model_executor/models/llava_onevision.py
View file @
cdc1fa12
...
...
@@ -13,7 +13,6 @@ from transformers.models.llava_onevision.modeling_llava_onevision import (
get_anyres_image_grid_shape
,
unpad_image
)
from
typing_extensions
import
NotRequired
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
...
...
@@ -922,8 +921,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
...
...
@@ -955,8 +952,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
...
...
vllm/model_executor/models/mamba.py
View file @
cdc1fa12
# SPDX-License-Identifier: Apache-2.0
"""PyTorch MAMBA model."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
MambaConfig
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed.parallel_state
import
get_pp_group
...
...
@@ -64,7 +63,6 @@ class MambaDecoderLayer(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
**
kwargs
,
...
...
@@ -75,8 +73,7 @@ class MambaDecoderLayer(nn.Module):
else
:
hidden_states
,
residual
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mixer
(
hidden_states
,
attn_metadata
,
mamba_cache_params
)
hidden_states
=
self
.
mixer
(
hidden_states
,
mamba_cache_params
)
return
hidden_states
,
residual
...
...
@@ -125,7 +122,6 @@ class MambaModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
mamba_cache_params
:
MambaCacheParams
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -146,7 +142,6 @@ class MambaModel(nn.Module):
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
attn_metadata
=
attn_metadata
,
residual
=
residual
,
mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
i
-
self
.
start_layer
))
...
...
@@ -208,8 +203,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
...
...
@@ -222,9 +215,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
hidden_states
=
self
.
backbone
(
input_ids
,
positions
,
attn_metadata
,
mamba_cache_params
,
intermediate_tensors
,
inputs_embeds
)
hidden_states
=
self
.
backbone
(
input_ids
,
positions
,
mamba_cache_params
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
...
...
vllm/model_executor/models/mamba2.py
View file @
cdc1fa12
# SPDX-License-Identifier: Apache-2.0
"""PyTorch MAMBA2 model."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
import
torch
from
torch
import
nn
...
...
@@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
(
...
...
@@ -63,7 +64,6 @@ class Mamba2DecoderLayer(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
sequence_idx
:
Optional
[
torch
.
Tensor
],
...
...
@@ -75,8 +75,8 @@ class Mamba2DecoderLayer(nn.Module):
else
:
hidden_states
,
residual
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mixer
(
hidden_states
,
attn_metadata
,
mamba_cache_params
,
sequence_idx
)
hidden_states
=
self
.
mixer
(
hidden_states
,
mamba_cache_params
,
sequence_idx
)
return
hidden_states
,
residual
...
...
@@ -122,7 +122,6 @@ class Mamba2Model(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
mamba_cache_params
:
MambaCacheParams
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -142,6 +141,7 @@ class Mamba2Model(nn.Module):
# proper continuous batching computation including
# chunked prefill
seq_idx
=
None
attn_metadata
:
AttentionMetadata
=
get_forward_context
().
attn_metadata
if
attn_metadata
.
num_prefills
>
0
:
seq_idx
=
torch
.
zeros_like
(
input_ids
,
dtype
=
torch
.
int32
)
for
i
,
(
srt
,
end
)
in
enumerate
(
...
...
@@ -158,7 +158,6 @@ class Mamba2Model(nn.Module):
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
attn_metadata
=
attn_metadata
,
residual
=
residual
,
mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
i
-
self
.
start_layer
),
...
...
@@ -224,8 +223,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
...
...
@@ -238,9 +235,8 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
mamba_cache_params
=
self
.
mamba_cache
.
current_run_tensors
(
**
kwargs
)
hidden_states
=
self
.
backbone
(
input_ids
,
positions
,
attn_metadata
,
mamba_cache_params
,
intermediate_tensors
,
inputs_embeds
)
hidden_states
=
self
.
backbone
(
input_ids
,
positions
,
mamba_cache_params
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
...
...
vllm/model_executor/models/minicpm.py
View file @
cdc1fa12
...
...
@@ -23,13 +23,13 @@
# limitations under the License.
"""Inference-only MiniCPM model compatible with HuggingFace weights."""
import
math
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
...
...
@@ -257,8 +257,6 @@ class MiniCPMAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
...
...
@@ -266,7 +264,7 @@ class MiniCPMAttention(nn.Module):
q
,
k
=
q
.
float
(),
k
.
float
()
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
q
.
to
(
orig_dtype
),
k
.
to
(
orig_dtype
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -331,8 +329,6 @@ class MiniCPMDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
...
...
@@ -341,8 +337,6 @@ class MiniCPMDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
hidden_states
=
residual
+
hidden_states
*
\
(
self
.
config
.
scale_depth
/
math
.
sqrt
(
self
.
config
.
num_hidden_layers
))
...
...
@@ -409,8 +403,6 @@ class MiniCPMModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
...
@@ -424,13 +416,10 @@ class MiniCPMModel(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
)
if
not
get_pp_group
().
is_last_rank
:
...
...
@@ -579,13 +568,10 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
...
...
vllm/model_executor/models/minicpm3.py
View file @
cdc1fa12
...
...
@@ -29,7 +29,7 @@ import torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.layernorm
import
RMSNorm
...
...
@@ -129,8 +129,6 @@ class MiniCPM3Attention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
q
,
_
=
self
.
q_a_proj
(
hidden_states
)
q
=
self
.
q_a_layernorm
(
q
)
...
...
@@ -170,7 +168,7 @@ class MiniCPM3Attention(nn.Module):
v
,
[
0
,
self
.
qk_head_dim
-
self
.
v_head_dim
],
value
=
0
).
view
(
-
1
,
self
.
num_local_heads
*
self
.
qk_head_dim
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
attn_output
=
attn_output
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)[...,
:
self
.
v_head_dim
].
reshape
(
...
...
vllm/model_executor/models/minicpmo.py
View file @
cdc1fa12
...
...
@@ -33,7 +33,6 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
from
transformers.models.whisper.modeling_whisper
import
(
ACT2FN
,
WHISPER_ATTENTION_CLASSES
,
WhisperConfig
,
WhisperEncoder
)
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
...
...
@@ -792,8 +791,6 @@ class MiniCPMO(MiniCPMV2_6):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
**
kwargs
:
Any
,
)
->
torch
.
Tensor
:
...
...
@@ -818,8 +815,6 @@ class MiniCPMO(MiniCPMV2_6):
output
=
self
.
llm
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
vlm_embeddings
,
)
...
...
vllm/model_executor/models/minicpmv.py
View file @
cdc1fa12
...
...
@@ -37,7 +37,6 @@ from torch import nn
from
transformers
import
BatchFeature
,
PretrainedConfig
from
typing_extensions
import
TypeVar
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.resampler
import
(
BaseResampler
,
Resampler2
,
...
...
@@ -1030,8 +1029,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
**
kwargs
:
Any
,
)
->
torch
.
Tensor
:
...
...
@@ -1051,8 +1048,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
output
=
self
.
llm
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
vlm_embeddings
,
)
...
...
vllm/model_executor/models/mixtral.py
View file @
cdc1fa12
...
...
@@ -22,13 +22,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
transformers
import
MixtralConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
...
@@ -175,13 +175,11 @@ class MixtralAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -224,8 +222,6 @@ class MixtralDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
...
...
@@ -238,8 +234,6 @@ class MixtralDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
# Fully Connected
...
...
@@ -291,8 +285,6 @@ class MixtralModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
...
@@ -306,11 +298,8 @@ class MixtralModel(nn.Module):
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
)
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
...
...
@@ -377,13 +366,10 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
...
...
vllm/model_executor/models/mixtral_quant.py
View file @
cdc1fa12
...
...
@@ -22,7 +22,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
numpy
as
np
import
torch
...
...
@@ -30,7 +30,7 @@ import torch.nn.functional as F
from
torch
import
nn
from
transformers
import
MixtralConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
...
...
@@ -229,13 +229,11 @@ class MixtralAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -274,8 +272,6 @@ class MixtralDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
...
...
@@ -288,8 +284,6 @@ class MixtralDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
# Fully Connected
...
...
@@ -333,8 +327,6 @@ class MixtralModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
...
@@ -348,11 +340,8 @@ class MixtralModel(nn.Module):
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
)
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
...
...
@@ -390,13 +379,10 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
...
...
vllm/model_executor/models/mllama.py
View file @
cdc1fa12
...
...
@@ -38,7 +38,8 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.attention.selector
import
_Backend
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tp_group
from
vllm.forward_context
import
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -416,11 +417,11 @@ class MllamaVisionSdpaAttention(nn.Module):
prefix
:
str
=
""
):
super
().
__init__
()
model
_parallel_size
=
get_t
ensor_model_parallel_
world_size
()
tensor
_parallel_size
=
get_t
p_group
().
world_size
self
.
embed_dim
=
config
.
hidden_size
self
.
num_heads
=
config
.
attention_heads
self
.
head_dim
=
config
.
hidden_size
//
config
.
attention_heads
self
.
num_local_heads
=
self
.
num_heads
//
model
_parallel_size
self
.
num_local_heads
=
self
.
num_heads
//
tensor
_parallel_size
self
.
q_size
=
self
.
num_local_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_local_heads
*
self
.
head_dim
...
...
@@ -771,12 +772,13 @@ class MllamaTextCrossAttention(nn.Module):
):
super
().
__init__
()
self
.
config
=
config
self
.
model_parallel_size
=
get_tensor_model_parallel_world_size
()
self
.
pipeline_parallel_rank
=
get_pp_group
().
rank_in_group
self
.
tensor_parallel_size
=
get_tp_group
().
world_size
self
.
num_heads
=
self
.
config
.
num_attention_heads
self
.
num_local_heads
=
self
.
num_heads
//
self
.
model
_parallel_size
self
.
num_local_heads
=
self
.
num_heads
//
self
.
tensor
_parallel_size
self
.
num_key_value_heads
=
self
.
config
.
num_key_value_heads
self
.
num_local_key_value_heads
=
\
self
.
num_key_value_heads
//
self
.
model
_parallel_size
self
.
num_key_value_heads
//
self
.
tensor
_parallel_size
self
.
dropout
=
config
.
dropout
self
.
hidden_size
=
config
.
hidden_size
self
.
head_dim
=
config
.
hidden_size
//
self
.
num_heads
...
...
@@ -824,8 +826,6 @@ class MllamaTextCrossAttention(nn.Module):
attention_mask
:
Optional
[
torch
.
Tensor
],
kv_range_for_decode
:
Optional
[
List
[
Tuple
[
int
,
int
]]],
cross_attention_states
:
Optional
[
torch
.
Tensor
],
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv_dec
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
_
,
_
=
qkv_dec
.
split
(
...
...
@@ -846,14 +846,11 @@ class MllamaTextCrossAttention(nn.Module):
q
=
self
.
q_norm
(
q
)
if
attention_mask
is
not
None
:
output
=
self
.
_attention_with_mask
(
q
,
k
,
v
,
kv_cache
,
attention_mask
,
kv_range_for_decode
,
attn_metadata
)
output
=
self
.
_attention_with_mask
(
q
,
k
,
v
,
attention_mask
,
kv_range_for_decode
)
else
:
output
=
self
.
attn
(
q
.
view
(
-
1
,
self
.
num_local_heads
*
self
.
head_dim
),
k
,
v
,
kv_cache
,
attn_metadata
)
q
.
view
(
-
1
,
self
.
num_local_heads
*
self
.
head_dim
),
k
,
v
)
out
,
_
=
self
.
o_proj
(
output
)
return
out
...
...
@@ -862,11 +859,11 @@ class MllamaTextCrossAttention(nn.Module):
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
kv_range_for_decode
:
List
[
Tuple
[
int
,
int
]],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
kv_cache
=
self
.
attn
.
kv_cache
[
self
.
pipeline_parallel_rank
]
attn_metadata
:
AttentionMetadata
=
get_forward_context
().
attn_metadata
# Skip writing kv-cache for the initial profiling run.
if
len
(
kv_cache
.
shape
)
>
1
:
i
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
)
...
...
@@ -978,8 +975,6 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
cross_attention_mask
:
torch
.
Tensor
,
kv_range_for_decode
:
Optional
[
List
[
Tuple
[
int
,
int
]]],
full_text_row_masked_out_mask
:
torch
.
Tensor
,
kv_cache
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
...
...
@@ -989,8 +984,6 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
attention_mask
=
cross_attention_mask
,
kv_range_for_decode
=
kv_range_for_decode
,
cross_attention_states
=
cross_attention_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
hidden_states
=
full_text_row_masked_out_mask
*
hidden_states
hidden_states
=
residual
+
self
.
cross_attn_attn_gate
.
tanh
(
...
...
@@ -1054,14 +1047,12 @@ class MllamaTextModel(nn.Module):
kv_range_for_decode
:
Optional
[
List
[
Tuple
[
int
,
int
]]],
full_text_row_masked_out_mask
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
skip_cross_attention
:
bool
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
inputs_embeds
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
)
:
for
decoder_layer
in
self
.
layers
:
if
isinstance
(
decoder_layer
,
MllamaCrossAttentionDecoderLayer
):
if
not
skip_cross_attention
:
hidden_states
=
decoder_layer
(
...
...
@@ -1071,15 +1062,11 @@ class MllamaTextModel(nn.Module):
kv_range_for_decode
=
kv_range_for_decode
,
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
,
kv_cache
=
kv_caches
[
idx
],
attn_metadata
=
attn_metadata
,
)
elif
isinstance
(
decoder_layer
,
LlamaDecoderLayer
):
hidden_states
,
residual
=
decoder_layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_caches
[
idx
],
attn_metadata
=
attn_metadata
,
residual
=
None
,
)
hidden_states
=
hidden_states
+
residual
...
...
@@ -1124,8 +1111,6 @@ class MllamaForCausalLM(nn.Module):
kv_range_for_decode
:
Optional
[
List
[
Tuple
[
int
,
int
]]],
full_text_row_masked_out_mask
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
skip_cross_attention
:
bool
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
...
...
@@ -1135,8 +1120,6 @@ class MllamaForCausalLM(nn.Module):
cross_attention_mask
=
cross_attention_mask
,
kv_range_for_decode
=
kv_range_for_decode
,
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
skip_cross_attention
=
skip_cross_attention
,
)
return
hidden_states
...
...
@@ -1353,10 +1336,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
**
kwargs
:
object
,
)
->
Union
[
Tuple
,
CausalLMOutputWithPast
]:
attn_metadata
=
get_forward_context
().
attn_metadata
if
attn_metadata
.
num_prefill_tokens
>
0
and
\
attn_metadata
.
num_decode_tokens
>
0
:
raise
ValueError
(
"Chunk prefill not supported"
)
...
...
@@ -1410,8 +1392,6 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
cross_attention_mask
=
cross_attention_mask
,
kv_range_for_decode
=
kv_range_for_decode
,
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
skip_cross_attention
=
skip_cross_attention
,
)
...
...
vllm/model_executor/models/molmo.py
View file @
cdc1fa12
...
...
@@ -16,7 +16,7 @@ from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin,
from
transformers.image_utils
import
ImageInput
from
transformers.tokenization_utils_base
import
TextInput
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.attention.layer
import
MultiHeadAttention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
...
...
@@ -460,15 +460,13 @@ class MolmoAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
if
self
.
q_norm
is
not
None
and
self
.
k_norm
is
not
None
:
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -580,8 +578,6 @@ class MolmoDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]:
# Self Attention
...
...
@@ -594,8 +590,6 @@ class MolmoDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
...
...
@@ -610,8 +604,6 @@ class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]:
# Self Attention
...
...
@@ -619,8 +611,6 @@ class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
...
...
@@ -841,8 +831,6 @@ class MolmoModel(nn.Module, SupportsQuant):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
...
...
@@ -858,13 +846,10 @@ class MolmoModel(nn.Module, SupportsQuant):
residual
=
intermediate_tensors
[
"residual"
]
# Apply blocks one-by-one.
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
)
if
not
get_pp_group
().
is_last_rank
:
...
...
@@ -1643,8 +1628,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
LongTensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
...
...
@@ -1663,8 +1646,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment