Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cdc1fa12
"vllm/vscode:/vscode.git/clone" did not exist on "37d2b35d9d8a49600ba6141689e2c3af7b32323f"
Unverified
Commit
cdc1fa12
authored
Feb 25, 2025
by
Harry Mellor
Committed by
GitHub
Feb 24, 2025
Browse files
Remove unused kwargs from model definitions (#13555)
parent
f61528d4
Changes
104
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
85 additions
and
340 deletions
+85
-340
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+7
-24
vllm/model_executor/models/nemotron.py
vllm/model_executor/models/nemotron.py
+5
-23
vllm/model_executor/models/olmo.py
vllm/model_executor/models/olmo.py
+6
-22
vllm/model_executor/models/olmo2.py
vllm/model_executor/models/olmo2.py
+6
-22
vllm/model_executor/models/olmoe.py
vllm/model_executor/models/olmoe.py
+5
-19
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+7
-25
vllm/model_executor/models/orion.py
vllm/model_executor/models/orion.py
+6
-23
vllm/model_executor/models/paligemma.py
vllm/model_executor/models/paligemma.py
+1
-6
vllm/model_executor/models/persimmon.py
vllm/model_executor/models/persimmon.py
+5
-22
vllm/model_executor/models/phi.py
vllm/model_executor/models/phi.py
+6
-23
vllm/model_executor/models/phi3_small.py
vllm/model_executor/models/phi3_small.py
+5
-23
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+0
-5
vllm/model_executor/models/phimoe.py
vllm/model_executor/models/phimoe.py
+5
-19
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+0
-5
vllm/model_executor/models/prithvi_geospatial_mae.py
vllm/model_executor/models/prithvi_geospatial_mae.py
+1
-4
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+6
-20
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+6
-23
vllm/model_executor/models/qwen2_5_vl.py
vllm/model_executor/models/qwen2_5_vl.py
+0
-5
vllm/model_executor/models/qwen2_audio.py
vllm/model_executor/models/qwen2_audio.py
+2
-7
vllm/model_executor/models/qwen2_moe.py
vllm/model_executor/models/qwen2_moe.py
+6
-20
No files found.
vllm/model_executor/models/mpt.py
View file @
cdc1fa12
...
@@ -2,12 +2,12 @@
...
@@ -2,12 +2,12 @@
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import
math
import
math
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
...
@@ -125,8 +125,6 @@ class MPTAttention(nn.Module):
...
@@ -125,8 +125,6 @@ class MPTAttention(nn.Module):
self
,
self
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
del
position_ids
# unused.
del
position_ids
# unused.
qkv
,
_
=
self
.
Wqkv
(
hidden_states
)
qkv
,
_
=
self
.
Wqkv
(
hidden_states
)
...
@@ -136,7 +134,7 @@ class MPTAttention(nn.Module):
...
@@ -136,7 +134,7 @@ class MPTAttention(nn.Module):
if
self
.
qk_ln
:
if
self
.
qk_ln
:
q
=
self
.
q_ln
(
q
)
q
=
self
.
q_ln
(
q
)
k
=
self
.
k_ln
(
k
)
k
=
self
.
k_ln
(
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
out_proj
(
attn_output
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
return
output
...
@@ -196,15 +194,11 @@ class MPTBlock(nn.Module):
...
@@ -196,15 +194,11 @@ class MPTBlock(nn.Module):
self
,
self
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
x
=
self
.
norm_1
(
hidden_states
)
x
=
self
.
norm_1
(
hidden_states
)
x
=
self
.
attn
(
x
=
self
.
attn
(
position_ids
=
position_ids
,
position_ids
=
position_ids
,
hidden_states
=
x
,
hidden_states
=
x
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
hidden_states
=
hidden_states
+
x
hidden_states
=
hidden_states
+
x
x
=
self
.
norm_2
(
hidden_states
)
x
=
self
.
norm_2
(
hidden_states
)
...
@@ -253,8 +247,6 @@ class MPTModel(nn.Module):
...
@@ -253,8 +247,6 @@ class MPTModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -267,14 +259,8 @@ class MPTModel(nn.Module):
...
@@ -267,14 +259,8 @@ class MPTModel(nn.Module):
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
block
in
self
.
blocks
[
self
.
start_layer
:
self
.
end_layer
]:
block
=
self
.
blocks
[
i
]
hidden_states
=
block
(
position_ids
,
hidden_states
)
hidden_states
=
block
(
position_ids
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
hidden_states
=
self
.
norm_f
(
hidden_states
)
hidden_states
=
self
.
norm_f
(
hidden_states
)
...
@@ -306,14 +292,11 @@ class MPTForCausalLM(nn.Module, SupportsPP):
...
@@ -306,14 +292,11 @@ class MPTForCausalLM(nn.Module, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
attn_metadata
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
hidden_states
return
hidden_states
def
compute_logits
(
def
compute_logits
(
...
...
vllm/model_executor/models/nemotron.py
View file @
cdc1fa12
...
@@ -27,7 +27,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
...
@@ -27,7 +27,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
@@ -204,13 +204,11 @@ class NemotronAttention(nn.Module):
...
@@ -204,13 +204,11 @@ class NemotronAttention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -269,8 +267,6 @@ class NemotronDecoderLayer(nn.Module):
...
@@ -269,8 +267,6 @@ class NemotronDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
# Self Attention
...
@@ -283,8 +279,6 @@ class NemotronDecoderLayer(nn.Module):
...
@@ -283,8 +279,6 @@ class NemotronDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
# Fully Connected
# Fully Connected
...
@@ -343,8 +337,6 @@ class NemotronModel(nn.Module):
...
@@ -343,8 +337,6 @@ class NemotronModel(nn.Module):
self
,
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -359,15 +351,8 @@ class NemotronModel(nn.Module):
...
@@ -359,15 +351,8 @@ class NemotronModel(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
return
IntermediateTensors
({
...
@@ -444,13 +429,10 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -444,13 +429,10 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
model_output
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
model_output
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
model_output
return
model_output
...
...
vllm/model_executor/models/olmo.py
View file @
cdc1fa12
...
@@ -22,13 +22,13 @@
...
@@ -22,13 +22,13 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only OLMo model compatible with HuggingFace weights."""
"""Inference-only OLMo model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
OlmoConfig
from
transformers
import
OlmoConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
@@ -119,15 +119,13 @@ class OlmoAttention(nn.Module):
...
@@ -119,15 +119,13 @@ class OlmoAttention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
if
self
.
clip_qkv
is
not
None
:
if
self
.
clip_qkv
is
not
None
:
qkv
.
clamp_
(
min
=-
self
.
clip_qkv
,
max
=
self
.
clip_qkv
)
qkv
.
clamp_
(
min
=-
self
.
clip_qkv
,
max
=
self
.
clip_qkv
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -212,14 +210,11 @@ class OlmoDecoderLayer(nn.Module):
...
@@ -212,14 +210,11 @@ class OlmoDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]:
# Attention block.
# Attention block.
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
positions
,
hidden_states
,
kv_cache
,
hidden_states
=
self
.
self_attn
(
positions
,
hidden_states
)
attn_metadata
)
hidden_states
=
hidden_states
+
residual
hidden_states
=
hidden_states
+
residual
# MLP block.
# MLP block.
...
@@ -263,8 +258,6 @@ class OlmoModel(nn.Module):
...
@@ -263,8 +258,6 @@ class OlmoModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -281,14 +274,9 @@ class OlmoModel(nn.Module):
...
@@ -281,14 +274,9 @@ class OlmoModel(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
# Apply blocks one-by-one.
# Apply blocks one-by-one.
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
)
:
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]
:
# shape: (batch_size, seq_len, d_model)
# shape: (batch_size, seq_len, d_model)
hidden_states
=
self
.
layers
[
i
](
hidden_states
=
layer
(
positions
,
hidden_states
)
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
...
@@ -332,16 +320,12 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
...
@@ -332,16 +320,12 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
)
)
...
...
vllm/model_executor/models/olmo2.py
View file @
cdc1fa12
...
@@ -24,12 +24,12 @@
...
@@ -24,12 +24,12 @@
"""Inference-only OLMo2 model compatible with HuggingFace weights."""
"""Inference-only OLMo2 model compatible with HuggingFace weights."""
from
functools
import
partial
from
functools
import
partial
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed.communication_op
import
tensor_model_parallel_all_gather
from
vllm.distributed.communication_op
import
tensor_model_parallel_all_gather
...
@@ -153,14 +153,12 @@ class Olmo2Attention(nn.Module):
...
@@ -153,14 +153,12 @@ class Olmo2Attention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -239,13 +237,10 @@ class Olmo2DecoderLayer(nn.Module):
...
@@ -239,13 +237,10 @@ class Olmo2DecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Attention block.
# Attention block.
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
self_attn
(
positions
,
hidden_states
,
kv_cache
,
hidden_states
=
self
.
self_attn
(
positions
,
hidden_states
)
attn_metadata
)
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
hidden_states
+
residual
hidden_states
=
hidden_states
+
residual
...
@@ -287,8 +282,6 @@ class Olmo2Model(nn.Module):
...
@@ -287,8 +282,6 @@ class Olmo2Model(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
"""
"""
...
@@ -307,14 +300,9 @@ class Olmo2Model(nn.Module):
...
@@ -307,14 +300,9 @@ class Olmo2Model(nn.Module):
assert
isinstance
(
hidden_states
,
torch
.
Tensor
)
assert
isinstance
(
hidden_states
,
torch
.
Tensor
)
# Apply blocks one-by-one.
# Apply blocks one-by-one.
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
)
:
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]
:
# shape: (batch_size, seq_len, d_model)
# shape: (batch_size, seq_len, d_model)
hidden_states
=
self
.
layers
[
i
](
hidden_states
=
layer
(
positions
,
hidden_states
)
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
...
@@ -357,15 +345,11 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP):
...
@@ -357,15 +345,11 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
)
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/olmoe.py
View file @
cdc1fa12
...
@@ -12,13 +12,13 @@
...
@@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only OLMoE model compatible with HuggingFace weights."""
"""Inference-only OLMoE model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
@@ -168,14 +168,12 @@ class OlmoeAttention(nn.Module):
...
@@ -168,14 +168,12 @@ class OlmoeAttention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
q_norm
(
q
.
contiguous
()),
self
.
k_norm
(
k
.
contiguous
())
q
,
k
=
self
.
q_norm
(
q
.
contiguous
()),
self
.
k_norm
(
k
.
contiguous
())
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -222,8 +220,6 @@ class OlmoeDecoderLayer(nn.Module):
...
@@ -222,8 +220,6 @@ class OlmoeDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Self Attention
# Self Attention
...
@@ -237,8 +233,6 @@ class OlmoeDecoderLayer(nn.Module):
...
@@ -237,8 +233,6 @@ class OlmoeDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
# Fully Connected
# Fully Connected
...
@@ -283,8 +277,6 @@ class OlmoeModel(nn.Module):
...
@@ -283,8 +277,6 @@ class OlmoeModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -299,13 +291,10 @@ class OlmoeModel(nn.Module):
...
@@ -299,13 +291,10 @@ class OlmoeModel(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
positions
,
hidden_states
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
residual
,
)
)
...
@@ -347,13 +336,10 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
...
@@ -347,13 +336,10 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/opt.py
View file @
cdc1fa12
...
@@ -18,13 +18,13 @@
...
@@ -18,13 +18,13 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only OPT model compatible with HuggingFace weights."""
"""Inference-only OPT model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
OPTConfig
from
transformers
import
OPTConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
@@ -107,12 +107,10 @@ class OPTAttention(nn.Module):
...
@@ -107,12 +107,10 @@ class OPTAttention(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
out_proj
(
attn_output
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
return
output
...
@@ -164,17 +162,13 @@ class OPTDecoderLayer(nn.Module):
...
@@ -164,17 +162,13 @@ class OPTDecoderLayer(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Self Attention
# Self Attention
residual
=
hidden_states
residual
=
hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if
self
.
do_layer_norm_before
:
if
self
.
do_layer_norm_before
:
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
)
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
# 350m applies layer norm AFTER attention
# 350m applies layer norm AFTER attention
if
not
self
.
do_layer_norm_before
:
if
not
self
.
do_layer_norm_before
:
...
@@ -261,8 +255,6 @@ class OPTDecoder(nn.Module):
...
@@ -261,8 +255,6 @@ class OPTDecoder(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -277,11 +269,8 @@ class OPTDecoder(nn.Module):
...
@@ -277,11 +269,8 @@ class OPTDecoder(nn.Module):
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
hidden_states
)
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
...
@@ -317,15 +306,11 @@ class OPTModel(nn.Module):
...
@@ -317,15 +306,11 @@ class OPTModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
return
self
.
decoder
(
input_ids
,
return
self
.
decoder
(
input_ids
,
positions
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
inputs_embeds
=
inputs_embeds
)
...
@@ -362,13 +347,10 @@ class OPTForCausalLM(nn.Module, SupportsPP):
...
@@ -362,13 +347,10 @@ class OPTForCausalLM(nn.Module, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/orion.py
View file @
cdc1fa12
...
@@ -5,13 +5,13 @@
...
@@ -5,13 +5,13 @@
# Copyright (c) OrionStar Inc.
# Copyright (c) OrionStar Inc.
# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
"""Inference-only Orion-14B model compatible with HuggingFace weights."""
"""Inference-only Orion-14B model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
@@ -136,13 +136,11 @@ class OrionAttention(nn.Module):
...
@@ -136,13 +136,11 @@ class OrionAttention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -189,8 +187,6 @@ class OrionDecoderLayer(nn.Module):
...
@@ -189,8 +187,6 @@ class OrionDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
# Self Attention
residual
=
hidden_states
residual
=
hidden_states
...
@@ -198,8 +194,6 @@ class OrionDecoderLayer(nn.Module):
...
@@ -198,8 +194,6 @@ class OrionDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -247,8 +241,6 @@ class OrionModel(nn.Module):
...
@@ -247,8 +241,6 @@ class OrionModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -260,14 +252,8 @@ class OrionModel(nn.Module):
...
@@ -260,14 +252,8 @@ class OrionModel(nn.Module):
else
:
else
:
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
)
hidden_states
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"hidden_states"
:
hidden_states
,
...
@@ -303,13 +289,10 @@ class OrionForCausalLM(nn.Module, SupportsPP):
...
@@ -303,13 +289,10 @@ class OrionForCausalLM(nn.Module, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/paligemma.py
View file @
cdc1fa12
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Set
,
Tuple
,
from
typing
import
(
Iterable
,
Literal
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
)
TypedDict
,
Union
)
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PaliGemmaConfig
from
transformers
import
PaliGemmaConfig
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
InputContext
,
token_inputs
)
InputContext
,
token_inputs
)
...
@@ -288,8 +287,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -288,8 +287,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
def
forward
(
self
,
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
)
->
Union
[
SamplerOutput
,
IntermediateTensors
]:
**
kwargs
:
object
)
->
Union
[
SamplerOutput
,
IntermediateTensors
]:
...
@@ -306,8 +303,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -306,8 +303,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
inputs_embeds
=
inputs_embeds
)
...
...
vllm/model_executor/models/persimmon.py
View file @
cdc1fa12
...
@@ -21,13 +21,13 @@
...
@@ -21,13 +21,13 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only persimmon model compatible with HuggingFace weights."""
"""Inference-only persimmon model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PersimmonConfig
from
transformers
import
PersimmonConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
@@ -142,8 +142,6 @@ class PersimmonAttention(nn.Module):
...
@@ -142,8 +142,6 @@ class PersimmonAttention(nn.Module):
self
,
self
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# [seq_length, 3 x hidden_size]
# [seq_length, 3 x hidden_size]
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
...
@@ -161,7 +159,7 @@ class PersimmonAttention(nn.Module):
...
@@ -161,7 +159,7 @@ class PersimmonAttention(nn.Module):
k
=
self
.
_merge_heads
(
k
)
k
=
self
.
_merge_heads
(
k
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
dense
(
attn_output
)
output
,
_
=
self
.
dense
(
attn_output
)
return
output
return
output
...
@@ -189,8 +187,6 @@ class PersimmonDecoderLayer(nn.Module):
...
@@ -189,8 +187,6 @@ class PersimmonDecoderLayer(nn.Module):
self
,
self
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
residual
=
hidden_states
residual
=
hidden_states
...
@@ -200,8 +196,6 @@ class PersimmonDecoderLayer(nn.Module):
...
@@ -200,8 +196,6 @@ class PersimmonDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
position_ids
=
position_ids
,
position_ids
=
position_ids
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -248,8 +242,6 @@ class PersimmonModel(nn.Module):
...
@@ -248,8 +242,6 @@ class PersimmonModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -261,13 +253,8 @@ class PersimmonModel(nn.Module):
...
@@ -261,13 +253,8 @@ class PersimmonModel(nn.Module):
else
:
else
:
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
=
self
.
layers
[
i
](
hidden_states
=
layer
(
positions
,
hidden_states
)
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
...
@@ -298,16 +285,12 @@ class PersimmonForCausalLM(nn.Module, SupportsPP):
...
@@ -298,16 +285,12 @@ class PersimmonForCausalLM(nn.Module, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
)
)
...
...
vllm/model_executor/models/phi.py
View file @
cdc1fa12
...
@@ -36,13 +36,13 @@
...
@@ -36,13 +36,13 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PhiConfig
from
transformers
import
PhiConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
@@ -126,13 +126,11 @@ class PhiAttention(nn.Module):
...
@@ -126,13 +126,11 @@ class PhiAttention(nn.Module):
self
,
self
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
dense
(
attn_output
)
output
,
_
=
self
.
dense
(
attn_output
)
return
output
return
output
...
@@ -186,16 +184,12 @@ class PhiLayer(nn.Module):
...
@@ -186,16 +184,12 @@ class PhiLayer(nn.Module):
self
,
self
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
attn_outputs
=
self
.
self_attn
(
attn_outputs
=
self
.
self_attn
(
position_ids
=
position_ids
,
position_ids
=
position_ids
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
feed_forward_hidden_states
=
self
.
mlp
(
hidden_states
)
feed_forward_hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
attn_outputs
+
feed_forward_hidden_states
+
residual
hidden_states
=
attn_outputs
+
feed_forward_hidden_states
+
residual
...
@@ -234,8 +228,6 @@ class PhiModel(nn.Module):
...
@@ -234,8 +228,6 @@ class PhiModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -247,14 +239,8 @@ class PhiModel(nn.Module):
...
@@ -247,14 +239,8 @@ class PhiModel(nn.Module):
else
:
else
:
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
)
hidden_states
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
...
@@ -304,13 +290,10 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -304,13 +290,10 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/phi3_small.py
View file @
cdc1fa12
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
math
import
math
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.configuration_utils
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
...
@@ -231,8 +231,6 @@ class Phi3SmallSelfAttention(nn.Module):
...
@@ -231,8 +231,6 @@ class Phi3SmallSelfAttention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
Optional
[
Tuple
[
torch
.
Tensor
]]]:
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
...
@@ -248,7 +246,7 @@ class Phi3SmallSelfAttention(nn.Module):
...
@@ -248,7 +246,7 @@ class Phi3SmallSelfAttention(nn.Module):
v
=
v
.
reshape
(
-
1
,
self
.
head_dim
*
self
.
num_kv_heads_per_partion
)
v
=
v
.
reshape
(
-
1
,
self
.
head_dim
*
self
.
num_kv_heads_per_partion
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
=
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
dense
(
attn_output
)
output
,
_
=
self
.
dense
(
attn_output
)
return
output
return
output
...
@@ -282,8 +280,6 @@ class Phi3SmallDecoderLayer(nn.Module):
...
@@ -282,8 +280,6 @@ class Phi3SmallDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
...
@@ -291,8 +287,6 @@ class Phi3SmallDecoderLayer(nn.Module):
...
@@ -291,8 +287,6 @@ class Phi3SmallDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -338,8 +332,6 @@ class Phi3SmallModel(nn.Module):
...
@@ -338,8 +332,6 @@ class Phi3SmallModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
LongTensor
,
input_ids
:
torch
.
LongTensor
,
positions
:
Optional
[
torch
.
LongTensor
],
positions
:
Optional
[
torch
.
LongTensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
],
inputs_embeds
:
Optional
[
torch
.
Tensor
],
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -354,14 +346,8 @@ class Phi3SmallModel(nn.Module):
...
@@ -354,14 +346,8 @@ class Phi3SmallModel(nn.Module):
else
:
else
:
assert
intermediate_tensors
assert
intermediate_tensors
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
)
hidden_states
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
...
@@ -438,16 +424,12 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP):
...
@@ -438,16 +424,12 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP):
self
,
self
,
input_ids
:
torch
.
LongTensor
,
input_ids
:
torch
.
LongTensor
,
positions
:
Optional
[
torch
.
LongTensor
],
positions
:
Optional
[
torch
.
LongTensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
output_hidden_states
=
self
.
model
(
output_hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
)
)
...
...
vllm/model_executor/models/phi3v.py
View file @
cdc1fa12
...
@@ -23,7 +23,6 @@ import torch.nn as nn
...
@@ -23,7 +23,6 @@ import torch.nn as nn
from
transformers
import
(
BatchFeature
,
CLIPVisionConfig
,
PretrainedConfig
,
from
transformers
import
(
BatchFeature
,
CLIPVisionConfig
,
PretrainedConfig
,
ProcessorMixin
)
ProcessorMixin
)
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
...
@@ -672,8 +671,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -672,8 +671,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
def
forward
(
self
,
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
):
**
kwargs
:
object
):
...
@@ -691,8 +688,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -691,8 +688,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
inputs_embeds
=
inputs_embeds
)
...
...
vllm/model_executor/models/phimoe.py
View file @
cdc1fa12
...
@@ -22,13 +22,13 @@
...
@@ -22,13 +22,13 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only PhiMoE model."""
"""Inference-only PhiMoE model."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.configuration_utils
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
@@ -357,13 +357,11 @@ class PhiMoEAttention(nn.Module):
...
@@ -357,13 +357,11 @@ class PhiMoEAttention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -410,8 +408,6 @@ class PhiMoEDecoderLayer(nn.Module):
...
@@ -410,8 +408,6 @@ class PhiMoEDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
residual
=
hidden_states
residual
=
hidden_states
...
@@ -422,8 +418,6 @@ class PhiMoEDecoderLayer(nn.Module):
...
@@ -422,8 +418,6 @@ class PhiMoEDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
hidden_states
=
hidden_states
+
residual
hidden_states
=
hidden_states
+
residual
...
@@ -478,8 +472,6 @@ class PhiMoEModel(nn.Module):
...
@@ -478,8 +472,6 @@ class PhiMoEModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -494,13 +486,10 @@ class PhiMoEModel(nn.Module):
...
@@ -494,13 +486,10 @@ class PhiMoEModel(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
positions
,
hidden_states
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
residual
,
)
)
...
@@ -571,13 +560,10 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -571,13 +560,10 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/pixtral.py
View file @
cdc1fa12
...
@@ -16,7 +16,6 @@ from transformers.models.pixtral.image_processing_pixtral import (
...
@@ -16,7 +16,6 @@ from transformers.models.pixtral.image_processing_pixtral import (
from
transformers.models.pixtral.modeling_pixtral
import
(
from
transformers.models.pixtral.modeling_pixtral
import
(
PixtralRotaryEmbedding
,
apply_rotary_pos_emb
,
position_ids_in_meshgrid
)
PixtralRotaryEmbedding
,
apply_rotary_pos_emb
,
position_ids_in_meshgrid
)
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
...
@@ -270,8 +269,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -270,8 +269,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
**
kwargs
:
object
,
...
@@ -291,8 +288,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -291,8 +288,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
inputs_embeds
=
inputs_embeds
)
...
...
vllm/model_executor/models/prithvi_geospatial_mae.py
View file @
cdc1fa12
...
@@ -15,13 +15,12 @@
...
@@ -15,13 +15,12 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only IBM/NASA Prithvi Geospatial model."""
"""Inference-only IBM/NASA Prithvi Geospatial model."""
from
typing
import
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Mapping
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
transformers
import
BatchFeature
from
transformers
import
BatchFeature
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
(
IsAttentionFree
,
from
vllm.model_executor.models.interfaces
import
(
IsAttentionFree
,
...
@@ -181,8 +180,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
...
@@ -181,8 +180,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
self
,
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
**
kwargs
:
object
,
...
...
vllm/model_executor/models/qwen.py
View file @
cdc1fa12
...
@@ -6,13 +6,13 @@
...
@@ -6,13 +6,13 @@
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
"""Inference-only QWen model compatible with HuggingFace weights."""
"""Inference-only QWen model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
@@ -124,13 +124,11 @@ class QWenAttention(nn.Module):
...
@@ -124,13 +124,11 @@ class QWenAttention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
c_proj
(
attn_output
)
output
,
_
=
self
.
c_proj
(
attn_output
)
return
output
return
output
...
@@ -168,8 +166,6 @@ class QWenBlock(nn.Module):
...
@@ -168,8 +166,6 @@ class QWenBlock(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
# Self Attention
...
@@ -181,8 +177,6 @@ class QWenBlock(nn.Module):
...
@@ -181,8 +177,6 @@ class QWenBlock(nn.Module):
hidden_states
=
self
.
attn
(
hidden_states
=
self
.
attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
# Fully Connected
# Fully Connected
...
@@ -225,8 +219,6 @@ class QWenModel(nn.Module):
...
@@ -225,8 +219,6 @@ class QWenModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -241,13 +233,10 @@ class QWenModel(nn.Module):
...
@@ -241,13 +233,10 @@ class QWenModel(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
h
[
self
.
start_layer
:
self
.
end_layer
]:
layer
=
self
.
h
[
i
]
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
positions
,
hidden_states
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
residual
,
)
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
...
@@ -373,12 +362,9 @@ class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA):
...
@@ -373,12 +362,9 @@ class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
attn_metadata
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
hidden_states
return
hidden_states
vllm/model_executor/models/qwen2.py
View file @
cdc1fa12
...
@@ -23,13 +23,13 @@
...
@@ -23,13 +23,13 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
Qwen2Config
from
transformers
import
Qwen2Config
from
vllm.attention
import
Attention
,
AttentionMetadata
,
AttentionType
from
vllm.attention
import
Attention
,
AttentionType
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
...
@@ -170,13 +170,11 @@ class Qwen2Attention(nn.Module):
...
@@ -170,13 +170,11 @@ class Qwen2Attention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -233,8 +231,6 @@ class Qwen2DecoderLayer(nn.Module):
...
@@ -233,8 +231,6 @@ class Qwen2DecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
# Self Attention
...
@@ -247,8 +243,6 @@ class Qwen2DecoderLayer(nn.Module):
...
@@ -247,8 +243,6 @@ class Qwen2DecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
# Fully Connected
# Fully Connected
...
@@ -328,8 +322,6 @@ class Qwen2Model(nn.Module):
...
@@ -328,8 +322,6 @@ class Qwen2Model(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -343,13 +335,10 @@ class Qwen2Model(nn.Module):
...
@@ -343,13 +335,10 @@ class Qwen2Model(nn.Module):
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
positions
,
hidden_states
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
residual
,
)
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
...
@@ -468,13 +457,10 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -468,13 +457,10 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
hidden_states
return
hidden_states
...
@@ -553,12 +539,9 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -553,12 +539,9 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
return
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
)
intermediate_tensors
)
def
pooler
(
def
pooler
(
self
,
self
,
...
...
vllm/model_executor/models/qwen2_5_vl.py
View file @
cdc1fa12
...
@@ -37,7 +37,6 @@ from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
...
@@ -37,7 +37,6 @@ from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
from
transformers.models.qwen2_5_vl.configuration_qwen2_5_vl
import
(
from
transformers.models.qwen2_5_vl.configuration_qwen2_5_vl
import
(
Qwen2_5_VLConfig
,
Qwen2_5_VLVisionConfig
)
Qwen2_5_VLConfig
,
Qwen2_5_VLVisionConfig
)
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
parallel_state
,
tensor_model_parallel_all_gather
from
vllm.distributed
import
parallel_state
,
tensor_model_parallel_all_gather
from
vllm.distributed
import
utils
as
dist_utils
from
vllm.distributed
import
utils
as
dist_utils
...
@@ -992,8 +991,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -992,8 +991,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
**
kwargs
:
object
,
...
@@ -1047,8 +1044,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1047,8 +1044,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states
=
self
.
language_model
.
model
(
hidden_states
=
self
.
language_model
.
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
)
)
...
...
vllm/model_executor/models/qwen2_audio.py
View file @
cdc1fa12
...
@@ -22,8 +22,8 @@
...
@@ -22,8 +22,8 @@
# limitations under the License.
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
(
Any
,
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
from
typing
import
(
Any
,
Iterable
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
TypedDict
,
Union
)
Union
)
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -33,7 +33,6 @@ from transformers.models.qwen2_audio import (Qwen2AudioConfig,
...
@@ -33,7 +33,6 @@ from transformers.models.qwen2_audio import (Qwen2AudioConfig,
Qwen2AudioProcessor
)
Qwen2AudioProcessor
)
from
transformers.models.whisper
import
WhisperFeatureExtractor
from
transformers.models.whisper
import
WhisperFeatureExtractor
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
@@ -380,8 +379,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -380,8 +379,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
**
kwargs
:
object
,
...
@@ -400,8 +397,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -400,8 +397,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
inputs_embeds
=
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/qwen2_moe.py
View file @
cdc1fa12
...
@@ -23,14 +23,14 @@
...
@@ -23,14 +23,14 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
from
vllm.distributed
import
(
get_pp_group
,
...
@@ -232,13 +232,11 @@ class Qwen2MoeAttention(nn.Module):
...
@@ -232,13 +232,11 @@ class Qwen2MoeAttention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -296,8 +294,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
...
@@ -296,8 +294,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Self Attention
# Self Attention
...
@@ -310,8 +306,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
...
@@ -310,8 +306,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
)
# Fully Connected
# Fully Connected
...
@@ -358,8 +352,6 @@ class Qwen2MoeModel(nn.Module):
...
@@ -358,8 +352,6 @@ class Qwen2MoeModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
...
@@ -373,11 +365,8 @@ class Qwen2MoeModel(nn.Module):
...
@@ -373,11 +365,8 @@ class Qwen2MoeModel(nn.Module):
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"hidden_states"
:
hidden_states
,
...
@@ -416,13 +405,10 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
...
@@ -416,13 +405,10 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment