Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
643ecf7b
Unverified
Commit
643ecf7b
authored
Nov 16, 2024
by
Roger Wang
Committed by
GitHub
Nov 17, 2024
Browse files
[V1] Refactor model executable interface for all text-only language models (#10374)
Signed-off-by:
Roger Wang
<
ywang@roblox.com
>
parent
4fd93750
Changes
43
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
229 additions
and
41 deletions
+229
-41
vllm/model_executor/models/arctic.py
vllm/model_executor/models/arctic.py
+14
-2
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+14
-2
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+14
-3
vllm/model_executor/models/commandr.py
vllm/model_executor/models/commandr.py
+14
-2
vllm/model_executor/models/dbrx.py
vllm/model_executor/models/dbrx.py
+14
-2
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+14
-2
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+14
-2
vllm/model_executor/models/eagle.py
vllm/model_executor/models/eagle.py
+10
-3
vllm/model_executor/models/exaone.py
vllm/model_executor/models/exaone.py
+6
-1
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+14
-2
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+6
-1
vllm/model_executor/models/gemma2.py
vllm/model_executor/models/gemma2.py
+10
-2
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+5
-2
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+13
-4
vllm/model_executor/models/gpt_j.py
vllm/model_executor/models/gpt_j.py
+14
-2
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+14
-2
vllm/model_executor/models/granite.py
vllm/model_executor/models/granite.py
+6
-1
vllm/model_executor/models/granitemoe.py
vllm/model_executor/models/granitemoe.py
+14
-2
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+7
-2
vllm/model_executor/models/jais.py
vllm/model_executor/models/jais.py
+12
-2
No files found.
vllm/model_executor/models/arctic.py
View file @
643ecf7b
...
...
@@ -389,6 +389,9 @@ class ArcticModel(nn.Module):
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -396,9 +399,13 @@ class ArcticModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
...
...
@@ -439,6 +446,9 @@ class ArcticForCausalLM(nn.Module, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -446,9 +456,11 @@ class ArcticForCausalLM(nn.Module, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
...
...
vllm/model_executor/models/baichuan.py
View file @
643ecf7b
...
...
@@ -284,6 +284,9 @@ class BaiChuanModel(nn.Module):
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -291,9 +294,13 @@ class BaiChuanModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
...
...
@@ -363,6 +370,9 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -370,9 +380,11 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
...
...
vllm/model_executor/models/bloom.py
View file @
643ecf7b
...
...
@@ -251,6 +251,9 @@ class BloomModel(nn.Module):
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
word_embeddings_layernorm
(
self
.
word_embeddings
(
input_ids
))
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -258,10 +261,13 @@ class BloomModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
word_embeddings
(
input_ids
)
hidden_states
=
self
.
word_embeddings_layernorm
(
hidden_states
)
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
...
...
@@ -301,6 +307,9 @@ class BloomForCausalLM(nn.Module, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
transformer
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
transformer
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -308,9 +317,11 @@ class BloomForCausalLM(nn.Module, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
...
...
vllm/model_executor/models/commandr.py
View file @
643ecf7b
...
...
@@ -280,6 +280,9 @@ class CohereModel(nn.Module):
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -287,9 +290,13 @@ class CohereModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
...
...
@@ -354,6 +361,9 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
@
torch
.
no_grad
()
def
forward
(
self
,
...
...
@@ -362,9 +372,11 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
...
...
vllm/model_executor/models/dbrx.py
View file @
643ecf7b
...
...
@@ -321,6 +321,9 @@ class DbrxModel(nn.Module):
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
d_model
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
wte
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -328,9 +331,13 @@ class DbrxModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
wte
(
input_ids
)
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
else
:
assert
intermediate_tensors
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
...
...
@@ -376,6 +383,9 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
transformer
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
transformer
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -383,9 +393,11 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
...
...
vllm/model_executor/models/deepseek.py
View file @
643ecf7b
...
...
@@ -353,6 +353,9 @@ class DeepseekModel(nn.Module):
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -360,9 +363,13 @@ class DeepseekModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
else
:
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
...
...
@@ -401,6 +408,9 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -408,9 +418,11 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
643ecf7b
...
...
@@ -445,6 +445,9 @@ class DeepseekV2Model(nn.Module):
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -452,9 +455,13 @@ class DeepseekV2Model(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
...
...
@@ -495,6 +502,9 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -502,9 +512,11 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
...
...
vllm/model_executor/models/eagle.py
View file @
643ecf7b
...
...
@@ -78,6 +78,9 @@ class EAGLE(nn.Module):
def
sampler
(
self
):
return
self
.
model
.
sampler
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -86,11 +89,14 @@ class EAGLE(nn.Module):
attn_metadata
:
AttentionMetadata
,
previous_hidden_states
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
tok_embeds
=
self
.
model
.
model
.
embed_tokens
(
input_ids
)
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
self
.
fc
(
torch
.
cat
([
tok
_embeds
,
previous_hidden_states
],
dim
=-
1
))
torch
.
cat
([
inputs
_embeds
,
previous_hidden_states
],
dim
=-
1
))
inputs_embeds
[
positions
==
0
]
=
0
# masking inputs at position=0
...
...
@@ -100,7 +106,8 @@ class EAGLE(nn.Module):
positions
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
)
intermediate_tensors
=
intermediate_tensors
,
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/exaone.py
View file @
643ecf7b
...
...
@@ -479,6 +479,9 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
transformer
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -486,9 +489,11 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
model_output
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
model_output
def
compute_logits
(
...
...
vllm/model_executor/models/falcon.py
View file @
643ecf7b
...
...
@@ -367,6 +367,9 @@ class FalconModel(nn.Module):
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
word_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -374,9 +377,13 @@ class FalconModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
word_embeddings
(
input_ids
)
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
else
:
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
...
...
@@ -432,6 +439,9 @@ class FalconForCausalLM(nn.Module, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
transformer
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
transformer
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
...
...
@@ -439,9 +449,11 @@ class FalconForCausalLM(nn.Module, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
...
...
vllm/model_executor/models/gemma.py
View file @
643ecf7b
...
...
@@ -390,6 +390,9 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -397,9 +400,11 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
...
...
vllm/model_executor/models/gemma2.py
View file @
643ecf7b
...
...
@@ -272,6 +272,9 @@ class Gemma2Model(nn.Module):
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
...
...
@@ -285,7 +288,7 @@ class Gemma2Model(nn.Module):
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
embed_token
s
(
input_ids
)
hidden_states
=
self
.
get_input_embedding
s
(
input_ids
)
hidden_states
*=
self
.
normalizer
residual
=
None
else
:
...
...
@@ -414,6 +417,9 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -421,9 +427,11 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
...
...
vllm/model_executor/models/gpt2.py
View file @
643ecf7b
...
...
@@ -209,6 +209,9 @@ class GPT2Model(nn.Module):
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
n_embd
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
wte
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -220,7 +223,7 @@ class GPT2Model(nn.Module):
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
wte
(
input_ids
)
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
)
position_embeds
=
self
.
wpe
(
position_ids
)
hidden_states
=
inputs_embeds
+
position_embeds
else
:
...
...
@@ -262,7 +265,7 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
self
.
transformer
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
transformer
.
wte
(
input_ids
)
return
self
.
transformer
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
...
...
vllm/model_executor/models/gpt_bigcode.py
View file @
643ecf7b
...
...
@@ -218,6 +218,9 @@ class GPTBigCodeModel(nn.Module):
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
n_embd
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
wte
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -225,11 +228,12 @@ class GPTBigCodeModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
inputs_embeds
=
self
.
wte
(
input_ids
)
position
_embeds
=
self
.
wpe
(
position
_ids
)
hidden_states
=
inputs_embeds
+
position_
embe
ds
if
inputs_embeds
is
None
:
inputs
_embeds
=
self
.
get_input_embeddings
(
input
_ids
)
hidden_states
=
inputs_embeds
+
self
.
wpe
(
position_
i
ds
)
else
:
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
...
...
@@ -285,6 +289,9 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
transformer
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
transformer
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -292,9 +299,11 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
...
...
vllm/model_executor/models/gpt_j.py
View file @
643ecf7b
...
...
@@ -201,6 +201,9 @@ class GPTJModel(nn.Module):
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
n_embd
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
wte
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -208,9 +211,13 @@ class GPTJModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
wte
(
input_ids
)
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
else
:
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
...
...
@@ -250,6 +257,9 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
transformer
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
transformer
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -257,9 +267,11 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
...
...
vllm/model_executor/models/gpt_neox.py
View file @
643ecf7b
...
...
@@ -214,6 +214,9 @@ class GPTNeoXModel(nn.Module):
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_in
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -221,9 +224,13 @@ class GPTNeoXModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
embed_in
(
input_ids
)
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
else
:
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
...
...
@@ -262,6 +269,9 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
gpt_neox
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
gpt_neox
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -269,9 +279,11 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
gpt_neox
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
...
...
vllm/model_executor/models/granite.py
View file @
643ecf7b
...
...
@@ -409,6 +409,9 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
else
:
self
.
lm_head
=
PPMissingLayer
()
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -416,9 +419,11 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
model_output
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
model_output
def
compute_logits
(
...
...
vllm/model_executor/models/granitemoe.py
View file @
643ecf7b
...
...
@@ -277,6 +277,9 @@ class GraniteMoeModel(nn.Module):
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -284,9 +287,13 @@ class GraniteMoeModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
hidden_states
*=
self
.
embedding_multiplier
residual
=
None
else
:
...
...
@@ -366,6 +373,9 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
sampler
=
get_sampler
()
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -373,9 +383,11 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
...
...
vllm/model_executor/models/internlm2.py
View file @
643ecf7b
...
...
@@ -290,7 +290,7 @@ class InternLM2Model(nn.Module):
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
tok
_embeddings
(
input_ids
)
hidden_states
=
self
.
get_input
_embeddings
(
input_ids
)
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
...
...
@@ -335,6 +335,9 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -342,9 +345,11 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
...
...
vllm/model_executor/models/jais.py
View file @
643ecf7b
...
...
@@ -250,6 +250,9 @@ class JAISModel(nn.Module):
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
n_embd
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
wte
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -257,9 +260,11 @@ class JAISModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
IntermediateTensors
,
torch
.
Tensor
]:
if
get_pp_group
().
is_first_rank
:
inputs_embeds
=
self
.
wte
(
input_ids
)
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
)
if
self
.
wpe
is
not
None
:
position_embeds
=
self
.
wpe
(
position_ids
)
hidden_states
=
inputs_embeds
+
position_embeds
...
...
@@ -311,6 +316,9 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
transformer
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
transformer
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
@@ -318,9 +326,11 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
IntermediateTensors
,
torch
.
Tensor
]:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment