Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
643ecf7b
Unverified
Commit
643ecf7b
authored
Nov 16, 2024
by
Roger Wang
Committed by
GitHub
Nov 17, 2024
Browse files
[V1] Refactor model executable interface for all text-only language models (#10374)
Signed-off-by:
Roger Wang
<
ywang@roblox.com
>
parent
4fd93750
Changes
43
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
212 additions
and
43 deletions
+212
-43
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+14
-2
vllm/model_executor/models/mamba.py
vllm/model_executor/models/mamba.py
+13
-2
vllm/model_executor/models/minicpm.py
vllm/model_executor/models/minicpm.py
+6
-1
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+14
-2
vllm/model_executor/models/mixtral_quant.py
vllm/model_executor/models/mixtral_quant.py
+14
-2
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+14
-2
vllm/model_executor/models/nemotron.py
vllm/model_executor/models/nemotron.py
+6
-1
vllm/model_executor/models/olmo.py
vllm/model_executor/models/olmo.py
+13
-6
vllm/model_executor/models/olmoe.py
vllm/model_executor/models/olmoe.py
+14
-2
vllm/model_executor/models/orion.py
vllm/model_executor/models/orion.py
+14
-2
vllm/model_executor/models/persimmon.py
vllm/model_executor/models/persimmon.py
+7
-1
vllm/model_executor/models/phi.py
vllm/model_executor/models/phi.py
+14
-2
vllm/model_executor/models/phi3_small.py
vllm/model_executor/models/phi3_small.py
+11
-8
vllm/model_executor/models/phimoe.py
vllm/model_executor/models/phimoe.py
+14
-2
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+14
-2
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+1
-1
vllm/model_executor/models/qwen2_cls.py
vllm/model_executor/models/qwen2_cls.py
+6
-1
vllm/model_executor/models/qwen2_moe.py
vllm/model_executor/models/qwen2_moe.py
+14
-2
vllm/model_executor/models/qwen2_rm.py
vllm/model_executor/models/qwen2_rm.py
+6
-1
vllm/model_executor/models/solar.py
vllm/model_executor/models/solar.py
+3
-1
No files found.
vllm/model_executor/models/jamba.py
View file @
643ecf7b
...
@@ -292,6 +292,9 @@ class JambaModel(nn.Module):
...
@@ -292,6 +292,9 @@ class JambaModel(nn.Module):
self
.
final_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
final_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -299,8 +302,12 @@ class JambaModel(nn.Module):
...
@@ -299,8 +302,12 @@ class JambaModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
mamba_cache_params
:
MambaCacheParams
,
mamba_cache_params
:
MambaCacheParams
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
...
@@ -381,12 +388,16 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
...
@@ -381,12 +388,16 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
config
.
vocab_size
)
config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
self
.
sampler
=
get_sampler
()
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
**
kwargs
):
if
self
.
mamba_cache
is
None
:
if
self
.
mamba_cache
is
None
:
max_batch_size
=
(
_get_graph_batch_size
(
max_batch_size
=
(
_get_graph_batch_size
(
...
@@ -409,7 +420,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
...
@@ -409,7 +420,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
mamba_cache_tensors
[
1
],
mamba_cache_tensors
[
1
],
state_indices_tensor
)
state_indices_tensor
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
mamba_cache_params
)
attn_metadata
,
mamba_cache_params
,
inputs_embeds
)
return
hidden_states
return
hidden_states
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
def
copy_inputs_before_cuda_graphs
(
self
,
input_buffers
,
**
kwargs
):
...
...
vllm/model_executor/models/mamba.py
View file @
643ecf7b
...
@@ -106,15 +106,22 @@ class MambaModel(nn.Module):
...
@@ -106,15 +106,22 @@ class MambaModel(nn.Module):
self
.
norm_f
=
RMSNorm
(
config
.
hidden_size
,
self
.
norm_f
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
eps
=
config
.
layer_norm_epsilon
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embeddings
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
mamba_cache_params
:
MambaCacheParams
,
mamba_cache_params
:
MambaCacheParams
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embeddings
(
input_ids
)
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
for
i
in
range
(
len
(
self
.
layers
)):
...
@@ -168,12 +175,16 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
...
@@ -168,12 +175,16 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
config
.
vocab_size
)
config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
self
.
sampler
=
get_sampler
()
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
backbone
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
):
**
kwargs
):
if
self
.
mamba_cache
is
None
:
if
self
.
mamba_cache
is
None
:
max_batch_size
=
(
_get_graph_batch_size
(
max_batch_size
=
(
_get_graph_batch_size
(
...
@@ -194,7 +205,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
...
@@ -194,7 +205,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
state_indices_tensor
)
state_indices_tensor
)
hidden_states
=
self
.
backbone
(
input_ids
,
positions
,
attn_metadata
,
hidden_states
=
self
.
backbone
(
input_ids
,
positions
,
attn_metadata
,
mamba_cache_params
)
mamba_cache_params
,
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/minicpm.py
View file @
643ecf7b
...
@@ -504,6 +504,9 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -504,6 +504,9 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
model
=
MiniCPMModel
(
vllm_config
=
vllm_config
,
self
.
model
=
MiniCPMModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
prefix
=
maybe_prefix
(
prefix
,
"model"
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -511,9 +514,11 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -511,9 +514,11 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
return
hidden_states
def
compute_logits
(
def
compute_logits
(
...
...
vllm/model_executor/models/mixtral.py
View file @
643ecf7b
...
@@ -281,6 +281,9 @@ class MixtralModel(nn.Module):
...
@@ -281,6 +281,9 @@ class MixtralModel(nn.Module):
make_empty_intermediate_tensors_factory
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -288,9 +291,13 @@ class MixtralModel(nn.Module):
...
@@ -288,9 +291,13 @@ class MixtralModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
residual
=
None
else
:
else
:
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
...
@@ -363,6 +370,9 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -363,6 +370,9 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -370,9 +380,11 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -370,9 +380,11 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
return
hidden_states
def
compute_logits
(
def
compute_logits
(
...
...
vllm/model_executor/models/mixtral_quant.py
View file @
643ecf7b
...
@@ -318,6 +318,9 @@ class MixtralModel(nn.Module):
...
@@ -318,6 +318,9 @@ class MixtralModel(nn.Module):
make_empty_intermediate_tensors_factory
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -325,9 +328,13 @@ class MixtralModel(nn.Module):
...
@@ -325,9 +328,13 @@ class MixtralModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
residual
=
None
else
:
else
:
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
...
@@ -368,6 +375,9 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
...
@@ -368,6 +375,9 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -375,9 +385,11 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
...
@@ -375,9 +385,11 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
return
hidden_states
def
compute_logits
(
def
compute_logits
(
...
...
vllm/model_executor/models/mpt.py
View file @
643ecf7b
...
@@ -237,6 +237,9 @@ class MPTModel(nn.Module):
...
@@ -237,6 +237,9 @@ class MPTModel(nn.Module):
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
d_model
))
config
.
d_model
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
wte
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -244,9 +247,13 @@ class MPTModel(nn.Module):
...
@@ -244,9 +247,13 @@ class MPTModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
wte
(
input_ids
)
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
else
:
else
:
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
...
@@ -283,6 +290,9 @@ class MPTForCausalLM(nn.Module, SupportsPP):
...
@@ -283,6 +290,9 @@ class MPTForCausalLM(nn.Module, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
transformer
.
make_empty_intermediate_tensors
)
self
.
transformer
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
transformer
.
get_input_embeddings
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -290,9 +300,11 @@ class MPTForCausalLM(nn.Module, SupportsPP):
...
@@ -290,9 +300,11 @@ class MPTForCausalLM(nn.Module, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
return
hidden_states
def
compute_logits
(
def
compute_logits
(
...
...
vllm/model_executor/models/nemotron.py
View file @
643ecf7b
...
@@ -440,6 +440,9 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -440,6 +440,9 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -447,9 +450,11 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -447,9 +450,11 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
model_output
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
model_output
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
model_output
return
model_output
def
compute_logits
(
def
compute_logits
(
...
...
vllm/model_executor/models/olmo.py
View file @
643ecf7b
...
@@ -248,6 +248,9 @@ class OlmoModel(nn.Module):
...
@@ -248,6 +248,9 @@ class OlmoModel(nn.Module):
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
hidden_size
))
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -255,17 +258,16 @@ class OlmoModel(nn.Module):
...
@@ -255,17 +258,16 @@ class OlmoModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
"""
"""
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
"""
"""
if
get_pp_group
().
is_first_rank
:
if
get_pp_group
().
is_first_rank
:
# Get embeddings of input.
if
inputs_embeds
is
not
None
:
# shape: (batch_size, seq_len, d_model)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
# embed positions
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
else
:
else
:
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
...
@@ -315,6 +317,9 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
...
@@ -315,6 +317,9 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -322,6 +327,7 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
...
@@ -322,6 +327,7 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
...
@@ -329,6 +335,7 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
...
@@ -329,6 +335,7 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
kv_caches
=
kv_caches
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/olmoe.py
View file @
643ecf7b
...
@@ -269,6 +269,9 @@ class OlmoeModel(nn.Module):
...
@@ -269,6 +269,9 @@ class OlmoeModel(nn.Module):
make_empty_intermediate_tensors_factory
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -276,9 +279,13 @@ class OlmoeModel(nn.Module):
...
@@ -276,9 +279,13 @@ class OlmoeModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
residual
=
None
else
:
else
:
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
...
@@ -326,6 +333,9 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
...
@@ -326,6 +333,9 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -333,9 +343,11 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
...
@@ -333,9 +343,11 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/orion.py
View file @
643ecf7b
...
@@ -237,6 +237,9 @@ class OrionModel(nn.Module):
...
@@ -237,6 +237,9 @@ class OrionModel(nn.Module):
"hidden_states"
,
"hidden_states"
,
],
config
.
hidden_size
))
],
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -244,9 +247,13 @@ class OrionModel(nn.Module):
...
@@ -244,9 +247,13 @@ class OrionModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
else
:
else
:
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
...
@@ -286,6 +293,9 @@ class OrionForCausalLM(nn.Module, SupportsPP):
...
@@ -286,6 +293,9 @@ class OrionForCausalLM(nn.Module, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -293,9 +303,11 @@ class OrionForCausalLM(nn.Module, SupportsPP):
...
@@ -293,9 +303,11 @@ class OrionForCausalLM(nn.Module, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
return
hidden_states
def
compute_logits
(
def
compute_logits
(
...
...
vllm/model_executor/models/persimmon.py
View file @
643ecf7b
...
@@ -235,6 +235,9 @@ class PersimmonModel(nn.Module):
...
@@ -235,6 +235,9 @@ class PersimmonModel(nn.Module):
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
hidden_size
))
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -248,7 +251,7 @@ class PersimmonModel(nn.Module):
...
@@ -248,7 +251,7 @@ class PersimmonModel(nn.Module):
if
inputs_embeds
is
not
None
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
else
:
else
:
hidden_states
=
self
.
embed_token
s
(
input_ids
)
hidden_states
=
self
.
get_input_embedding
s
(
input_ids
)
else
:
else
:
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
...
@@ -282,6 +285,9 @@ class PersimmonForCausalLM(nn.Module, SupportsPP):
...
@@ -282,6 +285,9 @@ class PersimmonForCausalLM(nn.Module, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/phi.py
View file @
643ecf7b
...
@@ -218,6 +218,9 @@ class PhiModel(nn.Module):
...
@@ -218,6 +218,9 @@ class PhiModel(nn.Module):
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
hidden_size
))
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -225,9 +228,13 @@ class PhiModel(nn.Module):
...
@@ -225,9 +228,13 @@ class PhiModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
else
:
else
:
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
...
@@ -303,6 +310,9 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -303,6 +310,9 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -310,9 +320,11 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -310,9 +320,11 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/phi3_small.py
View file @
643ecf7b
...
@@ -324,11 +324,8 @@ class Phi3SmallModel(nn.Module):
...
@@ -324,11 +324,8 @@ class Phi3SmallModel(nn.Module):
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
hidden_size
))
config
.
hidden_size
))
def
get_input_embeddings
(
self
):
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
return
self
.
embed_tokens
(
input_ids
)
def
set_input_embeddings
(
self
,
value
):
self
.
embed_tokens
=
value
def
forward
(
def
forward
(
self
,
self
,
...
@@ -337,9 +334,13 @@ class Phi3SmallModel(nn.Module):
...
@@ -337,9 +334,13 @@ class Phi3SmallModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
],
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
if
(
self
.
mup_embedding_multiplier
is
not
None
if
(
self
.
mup_embedding_multiplier
is
not
None
and
self
.
mup_embedding_multiplier
>
0.0
):
and
self
.
mup_embedding_multiplier
>
0.0
):
hidden_states
=
hidden_states
*
self
.
mup_embedding_multiplier
hidden_states
=
hidden_states
*
self
.
mup_embedding_multiplier
...
@@ -397,8 +398,8 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP):
...
@@ -397,8 +398,8 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP):
else
:
else
:
self
.
dummy_token_indices
=
None
self
.
dummy_token_indices
=
None
def
get_input_embeddings
(
self
)
:
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
embed_tokens
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
set_input_embeddings
(
self
,
value
):
def
set_input_embeddings
(
self
,
value
):
self
.
model
.
embed_tokens
=
value
self
.
model
.
embed_tokens
=
value
...
@@ -433,6 +434,7 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP):
...
@@ -433,6 +434,7 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
output_hidden_states
=
self
.
model
(
output_hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
...
@@ -440,6 +442,7 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP):
...
@@ -440,6 +442,7 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP):
kv_caches
=
kv_caches
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
)
output_hidden_states
=
output_hidden_states
output_hidden_states
=
output_hidden_states
return
output_hidden_states
return
output_hidden_states
...
...
vllm/model_executor/models/phimoe.py
View file @
643ecf7b
...
@@ -465,6 +465,9 @@ class PhiMoEModel(nn.Module):
...
@@ -465,6 +465,9 @@ class PhiMoEModel(nn.Module):
make_empty_intermediate_tensors_factory
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -472,9 +475,13 @@ class PhiMoEModel(nn.Module):
...
@@ -472,9 +475,13 @@ class PhiMoEModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
residual
=
None
else
:
else
:
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
...
@@ -560,6 +567,9 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -560,6 +567,9 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -567,9 +577,11 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -567,9 +577,11 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/qwen.py
View file @
643ecf7b
...
@@ -578,6 +578,9 @@ class QWenModel(nn.Module):
...
@@ -578,6 +578,9 @@ class QWenModel(nn.Module):
quant_config
=
quant_config
)
if
hasattr
(
quant_config
=
quant_config
)
if
hasattr
(
config
,
"visual"
)
else
None
config
,
"visual"
)
else
None
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
wte
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -586,6 +589,7 @@ class QWenModel(nn.Module):
...
@@ -586,6 +589,7 @@ class QWenModel(nn.Module):
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
pixel_values
:
Optional
[
QwenImageInputs
],
pixel_values
:
Optional
[
QwenImageInputs
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
img_pos
=
None
img_pos
=
None
# If pixel / visual embeddings are provided, this is a visual model
# If pixel / visual embeddings are provided, this is a visual model
...
@@ -606,6 +610,10 @@ class QWenModel(nn.Module):
...
@@ -606,6 +610,10 @@ class QWenModel(nn.Module):
)
)
if
get_pp_group
().
is_first_rank
:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
hidden_states
=
self
.
wte
(
input_ids
)
hidden_states
=
self
.
wte
(
input_ids
)
# Merge the image embeddings into the hidden states if actually have
# Merge the image embeddings into the hidden states if actually have
# visual features and the corresponding image tokens
# visual features and the corresponding image tokens
...
@@ -915,6 +923,9 @@ class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
...
@@ -915,6 +923,9 @@ class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
)
)
return
None
return
None
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
transformer
.
get_input_embeddings
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -922,7 +933,8 @@ class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
...
@@ -922,7 +933,8 @@ class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
pixel_values
:
Optional
[
torch
.
Tensor
]
=
None
pixel_values
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
intermediate_tensors
is
not
None
:
if
intermediate_tensors
is
not
None
:
input_ids
=
None
input_ids
=
None
...
@@ -932,7 +944,7 @@ class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
...
@@ -932,7 +944,7 @@ class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
attn_metadata
,
intermediate_tensors
,
pixel_values
)
pixel_values
,
inputs_embeds
)
return
hidden_states
return
hidden_states
def
compute_logits
(
def
compute_logits
(
...
...
vllm/model_executor/models/qwen2.py
View file @
643ecf7b
...
@@ -309,7 +309,7 @@ class Qwen2Model(nn.Module):
...
@@ -309,7 +309,7 @@ class Qwen2Model(nn.Module):
if
inputs_embeds
is
not
None
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
else
:
else
:
hidden_states
=
self
.
embed_token
s
(
input_ids
)
hidden_states
=
self
.
get_input_embedding
s
(
input_ids
)
residual
=
None
residual
=
None
else
:
else
:
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
...
...
vllm/model_executor/models/qwen2_cls.py
View file @
643ecf7b
...
@@ -72,6 +72,9 @@ class Qwen2ForSequenceClassification(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -72,6 +72,9 @@ class Qwen2ForSequenceClassification(nn.Module, SupportsLoRA, SupportsPP):
normalize
=
False
,
normalize
=
False
,
softmax
=
True
)
softmax
=
True
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -79,9 +82,11 @@ class Qwen2ForSequenceClassification(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -79,9 +82,11 @@ class Qwen2ForSequenceClassification(nn.Module, SupportsLoRA, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
logits
,
_
=
self
.
score
(
hidden_states
)
logits
,
_
=
self
.
score
(
hidden_states
)
return
logits
return
logits
...
...
vllm/model_executor/models/qwen2_moe.py
View file @
643ecf7b
...
@@ -344,6 +344,9 @@ class Qwen2MoeModel(nn.Module):
...
@@ -344,6 +344,9 @@ class Qwen2MoeModel(nn.Module):
make_empty_intermediate_tensors_factory
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -351,9 +354,13 @@ class Qwen2MoeModel(nn.Module):
...
@@ -351,9 +354,13 @@ class Qwen2MoeModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
residual
=
None
else
:
else
:
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
...
@@ -395,6 +402,9 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
...
@@ -395,6 +402,9 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -402,9 +412,11 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
...
@@ -402,9 +412,11 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
return
hidden_states
def
compute_logits
(
def
compute_logits
(
...
...
vllm/model_executor/models/qwen2_rm.py
View file @
643ecf7b
...
@@ -85,6 +85,9 @@ class Qwen2ForRewardModel(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -85,6 +85,9 @@ class Qwen2ForRewardModel(nn.Module, SupportsLoRA, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -92,9 +95,11 @@ class Qwen2ForRewardModel(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -92,9 +95,11 @@ class Qwen2ForRewardModel(nn.Module, SupportsLoRA, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
logits
,
_
=
self
.
score
(
hidden_states
)
logits
,
_
=
self
.
score
(
hidden_states
)
return
logits
return
logits
...
...
vllm/model_executor/models/solar.py
View file @
643ecf7b
...
@@ -456,9 +456,11 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -456,9 +456,11 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
model_output
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
model_output
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
model_output
return
model_output
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment