Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
925f3332
Unverified
Commit
925f3332
authored
Mar 24, 2024
by
Woosuk Kwon
Committed by
GitHub
Mar 25, 2024
Browse files
[Core] Refactor Attention Take 2 (#3462)
parent
b0dfa91d
Changes
44
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
276 additions
and
360 deletions
+276
-360
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+14
-17
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+13
-17
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+14
-19
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+14
-19
vllm/model_executor/models/gpt_j.py
vllm/model_executor/models/gpt_j.py
+14
-18
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+14
-18
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+13
-17
vllm/model_executor/models/jais.py
vllm/model_executor/models/jais.py
+15
-20
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+13
-17
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+14
-18
vllm/model_executor/models/mixtral_quant.py
vllm/model_executor/models/mixtral_quant.py
+14
-18
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+14
-18
vllm/model_executor/models/olmo.py
vllm/model_executor/models/olmo.py
+13
-17
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+17
-22
vllm/model_executor/models/orion.py
vllm/model_executor/models/orion.py
+13
-17
vllm/model_executor/models/phi.py
vllm/model_executor/models/phi.py
+14
-18
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+13
-18
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+13
-17
vllm/model_executor/models/stablelm.py
vllm/model_executor/models/stablelm.py
+13
-17
vllm/model_executor/models/starcoder2.py
vllm/model_executor/models/starcoder2.py
+14
-18
No files found.
vllm/model_executor/models/falcon.py
View file @
925f3332
...
...
@@ -19,16 +19,15 @@
"""PyTorch Falcon model."""
import
math
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Union
import
torch
from
torch
import
nn
from
torch.nn
import
LayerNorm
from
transformers
import
FalconConfig
as
HF_FalconConfig
from
vllm.
model_executor.input_metadata
import
Input
Metadata
from
vllm.
attention
import
Attention
,
Attention
Metadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
...
...
@@ -48,7 +47,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs
import
RWConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
FalconConfig
=
Union
[
HF_FalconConfig
,
RWConfig
]
...
...
@@ -177,8 +175,8 @@ class FalconAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
bias
=
self
.
query_key_value
(
hidden_states
)
if
bias
is
not
None
:
...
...
@@ -186,8 +184,7 @@ class FalconAttention(nn.Module):
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
if
self
.
use_rotary
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
,
bias
=
self
.
dense
(
attn_output
)
return
attn_output
,
bias
...
...
@@ -263,8 +260,8 @@ class FalconDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
...
...
@@ -279,7 +276,7 @@ class FalconDecoderLayer(nn.Module):
positions
=
positions
,
hidden_states
=
attention_layernorm_out
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
if
self
.
reduce_row_parallel_results
and
attention_bias
is
not
None
:
attention_output
+=
attention_bias
...
...
@@ -343,8 +340,8 @@ class FalconModel(nn.Module):
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
word_embeddings
(
input_ids
)
for
i
in
range
(
len
(
self
.
h
)):
...
...
@@ -353,7 +350,7 @@ class FalconModel(nn.Module):
positions
,
hidden_states
,
kv_caches
[
i
],
input
_metadata
,
attn
_metadata
,
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
...
...
@@ -378,14 +375,14 @@ class FalconForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
,
attn
_metadata
,
)
return
hidden_states
...
...
vllm/model_executor/models/gemma.py
View file @
925f3332
...
...
@@ -20,10 +20,9 @@ import torch
from
torch
import
nn
from
transformers
import
GemmaConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
LoRAConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
...
...
@@ -41,8 +40,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
GemmaMLP
(
nn
.
Module
):
...
...
@@ -133,14 +130,13 @@ class GemmaAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -177,8 +173,8 @@ class GemmaDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
...
...
@@ -192,7 +188,7 @@ class GemmaDecoderLayer(nn.Module):
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
# Fully Connected
...
...
@@ -226,8 +222,8 @@ class GemmaModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
# Normalize the embedding by sqrt(hidden_size)
...
...
@@ -240,7 +236,7 @@ class GemmaModel(nn.Module):
positions
,
hidden_states
,
kv_caches
[
i
],
input
_metadata
,
attn
_metadata
,
residual
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
...
...
@@ -290,11 +286,11 @@ class GemmaForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/gpt2.py
View file @
925f3332
...
...
@@ -17,15 +17,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
import
torch
from
torch
import
nn
from
transformers
import
GPT2Config
from
vllm.
model_executor.input_metadata
import
Input
Metadata
from
vllm.
attention
import
Attention
,
Attention
Metadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
...
...
@@ -41,8 +40,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
GPT2Attention
(
nn
.
Module
):
...
...
@@ -79,14 +76,12 @@ class GPT2Attention(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
key_cache
,
value_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
,
_
=
self
.
c_proj
(
attn_output
)
return
attn_output
...
...
@@ -144,15 +139,15 @@ class GPT2Block(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
ln_1
(
hidden_states
)
attn_output
=
self
.
attn
(
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
# residual connection
hidden_states
=
attn_output
+
residual
...
...
@@ -190,8 +185,8 @@ class GPT2Model(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
wte
(
input_ids
)
position_embeds
=
self
.
wpe
(
position_ids
)
...
...
@@ -199,7 +194,7 @@ class GPT2Model(nn.Module):
for
i
in
range
(
len
(
self
.
h
)):
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
input
_metadata
)
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
attn
_metadata
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
...
...
@@ -224,11 +219,11 @@ class GPT2LMHeadModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/gpt_bigcode.py
View file @
925f3332
...
...
@@ -18,15 +18,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
import
torch
from
torch
import
nn
from
transformers
import
GPTBigCodeConfig
from
vllm.
model_executor.input_metadata
import
Input
Metadata
from
vllm.
attention
import
Attention
,
Attention
Metadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
...
...
@@ -42,8 +41,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
GPTBigCodeAttention
(
nn
.
Module
):
...
...
@@ -94,8 +91,8 @@ class GPTBigCodeAttention(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
(
...
...
@@ -105,9 +102,7 @@ class GPTBigCodeAttention(nn.Module):
],
dim
=-
1
,
)
key_cache
,
value_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
,
_
=
self
.
c_proj
(
attn_output
)
return
attn_output
...
...
@@ -165,15 +160,15 @@ class GPTBigCodeBlock(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
ln_1
(
hidden_states
)
attn_output
=
self
.
attn
(
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
# residual connection
hidden_states
=
attn_output
+
residual
...
...
@@ -211,8 +206,8 @@ class GPTBigCodeModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
wte
(
input_ids
)
position_embeds
=
self
.
wpe
(
position_ids
)
...
...
@@ -220,7 +215,7 @@ class GPTBigCodeModel(nn.Module):
for
i
in
range
(
len
(
self
.
h
)):
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
input
_metadata
)
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
attn
_metadata
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
...
...
@@ -245,11 +240,11 @@ class GPTBigCodeForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/gpt_j.py
View file @
925f3332
...
...
@@ -16,15 +16,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-J model compatible with HuggingFace weights."""
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
import
torch
from
torch
import
nn
from
transformers
import
GPTJConfig
from
vllm.
model_executor.input_metadata
import
Input
Metadata
from
vllm.
attention
import
Attention
,
Attention
Metadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
...
...
@@ -41,8 +40,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
GPTJAttention
(
nn
.
Module
):
...
...
@@ -93,14 +90,13 @@ class GPTJAttention(nn.Module):
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
,
_
=
self
.
out_proj
(
attn_output
)
return
attn_output
...
...
@@ -154,8 +150,8 @@ class GPTJBlock(nn.Module):
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
ln_1
(
hidden_states
)
...
...
@@ -163,7 +159,7 @@ class GPTJBlock(nn.Module):
position_ids
=
position_ids
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
mlp_output
=
self
.
mlp
(
hidden_states
)
hidden_states
=
attn_output
+
mlp_output
+
residual
...
...
@@ -192,8 +188,8 @@ class GPTJModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
wte
(
input_ids
)
for
i
in
range
(
len
(
self
.
h
)):
...
...
@@ -202,7 +198,7 @@ class GPTJModel(nn.Module):
position_ids
,
hidden_states
,
kv_caches
[
i
],
input
_metadata
,
attn
_metadata
,
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
...
...
@@ -232,11 +228,11 @@ class GPTJForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/gpt_neox.py
View file @
925f3332
...
...
@@ -16,15 +16,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
import
torch
from
torch
import
nn
from
transformers
import
GPTNeoXConfig
from
vllm.
model_executor.input_metadata
import
Input
Metadata
from
vllm.
attention
import
Attention
,
Attention
Metadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
...
...
@@ -41,8 +40,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
GPTNeoXAttention
(
nn
.
Module
):
...
...
@@ -94,14 +91,13 @@ class GPTNeoXAttention(nn.Module):
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
dense
(
attn_output
)
return
output
...
...
@@ -155,15 +151,15 @@ class GPTNeoXLayer(nn.Module):
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
attn_input
=
self
.
input_layernorm
(
hidden_states
)
attn_output
=
self
.
attention
(
position_ids
=
position_ids
,
hidden_states
=
attn_input
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
if
self
.
use_parallel_residual
:
...
...
@@ -208,8 +204,8 @@ class GPTNeoXModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_in
(
input_ids
)
for
i
in
range
(
len
(
self
.
layers
)):
...
...
@@ -218,7 +214,7 @@ class GPTNeoXModel(nn.Module):
position_ids
,
hidden_states
,
kv_caches
[
i
],
input
_metadata
,
attn
_metadata
,
)
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
return
hidden_states
...
...
@@ -246,11 +242,11 @@ class GPTNeoXForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
gpt_neox
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/internlm2.py
View file @
925f3332
...
...
@@ -5,9 +5,8 @@ import torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.
model_executor.input_metadata
import
Input
Metadata
from
vllm.
attention
import
Attention
,
Attention
Metadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
...
...
@@ -25,8 +24,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
InternLM2MLP
(
nn
.
Module
):
...
...
@@ -124,14 +121,13 @@ class InternLM2Attention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
wqkv
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
wo
(
attn_output
)
return
output
...
...
@@ -172,8 +168,8 @@ class InternLMDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
...
...
@@ -187,7 +183,7 @@ class InternLMDecoderLayer(nn.Module):
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
# Fully Connected
...
...
@@ -221,8 +217,8 @@ class InternLM2Model(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
tok_embeddings
(
input_ids
)
residual
=
None
...
...
@@ -232,7 +228,7 @@ class InternLM2Model(nn.Module):
positions
,
hidden_states
,
kv_caches
[
i
],
input
_metadata
,
attn
_metadata
,
residual
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
...
...
@@ -258,11 +254,11 @@ class InternLM2ForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/jais.py
View file @
925f3332
...
...
@@ -20,14 +20,13 @@
"""Inference-only Jais model compatible with HuggingFace weights."""
import
math
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
import
torch
from
torch
import
nn
from
vllm.transformers_utils.configs
import
JAISConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
...
...
@@ -49,8 +48,6 @@ from vllm.model_executor.weight_utils import (
from
vllm.sequence
import
SamplerOutput
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
SwiGLUActivation
(
nn
.
Module
):
...
...
@@ -122,14 +119,12 @@ class JAISAttention(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
key_cache
,
value_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
,
_
=
self
.
c_proj
(
attn_output
)
return
attn_output
...
...
@@ -196,15 +191,15 @@ class JAISBlock(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
ln_1
(
hidden_states
)
attn_output
=
self
.
attn
(
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
# residual connection
hidden_states
=
attn_output
+
residual
...
...
@@ -248,8 +243,8 @@ class JAISModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
wte
(
input_ids
)
if
self
.
wpe
is
not
None
:
...
...
@@ -262,7 +257,7 @@ class JAISModel(nn.Module):
for
i
in
range
(
len
(
self
.
h
)):
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
input
_metadata
)
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
attn
_metadata
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
...
...
@@ -293,11 +288,11 @@ class JAISLMHeadModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -348,4 +343,4 @@ class JAISLMHeadModel(nn.Module):
loaded_weight
=
loaded_weight
.
t
()
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
\ No newline at end of file
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/llama.py
View file @
925f3332
...
...
@@ -27,10 +27,9 @@ import torch
from
torch
import
nn
from
transformers
import
LlamaConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
LoRAConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
...
...
@@ -48,8 +47,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
LlamaMLP
(
nn
.
Module
):
...
...
@@ -150,14 +147,13 @@ class LlamaAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -203,8 +199,8 @@ class LlamaDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
...
...
@@ -218,7 +214,7 @@ class LlamaDecoderLayer(nn.Module):
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
# Fully Connected
...
...
@@ -258,8 +254,8 @@ class LlamaModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
...
...
@@ -269,7 +265,7 @@ class LlamaModel(nn.Module):
positions
,
hidden_states
,
kv_caches
[
i
],
input
_metadata
,
attn
_metadata
,
residual
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
...
...
@@ -336,11 +332,11 @@ class LlamaForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/mixtral.py
View file @
925f3332
...
...
@@ -21,15 +21,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
import
torch
from
torch
import
nn
from
transformers
import
MixtralConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
LoRAConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
...
...
@@ -51,8 +50,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
MixtralMoE
(
nn
.
Module
):
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
...
...
@@ -209,14 +206,13 @@ class MixtralAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -254,8 +250,8 @@ class MixtralDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
...
...
@@ -269,7 +265,7 @@ class MixtralDecoderLayer(nn.Module):
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
# Fully Connected
...
...
@@ -309,15 +305,15 @@ class MixtralModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
input
_metadata
,
kv_caches
[
i
],
attn
_metadata
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
...
...
@@ -377,11 +373,11 @@ class MixtralForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/mixtral_quant.py
View file @
925f3332
...
...
@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
import
numpy
as
np
...
...
@@ -31,8 +31,7 @@ import torch.nn.functional as F
from
torch
import
nn
from
transformers
import
MixtralConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
ReplicatedLinear
,
...
...
@@ -52,8 +51,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
MixtralMLP
(
nn
.
Module
):
...
...
@@ -227,14 +224,13 @@ class MixtralAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -269,8 +265,8 @@ class MixtralDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
...
...
@@ -284,7 +280,7 @@ class MixtralDecoderLayer(nn.Module):
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
# Fully Connected
...
...
@@ -319,15 +315,15 @@ class MixtralModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
input
_metadata
,
kv_caches
[
i
],
attn
_metadata
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
...
...
@@ -352,11 +348,11 @@ class MixtralForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/mpt.py
View file @
925f3332
# coding=utf-8
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import
math
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
import
torch
import
torch.nn
as
nn
from
vllm.
model_executor.input_metadata
import
Input
Metadata
from
vllm.
attention
import
Attention
,
Attention
Metadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
...
...
@@ -25,8 +24,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs.mpt
import
MPTConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
def
_get_alibi_slopes
(
total_num_heads
:
int
,
...
...
@@ -116,8 +113,8 @@ class MPTAttention(nn.Module):
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
del
position_ids
# unused.
qkv
,
_
=
self
.
Wqkv
(
hidden_states
)
...
...
@@ -127,8 +124,7 @@ class MPTAttention(nn.Module):
if
self
.
qk_ln
:
q
=
self
.
q_ln
(
q
)
k
=
self
.
k_ln
(
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
...
...
@@ -184,15 +180,15 @@ class MPTBlock(nn.Module):
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
x
=
self
.
norm_1
(
hidden_states
)
x
=
self
.
attn
(
position_ids
=
position_ids
,
hidden_states
=
x
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
hidden_states
=
hidden_states
+
x
x
=
self
.
norm_2
(
hidden_states
)
...
...
@@ -230,8 +226,8 @@ class MPTModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
wte
(
input_ids
)
for
i
in
range
(
len
(
self
.
blocks
)):
...
...
@@ -240,7 +236,7 @@ class MPTModel(nn.Module):
position_ids
,
hidden_states
,
kv_caches
[
i
],
input
_metadata
,
attn
_metadata
,
)
hidden_states
=
self
.
norm_f
(
hidden_states
)
return
hidden_states
...
...
@@ -267,11 +263,11 @@ class MPTForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/olmo.py
View file @
925f3332
...
...
@@ -42,8 +42,7 @@ import torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
...
...
@@ -67,8 +66,6 @@ from vllm.sequence import SamplerOutput
# this model must need this dependency
from
hf_olmo
import
OLMoConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
SwiGLU
(
nn
.
Module
):
...
...
@@ -146,16 +143,15 @@ class OlmoAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
attn_norm
(
hidden_states
)
qkv
,
_
=
self
.
att_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
if
self
.
config
.
rope
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
attn_out
(
attn_output
)
return
output
...
...
@@ -241,12 +237,12 @@ class OlmoBlock(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]:
# Attention block.
og_x
=
hidden_states
x
=
self
.
attn
(
positions
,
hidden_states
,
kv_cache
,
input
_metadata
)
x
=
self
.
attn
(
positions
,
hidden_states
,
kv_cache
,
attn
_metadata
)
x
=
x
+
og_x
# MLP block.
...
...
@@ -296,8 +292,8 @@ class OlmoModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
"""
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
...
...
@@ -313,7 +309,7 @@ class OlmoModel(nn.Module):
positions
,
x
,
kv_caches
[
block_idx
],
input
_metadata
,
attn
_metadata
,
)
# Apply final layer norm.
...
...
@@ -344,14 +340,14 @@ class OLMoForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
kv_caches
=
kv_caches
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
return
hidden_states
...
...
vllm/model_executor/models/opt.py
View file @
925f3332
...
...
@@ -17,15 +17,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only OPT model compatible with HuggingFace weights."""
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
import
torch
from
torch
import
nn
from
transformers
import
OPTConfig
from
vllm.
model_executor.input_metadata
import
Input
Metadata
from
vllm.
attention
import
Attention
,
Attention
Metadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
...
...
@@ -42,8 +41,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
OPTLearnedPositionalEmbedding
(
nn
.
Embedding
):
...
...
@@ -97,14 +94,12 @@ class OPTAttention(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
key_cache
,
value_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
...
...
@@ -152,8 +147,8 @@ class OPTDecoderLayer(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
# Self Attention
residual
=
hidden_states
...
...
@@ -162,7 +157,7 @@ class OPTDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
)
attn
_metadata
=
attn
_metadata
)
hidden_states
=
residual
+
hidden_states
# 350m applies layer norm AFTER attention
if
not
self
.
do_layer_norm_before
:
...
...
@@ -241,8 +236,8 @@ class OPTDecoder(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
pos_embeds
=
self
.
embed_positions
(
positions
)
...
...
@@ -252,7 +247,7 @@ class OPTDecoder(nn.Module):
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
input
_metadata
)
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
attn
_metadata
)
if
self
.
final_layer_norm
is
not
None
:
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
...
...
@@ -275,10 +270,10 @@ class OPTModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
return
self
.
decoder
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
return
self
.
decoder
(
input_ids
,
positions
,
kv_caches
,
attn
_metadata
)
class
OPTForCausalLM
(
nn
.
Module
):
...
...
@@ -300,11 +295,11 @@ class OPTForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/orion.py
View file @
925f3332
...
...
@@ -10,9 +10,8 @@ import torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.
model_executor.input_metadata
import
Input
Metadata
from
vllm.
attention
import
Attention
,
Attention
Metadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
...
...
@@ -29,8 +28,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
OrionMLP
(
nn
.
Module
):
...
...
@@ -128,14 +125,13 @@ class OrionAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -178,8 +174,8 @@ class OrionDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
...
...
@@ -189,7 +185,7 @@ class OrionDecoderLayer(nn.Module):
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
hidden_states
=
residual
+
hidden_states
...
...
@@ -227,8 +223,8 @@ class OrionModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
...
...
@@ -238,7 +234,7 @@ class OrionModel(nn.Module):
positions
,
hidden_states
,
kv_caches
[
i
],
input
_metadata
,
attn
_metadata
,
residual
,
)
hidden_states
=
self
.
norm
(
hidden_states
)
...
...
@@ -264,11 +260,11 @@ class OrionForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/phi.py
View file @
925f3332
...
...
@@ -35,15 +35,14 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.
model_executor.input_metadata
import
Input
Metadata
from
vllm.
attention
import
Attention
,
Attention
Metadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
...
...
@@ -60,8 +59,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
PhiAttention
(
nn
.
Module
):
...
...
@@ -115,14 +112,13 @@ class PhiAttention(nn.Module):
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
dense
(
attn_output
)
return
output
...
...
@@ -172,8 +168,8 @@ class PhiLayer(nn.Module):
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
...
...
@@ -181,7 +177,7 @@ class PhiLayer(nn.Module):
position_ids
=
position_ids
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
feed_forward_hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
attn_outputs
+
feed_forward_hidden_states
+
residual
...
...
@@ -209,8 +205,8 @@ class PhiModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
for
i
in
range
(
self
.
config
.
num_hidden_layers
):
...
...
@@ -219,7 +215,7 @@ class PhiModel(nn.Module):
positions
,
hidden_states
,
kv_caches
[
i
],
input
_metadata
,
attn
_metadata
,
)
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
...
...
@@ -248,11 +244,11 @@ class PhiForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
...
...
vllm/model_executor/models/qwen.py
View file @
925f3332
...
...
@@ -10,9 +10,8 @@ import torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.
model_executor.input_metadata
import
Input
Metadata
from
vllm.
attention
import
Attention
,
Attention
Metadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
...
...
@@ -30,8 +29,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
QWenMLP
(
nn
.
Module
):
...
...
@@ -111,15 +108,13 @@ class QWenAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
c_proj
(
attn_output
)
return
output
...
...
@@ -153,8 +148,8 @@ class QWenBlock(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
...
...
@@ -167,7 +162,7 @@ class QWenBlock(nn.Module):
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
# Fully Connected
...
...
@@ -201,8 +196,8 @@ class QWenModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
wte
(
input_ids
)
residual
=
None
...
...
@@ -212,7 +207,7 @@ class QWenModel(nn.Module):
positions
,
hidden_states
,
kv_caches
[
i
],
input
_metadata
,
attn
_metadata
,
residual
,
)
hidden_states
,
_
=
self
.
ln_f
(
hidden_states
,
residual
)
...
...
@@ -238,11 +233,11 @@ class QWenLMHeadModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/qwen2.py
View file @
925f3332
...
...
@@ -28,9 +28,8 @@ import torch
from
torch
import
nn
from
transformers
import
Qwen2Config
from
vllm.
model_executor.input_metadata
import
Input
Metadata
from
vllm.
attention
import
Attention
,
Attention
Metadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
...
...
@@ -49,8 +48,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
from
vllm.sequence
import
SamplerOutput
from
vllm.config
import
LoRAConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
Qwen2MLP
(
nn
.
Module
):
...
...
@@ -147,14 +144,13 @@ class Qwen2Attention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -197,8 +193,8 @@ class Qwen2DecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
...
...
@@ -212,7 +208,7 @@ class Qwen2DecoderLayer(nn.Module):
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
# Fully Connected
...
...
@@ -248,8 +244,8 @@ class Qwen2Model(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
...
...
@@ -259,7 +255,7 @@ class Qwen2Model(nn.Module):
positions
,
hidden_states
,
kv_caches
[
i
],
input
_metadata
,
attn
_metadata
,
residual
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
...
...
@@ -315,11 +311,11 @@ class Qwen2ForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/stablelm.py
View file @
925f3332
...
...
@@ -25,9 +25,8 @@ import torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.
model_executor.input_metadata
import
Input
Metadata
from
vllm.
attention
import
Attention
,
Attention
Metadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
...
...
@@ -44,8 +43,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
StablelmMLP
(
nn
.
Module
):
...
...
@@ -134,14 +131,13 @@ class StablelmAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -166,8 +162,8 @@ class StablelmDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
residual
=
hidden_states
...
...
@@ -176,7 +172,7 @@ class StablelmDecoderLayer(nn.Module):
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
hidden_states
=
residual
+
hidden_states
...
...
@@ -211,8 +207,8 @@ class StableLMEpochModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
for
i
in
range
(
len
(
self
.
layers
)):
...
...
@@ -221,7 +217,7 @@ class StableLMEpochModel(nn.Module):
positions
,
hidden_states
,
kv_caches
[
i
],
input
_metadata
,
attn
_metadata
,
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
...
...
@@ -246,11 +242,11 @@ class StablelmForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/starcoder2.py
View file @
925f3332
...
...
@@ -18,15 +18,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Starcoder2 model."""
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
import
torch
from
torch
import
nn
from
transformers
import
Starcoder2Config
from
vllm.
model_executor.input_metadata
import
Input
Metadata
from
vllm.
attention
import
Attention
,
Attention
Metadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -43,8 +42,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
Starcoder2Attention
(
nn
.
Module
):
...
...
@@ -111,14 +108,13 @@ class Starcoder2Attention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -171,8 +167,8 @@ class Starcoder2DecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
# Self Attention
residual
=
hidden_states
...
...
@@ -181,7 +177,7 @@ class Starcoder2DecoderLayer(nn.Module):
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
hidden_states
=
residual
+
hidden_states
...
...
@@ -217,14 +213,14 @@ class Starcoder2Model(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
input
_metadata
)
attn
_metadata
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
...
...
@@ -258,11 +254,11 @@ class Starcoder2ForCausalLM(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment