Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cdc1fa12
Unverified
Commit
cdc1fa12
authored
Feb 25, 2025
by
Harry Mellor
Committed by
GitHub
Feb 24, 2025
Browse files
Remove unused kwargs from model definitions (#13555)
parent
f61528d4
Changes
104
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
85 additions
and
340 deletions
+85
-340
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+7
-24
vllm/model_executor/models/nemotron.py
vllm/model_executor/models/nemotron.py
+5
-23
vllm/model_executor/models/olmo.py
vllm/model_executor/models/olmo.py
+6
-22
vllm/model_executor/models/olmo2.py
vllm/model_executor/models/olmo2.py
+6
-22
vllm/model_executor/models/olmoe.py
vllm/model_executor/models/olmoe.py
+5
-19
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+7
-25
vllm/model_executor/models/orion.py
vllm/model_executor/models/orion.py
+6
-23
vllm/model_executor/models/paligemma.py
vllm/model_executor/models/paligemma.py
+1
-6
vllm/model_executor/models/persimmon.py
vllm/model_executor/models/persimmon.py
+5
-22
vllm/model_executor/models/phi.py
vllm/model_executor/models/phi.py
+6
-23
vllm/model_executor/models/phi3_small.py
vllm/model_executor/models/phi3_small.py
+5
-23
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+0
-5
vllm/model_executor/models/phimoe.py
vllm/model_executor/models/phimoe.py
+5
-19
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+0
-5
vllm/model_executor/models/prithvi_geospatial_mae.py
vllm/model_executor/models/prithvi_geospatial_mae.py
+1
-4
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+6
-20
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+6
-23
vllm/model_executor/models/qwen2_5_vl.py
vllm/model_executor/models/qwen2_5_vl.py
+0
-5
vllm/model_executor/models/qwen2_audio.py
vllm/model_executor/models/qwen2_audio.py
+2
-7
vllm/model_executor/models/qwen2_moe.py
vllm/model_executor/models/qwen2_moe.py
+6
-20
No files found.
vllm/model_executor/models/mpt.py
View file @
cdc1fa12
...
...
@@ -2,12 +2,12 @@
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import
math
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
...
...
@@ -125,8 +125,6 @@ class MPTAttention(nn.Module):
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
del
position_ids
# unused.
qkv
,
_
=
self
.
Wqkv
(
hidden_states
)
...
...
@@ -136,7 +134,7 @@ class MPTAttention(nn.Module):
if
self
.
qk_ln
:
q
=
self
.
q_ln
(
q
)
k
=
self
.
k_ln
(
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
...
...
@@ -196,15 +194,11 @@ class MPTBlock(nn.Module):
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
x
=
self
.
norm_1
(
hidden_states
)
x
=
self
.
attn
(
position_ids
=
position_ids
,
hidden_states
=
x
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
hidden_states
=
hidden_states
+
x
x
=
self
.
norm_2
(
hidden_states
)
...
...
@@ -253,8 +247,6 @@ class MPTModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
...
@@ -267,14 +259,8 @@ class MPTModel(nn.Module):
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
block
=
self
.
blocks
[
i
]
hidden_states
=
block
(
position_ids
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
)
for
block
in
self
.
blocks
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
=
block
(
position_ids
,
hidden_states
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
hidden_states
=
self
.
norm_f
(
hidden_states
)
...
...
@@ -306,14 +292,11 @@ class MPTForCausalLM(nn.Module, SupportsPP):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
...
...
vllm/model_executor/models/nemotron.py
View file @
cdc1fa12
...
...
@@ -27,7 +27,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import
torch
from
torch
import
nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
...
@@ -204,13 +204,11 @@ class NemotronAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -269,8 +267,6 @@ class NemotronDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
...
...
@@ -283,8 +279,6 @@ class NemotronDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
# Fully Connected
...
...
@@ -343,8 +337,6 @@ class NemotronModel(nn.Module):
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
...
@@ -359,15 +351,8 @@ class NemotronModel(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
)
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
...
...
@@ -444,13 +429,10 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
model_output
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
model_output
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
model_output
...
...
vllm/model_executor/models/olmo.py
View file @
cdc1fa12
...
...
@@ -22,13 +22,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only OLMo model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
transformers
import
OlmoConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
...
@@ -119,15 +119,13 @@ class OlmoAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
if
self
.
clip_qkv
is
not
None
:
qkv
.
clamp_
(
min
=-
self
.
clip_qkv
,
max
=
self
.
clip_qkv
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -212,14 +210,11 @@ class OlmoDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]:
# Attention block.
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
positions
,
hidden_states
,
kv_cache
,
attn_metadata
)
hidden_states
=
self
.
self_attn
(
positions
,
hidden_states
)
hidden_states
=
hidden_states
+
residual
# MLP block.
...
...
@@ -263,8 +258,6 @@ class OlmoModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
...
@@ -281,14 +274,9 @@ class OlmoModel(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
# Apply blocks one-by-one.
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
)
:
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]
:
# shape: (batch_size, seq_len, d_model)
hidden_states
=
self
.
layers
[
i
](
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
)
hidden_states
=
layer
(
positions
,
hidden_states
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
...
...
@@ -332,16 +320,12 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
...
...
vllm/model_executor/models/olmo2.py
View file @
cdc1fa12
...
...
@@ -24,12 +24,12 @@
"""Inference-only OLMo2 model compatible with HuggingFace weights."""
from
functools
import
partial
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed.communication_op
import
tensor_model_parallel_all_gather
...
...
@@ -153,14 +153,12 @@ class Olmo2Attention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -239,13 +237,10 @@ class Olmo2DecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
# Attention block.
residual
=
hidden_states
hidden_states
=
self
.
self_attn
(
positions
,
hidden_states
,
kv_cache
,
attn_metadata
)
hidden_states
=
self
.
self_attn
(
positions
,
hidden_states
)
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
hidden_states
+
residual
...
...
@@ -287,8 +282,6 @@ class Olmo2Model(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
"""
...
...
@@ -307,14 +300,9 @@ class Olmo2Model(nn.Module):
assert
isinstance
(
hidden_states
,
torch
.
Tensor
)
# Apply blocks one-by-one.
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
)
:
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]
:
# shape: (batch_size, seq_len, d_model)
hidden_states
=
self
.
layers
[
i
](
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
)
hidden_states
=
layer
(
positions
,
hidden_states
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
...
...
@@ -357,15 +345,11 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
)
return
hidden_states
...
...
vllm/model_executor/models/olmoe.py
View file @
cdc1fa12
...
...
@@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only OLMoE model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
...
@@ -168,14 +168,12 @@ class OlmoeAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
q_norm
(
q
.
contiguous
()),
self
.
k_norm
(
k
.
contiguous
())
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -222,8 +220,6 @@ class OlmoeDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
...
...
@@ -237,8 +233,6 @@ class OlmoeDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
# Fully Connected
...
...
@@ -283,8 +277,6 @@ class OlmoeModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
...
@@ -299,13 +291,10 @@ class OlmoeModel(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
)
...
...
@@ -347,13 +336,10 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
...
...
vllm/model_executor/models/opt.py
View file @
cdc1fa12
...
...
@@ -18,13 +18,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only OPT model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
transformers
import
OPTConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
...
@@ -107,12 +107,10 @@ class OPTAttention(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
...
...
@@ -164,17 +162,13 @@ class OPTDecoderLayer(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
# Self Attention
residual
=
hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if
self
.
do_layer_norm_before
:
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
)
hidden_states
=
residual
+
hidden_states
# 350m applies layer norm AFTER attention
if
not
self
.
do_layer_norm_before
:
...
...
@@ -261,8 +255,6 @@ class OPTDecoder(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
...
@@ -277,11 +269,8 @@ class OPTDecoder(nn.Module):
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
)
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
=
layer
(
hidden_states
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
...
...
@@ -317,15 +306,11 @@ class OPTModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
return
self
.
decoder
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
...
...
@@ -362,13 +347,10 @@ class OPTForCausalLM(nn.Module, SupportsPP):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
...
...
vllm/model_executor/models/orion.py
View file @
cdc1fa12
...
...
@@ -5,13 +5,13 @@
# Copyright (c) OrionStar Inc.
# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
"""Inference-only Orion-14B model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
...
@@ -136,13 +136,11 @@ class OrionAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -189,8 +187,6 @@ class OrionDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
residual
=
hidden_states
...
...
@@ -198,8 +194,6 @@ class OrionDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
hidden_states
=
residual
+
hidden_states
...
...
@@ -247,8 +241,6 @@ class OrionModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
...
@@ -260,14 +252,8 @@ class OrionModel(nn.Module):
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
)
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
=
layer
(
positions
,
hidden_states
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
...
...
@@ -303,13 +289,10 @@ class OrionForCausalLM(nn.Module, SupportsPP):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
...
...
vllm/model_executor/models/paligemma.py
View file @
cdc1fa12
# SPDX-License-Identifier: Apache-2.0
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Set
,
Tuple
,
from
typing
import
(
Iterable
,
Literal
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
)
import
torch
from
torch
import
nn
from
transformers
import
PaliGemmaConfig
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
InputContext
,
token_inputs
)
...
...
@@ -288,8 +287,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
)
->
Union
[
SamplerOutput
,
IntermediateTensors
]:
...
...
@@ -306,8 +303,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
...
...
vllm/model_executor/models/persimmon.py
View file @
cdc1fa12
...
...
@@ -21,13 +21,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only persimmon model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
transformers
import
PersimmonConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
...
@@ -142,8 +142,6 @@ class PersimmonAttention(nn.Module):
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
# [seq_length, 3 x hidden_size]
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
...
...
@@ -161,7 +159,7 @@ class PersimmonAttention(nn.Module):
k
=
self
.
_merge_heads
(
k
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
dense
(
attn_output
)
return
output
...
...
@@ -189,8 +187,6 @@ class PersimmonDecoderLayer(nn.Module):
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
...
...
@@ -200,8 +196,6 @@ class PersimmonDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
position_ids
=
position_ids
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
hidden_states
=
residual
+
hidden_states
...
...
@@ -248,8 +242,6 @@ class PersimmonModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
...
@@ -261,13 +253,8 @@ class PersimmonModel(nn.Module):
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
hidden_states
=
self
.
layers
[
i
](
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
)
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
=
layer
(
positions
,
hidden_states
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
...
...
@@ -298,16 +285,12 @@ class PersimmonForCausalLM(nn.Module, SupportsPP):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
):
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
...
...
vllm/model_executor/models/phi.py
View file @
cdc1fa12
...
...
@@ -36,13 +36,13 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
transformers
import
PhiConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
...
@@ -126,13 +126,11 @@ class PhiAttention(nn.Module):
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
dense
(
attn_output
)
return
output
...
...
@@ -186,16 +184,12 @@ class PhiLayer(nn.Module):
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
attn_outputs
=
self
.
self_attn
(
position_ids
=
position_ids
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
feed_forward_hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
attn_outputs
+
feed_forward_hidden_states
+
residual
...
...
@@ -234,8 +228,6 @@ class PhiModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
...
@@ -247,14 +239,8 @@ class PhiModel(nn.Module):
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
)
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
=
layer
(
positions
,
hidden_states
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
...
...
@@ -304,13 +290,10 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
...
...
vllm/model_executor/models/phi3_small.py
View file @
cdc1fa12
# SPDX-License-Identifier: Apache-2.0
import
math
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
transformers.configuration_utils
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
...
...
@@ -231,8 +231,6 @@ class Phi3SmallSelfAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
...
...
@@ -248,7 +246,7 @@ class Phi3SmallSelfAttention(nn.Module):
v
=
v
.
reshape
(
-
1
,
self
.
head_dim
*
self
.
num_kv_heads_per_partion
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
=
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
dense
(
attn_output
)
return
output
...
...
@@ -282,8 +280,6 @@ class Phi3SmallDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
...
...
@@ -291,8 +287,6 @@ class Phi3SmallDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
hidden_states
=
residual
+
hidden_states
...
...
@@ -338,8 +332,6 @@ class Phi3SmallModel(nn.Module):
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
Optional
[
torch
.
LongTensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
],
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
...
@@ -354,14 +346,8 @@ class Phi3SmallModel(nn.Module):
else
:
assert
intermediate_tensors
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
)
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
=
layer
(
positions
,
hidden_states
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
...
...
@@ -438,16 +424,12 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP):
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
Optional
[
torch
.
LongTensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
output_hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
...
...
vllm/model_executor/models/phi3v.py
View file @
cdc1fa12
...
...
@@ -23,7 +23,6 @@ import torch.nn as nn
from
transformers
import
(
BatchFeature
,
CLIPVisionConfig
,
PretrainedConfig
,
ProcessorMixin
)
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
...
...
@@ -672,8 +671,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
):
...
...
@@ -691,8 +688,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
...
...
vllm/model_executor/models/phimoe.py
View file @
cdc1fa12
...
...
@@ -22,13 +22,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only PhiMoE model."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
transformers.configuration_utils
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
...
@@ -357,13 +357,11 @@ class PhiMoEAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -410,8 +408,6 @@ class PhiMoEDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
residual
=
hidden_states
...
...
@@ -422,8 +418,6 @@ class PhiMoEDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
hidden_states
=
hidden_states
+
residual
...
...
@@ -478,8 +472,6 @@ class PhiMoEModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
...
@@ -494,13 +486,10 @@ class PhiMoEModel(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
)
...
...
@@ -571,13 +560,10 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
...
...
vllm/model_executor/models/pixtral.py
View file @
cdc1fa12
...
...
@@ -16,7 +16,6 @@ from transformers.models.pixtral.image_processing_pixtral import (
from
transformers.models.pixtral.modeling_pixtral
import
(
PixtralRotaryEmbedding
,
apply_rotary_pos_emb
,
position_ids_in_meshgrid
)
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
...
...
@@ -270,8 +269,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
...
...
@@ -291,8 +288,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
...
...
vllm/model_executor/models/prithvi_geospatial_mae.py
View file @
cdc1fa12
...
...
@@ -15,13 +15,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only IBM/NASA Prithvi Geospatial model."""
from
typing
import
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Mapping
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
from
transformers
import
BatchFeature
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
(
IsAttentionFree
,
...
...
@@ -181,8 +180,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
...
...
vllm/model_executor/models/qwen.py
View file @
cdc1fa12
...
...
@@ -6,13 +6,13 @@
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
"""Inference-only QWen model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
...
@@ -124,13 +124,11 @@ class QWenAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
c_proj
(
attn_output
)
return
output
...
...
@@ -168,8 +166,6 @@ class QWenBlock(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
...
...
@@ -181,8 +177,6 @@ class QWenBlock(nn.Module):
hidden_states
=
self
.
attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
# Fully Connected
...
...
@@ -225,8 +219,6 @@ class QWenModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
...
@@ -241,13 +233,10 @@ class QWenModel(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
h
[
i
]
for
layer
in
self
.
h
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
)
if
not
get_pp_group
().
is_last_rank
:
...
...
@@ -373,12 +362,9 @@ class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
vllm/model_executor/models/qwen2.py
View file @
cdc1fa12
...
...
@@ -23,13 +23,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
transformers
import
Qwen2Config
from
vllm.attention
import
Attention
,
AttentionMetadata
,
AttentionType
from
vllm.attention
import
Attention
,
AttentionType
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
...
@@ -170,13 +170,11 @@ class Qwen2Attention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -233,8 +231,6 @@ class Qwen2DecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
...
...
@@ -247,8 +243,6 @@ class Qwen2DecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
# Fully Connected
...
...
@@ -328,8 +322,6 @@ class Qwen2Model(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
...
@@ -343,13 +335,10 @@ class Qwen2Model(nn.Module):
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
)
if
not
get_pp_group
().
is_last_rank
:
...
...
@@ -468,13 +457,10 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
...
...
@@ -553,12 +539,9 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
return
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
return
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
)
def
pooler
(
self
,
...
...
vllm/model_executor/models/qwen2_5_vl.py
View file @
cdc1fa12
...
...
@@ -37,7 +37,6 @@ from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
from
transformers.models.qwen2_5_vl.configuration_qwen2_5_vl
import
(
Qwen2_5_VLConfig
,
Qwen2_5_VLVisionConfig
)
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
parallel_state
,
tensor_model_parallel_all_gather
from
vllm.distributed
import
utils
as
dist_utils
...
...
@@ -992,8 +991,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
...
...
@@ -1047,8 +1044,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states
=
self
.
language_model
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
...
...
vllm/model_executor/models/qwen2_audio.py
View file @
cdc1fa12
...
...
@@ -22,8 +22,8 @@
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from
functools
import
cached_property
from
typing
import
(
Any
,
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
)
from
typing
import
(
Any
,
Iterable
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
)
import
torch
import
torch.nn
as
nn
...
...
@@ -33,7 +33,6 @@ from transformers.models.qwen2_audio import (Qwen2AudioConfig,
Qwen2AudioProcessor
)
from
transformers.models.whisper
import
WhisperFeatureExtractor
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
...
@@ -380,8 +379,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
...
...
@@ -400,8 +397,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
return
hidden_states
...
...
vllm/model_executor/models/qwen2_moe.py
View file @
cdc1fa12
...
...
@@ -23,14 +23,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
...
...
@@ -232,13 +232,11 @@ class Qwen2MoeAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -296,8 +294,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
...
...
@@ -310,8 +306,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
# Fully Connected
...
...
@@ -358,8 +352,6 @@ class Qwen2MoeModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
...
@@ -373,11 +365,8 @@ class Qwen2MoeModel(nn.Module):
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
)
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
...
...
@@ -416,13 +405,10 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment