Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7c4f76e3
Commit
7c4f76e3
authored
Apr 15, 2024
by
zhuwenwen
Browse files
merge v0.4.0
parents
2da0dd3e
51c31bc1
Changes
332
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1920 additions
and
522 deletions
+1920
-522
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+71
-42
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+30
-27
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+46
-29
vllm/model_executor/models/commandr.py
vllm/model_executor/models/commandr.py
+329
-0
vllm/model_executor/models/dbrx.py
vllm/model_executor/models/dbrx.py
+421
-0
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+39
-35
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+44
-40
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+80
-27
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+26
-26
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+29
-27
vllm/model_executor/models/gpt_j.py
vllm/model_executor/models/gpt_j.py
+29
-25
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+27
-24
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+31
-27
vllm/model_executor/models/jais.py
vllm/model_executor/models/jais.py
+340
-0
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+42
-29
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+244
-0
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+31
-28
vllm/model_executor/models/mixtral_quant.py
vllm/model_executor/models/mixtral_quant.py
+31
-30
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+30
-27
vllm/model_executor/models/neuron/llama.py
vllm/model_executor/models/neuron/llama.py
+0
-79
No files found.
vllm/model_executor/models/baichuan.py
View file @
7c4f76e3
...
...
@@ -25,18 +25,19 @@ import torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
LoRAConfig
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
...
@@ -44,8 +45,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
def
_get_alibi_slopes
(
total_num_heads
:
int
)
->
torch
.
Tensor
:
closest_power_of_2
=
2
**
math
.
floor
(
math
.
log2
(
total_num_heads
))
...
...
@@ -151,10 +150,10 @@ class BaiChuanAttention(nn.Module):
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
].
tolist
()
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scaling
,
alibi_slopes
=
alibi_slopes
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scaling
,
alibi_slopes
=
alibi_slopes
)
else
:
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
...
...
@@ -163,22 +162,20 @@ class BaiChuanAttention(nn.Module):
base
=
self
.
rope_theta
,
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
W_pack
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
if
self
.
postion_embedding
!=
"ALIBI"
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -217,8 +214,8 @@ class BaiChuanDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
...
...
@@ -232,7 +229,7 @@ class BaiChuanDecoderLayer(nn.Module):
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
# Fully Connected
...
...
@@ -267,8 +264,8 @@ class BaiChuanModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
...
...
@@ -278,7 +275,7 @@ class BaiChuanModel(nn.Module):
positions
,
hidden_states
,
kv_caches
[
i
],
input
_metadata
,
attn
_metadata
,
residual
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
...
...
@@ -286,36 +283,61 @@ class BaiChuanModel(nn.Module):
class
BaiChuanBaseForCausalLM
(
nn
.
Module
):
packed_modules_mapping
=
{
"W_pack"
:
[
"W_pack"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
# LoRA specific attributes
supported_lora_modules
=
[
"W_pack"
,
"o_proj"
,
"gate_up_proj"
,
"down_proj"
,
]
embedding_modules
=
{}
embedding_padding_modules
=
[]
def
__init__
(
self
,
config
,
position_embedding
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
def
__init__
(
self
,
config
,
position_embedding
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
BaiChuanModel
(
config
,
position_embedding
,
linear_method
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
...
...
@@ -334,7 +356,8 @@ class BaiChuanBaseForCausalLM(nn.Module):
if
"rotary_emb.inv_freq"
in
name
:
continue
if
name
==
"lm_head.weight"
:
# Unlike Baichuan, Baichuan2 normalizes the head weights. Refer to:
# Unlike Baichuan, Baichuan2 normalizes the head weights.
# Refer to:
# https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
# Distinguish between Baichuan and Baichuan2 by checking the
# vocab size. This is suggested by
...
...
@@ -368,19 +391,25 @@ class BaiChuanBaseForCausalLM(nn.Module):
class
BaichuanForCausalLM
(
BaiChuanBaseForCausalLM
):
"""Baichuan 13B and Baichuan2 7B/13B."""
def
__init__
(
self
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
def
__init__
(
self
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
if
config
.
hidden_size
==
4096
:
# baichuan2 7b
super
().
__init__
(
config
,
"ROPE"
,
linear_method
)
super
().
__init__
(
config
,
"ROPE"
,
linear_method
,
lora_config
)
else
:
# baichuan 13b, baichuan2 13b
super
().
__init__
(
config
,
"ALIBI"
,
linear_method
)
super
().
__init__
(
config
,
"ALIBI"
,
linear_method
,
lora_config
)
class
BaiChuanForCausalLM
(
BaiChuanBaseForCausalLM
):
"""Baichuan 7B."""
def
__init__
(
self
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
super
().
__init__
(
config
,
"ROPE"
,
linear_method
)
def
__init__
(
self
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
super
().
__init__
(
config
,
"ROPE"
,
linear_method
,
lora_config
)
vllm/model_executor/models/bloom.py
View file @
7c4f76e3
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
# Copyright 2023 The
CacheFlow
team.
# Copyright 2023 The
vLLM
team.
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -17,19 +17,19 @@
# limitations under the License.
"""Inference-only BLOOM model compatible with HuggingFace weights."""
import
math
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
import
torch
from
torch
import
nn
from
transformers
import
BloomConfig
from
vllm.
model_executor.input_metadata
import
Input
Metadata
from
vllm.
attention
import
Attention
,
Attention
Metadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
...
...
@@ -40,8 +40,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
def
_get_alibi_slopes
(
total_num_heads
:
int
)
->
torch
.
Tensor
:
closest_power_of_2
=
2
**
math
.
floor
(
math
.
log2
(
total_num_heads
))
...
...
@@ -107,23 +105,22 @@ class BloomAttention(nn.Module):
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
].
tolist
()
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scaling
,
alibi_slopes
=
alibi_slopes
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scaling
,
alibi_slopes
=
alibi_slopes
)
def
forward
(
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
del
position_ids
# Unused.
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
dense
(
attn_output
)
return
output
...
...
@@ -180,8 +177,8 @@ class BloomBlock(nn.Module):
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
...
...
@@ -197,7 +194,7 @@ class BloomBlock(nn.Module):
position_ids
=
position_ids
,
hidden_states
=
layernorm_output
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
attention_output
=
attention_output
+
residual
layernorm_output
=
self
.
post_attention_layernorm
(
attention_output
)
...
...
@@ -244,8 +241,8 @@ class BloomModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
word_embeddings
(
input_ids
)
hidden_states
=
self
.
word_embeddings_layernorm
(
hidden_states
)
...
...
@@ -255,7 +252,7 @@ class BloomModel(nn.Module):
position_ids
,
hidden_states
,
kv_caches
[
i
],
input
_metadata
,
attn
_metadata
,
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
...
...
@@ -273,26 +270,32 @@ class BloomForCausalLM(nn.Module):
self
.
linear_method
=
linear_method
self
.
transformer
=
BloomModel
(
config
,
linear_method
)
self
.
lm_head_weight
=
self
.
transformer
.
word_embeddings
.
weight
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
...
...
vllm/model_executor/models/chatglm.py
View file @
7c4f76e3
...
...
@@ -2,24 +2,25 @@
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights."""
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
import
torch
from
torch
import
nn
from
torch.nn
import
LayerNorm
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
LoRAConfig
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
...
@@ -28,8 +29,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs
import
ChatGLMConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
GLMAttention
(
nn
.
Module
):
...
...
@@ -87,7 +86,7 @@ class GLMAttention(nn.Module):
base
=
10000
*
rope_ratio
,
is_neox_style
=
False
,
)
self
.
attn
=
Paged
Attention
(
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
...
...
@@ -98,20 +97,18 @@ class GLMAttention(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
key_cache
,
value_cache
=
kv_cache
context_layer
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
,
kv_cache
,
attn_metadata
,
)
attn_output
,
_
=
self
.
dense
(
context_layer
)
return
attn_output
...
...
@@ -199,8 +196,8 @@ class GLMBlock(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
# hidden_states: [num_tokens, h]
# Layer norm at the beginning of the transformer layer.
...
...
@@ -210,7 +207,7 @@ class GLMBlock(nn.Module):
hidden_states
=
layernorm_output
,
position_ids
=
position_ids
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
# Residual connection.
...
...
@@ -263,8 +260,8 @@ class GLMTransformer(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
for
i
in
range
(
self
.
num_layers
):
layer
=
self
.
layers
[
i
]
...
...
@@ -272,7 +269,7 @@ class GLMTransformer(nn.Module):
hidden_states
=
hidden_states
,
position_ids
=
position_ids
,
kv_cache
=
kv_caches
[
i
],
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
# Final layer norm.
if
self
.
post_layer_norm
:
...
...
@@ -305,8 +302,8 @@ class ChatGLMModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
embedding
(
input_ids
)
...
...
@@ -315,43 +312,63 @@ class ChatGLMModel(nn.Module):
hidden_states
=
inputs_embeds
,
position_ids
=
position_ids
,
kv_caches
=
kv_caches
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
return
hidden_states
class
ChatGLMForCausalLM
(
nn
.
Module
):
packed_modules_mapping
=
{
"query_key_value"
:
[
"query_key_value"
],
"dense_h_to_4h"
:
[
"dense_h_to_4h"
]
}
# LoRA specific attributes
supported_lora_modules
=
[
"query_key_value"
,
"dense"
,
"dense_h_to_4h"
,
"dense_4h_to_h"
,
]
embedding_modules
=
{}
embedding_padding_modules
=
[]
def
__init__
(
self
,
config
:
ChatGLMConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
:
ChatGLMConfig
=
config
self
.
linear_method
=
linear_method
self
.
transformer
=
ChatGLMModel
(
config
,
linear_method
)
self
.
lm_head_weight
=
self
.
transformer
.
output_layer
.
weight
self
.
sampler
=
Sampler
(
config
.
padded_vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
padded_vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
...
...
vllm/model_executor/models/commandr.py
0 → 100644
View file @
7c4f76e3
# coding=utf-8
# Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch.utils.checkpoint
from
torch
import
nn
from
transformers
import
CohereConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
class
LayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-5
,
bias
=
False
):
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
hidden_size
))
if
bias
else
None
self
.
variance_epsilon
=
eps
def
forward
(
self
,
hidden_states
,
residuals
=
None
):
input_dtype
=
hidden_states
.
dtype
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
mean
=
hidden_states
.
mean
(
-
1
,
keepdim
=
True
)
variance
=
(
hidden_states
-
mean
).
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
(
hidden_states
-
mean
)
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
hidden_states
=
self
.
weight
.
to
(
torch
.
float32
)
*
hidden_states
if
self
.
bias
is
not
None
:
hidden_states
=
hidden_states
+
self
.
bias
.
to
(
torch
.
float32
)
return
hidden_states
.
to
(
input_dtype
),
residuals
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
class
CohereMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
self
.
intermediate_size
=
config
.
intermediate_size
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
self
.
hidden_size
,
[
self
.
intermediate_size
]
*
2
,
bias
=
False
,
linear_method
=
linear_method
,
)
self
.
down_proj
=
RowParallelLinear
(
self
.
intermediate_size
,
self
.
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
CohereAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CohereConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
config
=
config
self
.
attention_dropout
=
config
.
attention_dropout
self
.
hidden_size
=
config
.
hidden_size
self
.
total_num_heads
=
config
.
num_attention_heads
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
head_dim
=
self
.
hidden_size
//
self
.
total_num_heads
self
.
total_num_kv_heads
=
config
.
num_key_value_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
max_position_embeddings
=
config
.
max_position_embeddings
self
.
rope_theta
=
config
.
rope_theta
self
.
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
linear_method
=
linear_method
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
self
.
max_position_embeddings
,
base
=
self
.
rope_theta
,
rope_scaling
=
self
.
rope_scaling
,
is_neox_style
=
False
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
CohereDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CohereConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
self_attn
=
CohereAttention
(
config
,
linear_method
=
linear_method
)
self
.
mlp
=
CohereMLP
(
config
,
linear_method
=
linear_method
)
self
.
input_layernorm
=
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
residual
=
hidden_states
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states_attention
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
hidden_states_mlp
=
self
.
mlp
(
hidden_states
)
# Add everything together
hidden_states
=
residual
+
hidden_states_attention
+
hidden_states_mlp
return
hidden_states
,
residual
class
CohereModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CohereConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
([
CohereDecoderLayer
(
config
,
linear_method
=
linear_method
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
attn_metadata
,
residual
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
CohereForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CohereConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
,
scale
=
config
.
logit_scale
)
self
.
model
=
CohereModel
(
config
,
linear_method
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
model
.
embed_tokens
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
=
set
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
for
param_name
,
shard_name
,
shard_id
in
stacked_params_mapping
:
if
shard_name
not
in
name
:
continue
name
=
name
.
replace
(
shard_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
vllm/model_executor/models/dbrx.py
0 → 100644
View file @
7c4f76e3
# coding=utf-8
from
typing
import
List
,
Optional
import
torch
import
torch.nn
as
nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.communication_op
import
(
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
class
DbrxRouter
(
nn
.
Module
):
"""A Router implementation for DBRX that returns logits for each expert
per token.
"""
def
__init__
(
self
,
config
:
DbrxConfig
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
):
super
().
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_total_experts
=
config
.
ffn_config
.
moe_num_experts
self
.
d_model
=
config
.
d_model
self
.
layer
=
ReplicatedLinear
(
self
.
d_model
,
self
.
num_total_experts
,
bias
=
False
,
params_dtype
=
params_dtype
,
linear_method
=
None
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
router_logits
,
_
=
self
.
layer
(
hidden_states
)
return
router_logits
class
DbrxExperts
(
nn
.
Module
):
"""A tensor-parallel MoE implementation for DBRX.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def
__init__
(
self
,
config
:
DbrxConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
):
super
().
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_total_experts
=
config
.
ffn_config
.
moe_num_experts
self
.
top_k
=
config
.
ffn_config
.
moe_top_k
self
.
d_model
=
config
.
d_model
self
.
intermediate_size
=
(
config
.
ffn_config
.
ffn_hidden_size
//
self
.
tp_size
)
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
self
.
router
=
DbrxRouter
(
config
,
self
.
params_dtype
)
self
.
ws
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
2
*
self
.
intermediate_size
,
self
.
d_model
,
device
=
"cuda"
,
dtype
=
self
.
params_dtype
,
))
self
.
w2s
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
self
.
d_model
,
self
.
intermediate_size
,
device
=
"cuda"
,
dtype
=
self
.
params_dtype
,
))
set_weight_attrs
(
self
.
ws
,
{
"weight_loader"
:
self
.
weight_loader
,
},
)
set_weight_attrs
(
self
.
w2s
,
{
"weight_loader"
:
self
.
weight_loader
,
},
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
):
tp_rank
=
get_tensor_model_parallel_rank
()
param_data
=
param
.
data
shard_size
=
self
.
intermediate_size
shard
=
slice
(
tp_rank
*
shard_size
,
(
tp_rank
+
1
)
*
shard_size
)
# DBRX uses GLU for each experts.
# GLU has 3 linear layers: w1, v1 and w2.
if
weight_name
.
endswith
(
"w1"
):
loaded_weight
=
torch
.
reshape
(
loaded_weight
,
[
-
1
,
self
.
intermediate_size
*
self
.
tp_size
,
self
.
d_model
],
)
param_data
[:,
0
:
shard_size
,
:]
=
loaded_weight
[:,
shard
,
:]
if
weight_name
.
endswith
(
"v1"
):
loaded_weight
=
torch
.
reshape
(
loaded_weight
,
[
-
1
,
self
.
intermediate_size
*
self
.
tp_size
,
self
.
d_model
],
)
param_data
[:,
shard_size
:
2
*
shard_size
,
:]
=
loaded_weight
[:,
shard
,
:]
if
weight_name
.
endswith
(
"w2"
):
loaded_weight
=
torch
.
reshape
(
loaded_weight
,
[
-
1
,
self
.
intermediate_size
*
self
.
tp_size
,
self
.
d_model
],
).
transpose
(
1
,
2
)
param_data
[:]
=
loaded_weight
[:,
:,
shard
]
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_size
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
d_model
)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
router
(
hidden_states
)
final_hidden_states
=
fused_moe
(
hidden_states
,
self
.
ws
,
self
.
w2s
,
router_logits
,
self
.
top_k
,
renormalize
=
True
,
inplace
=
True
,
)
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_size
)
class
DbrxAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
DbrxConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
self
.
d_model
=
config
.
d_model
self
.
total_num_heads
=
config
.
n_heads
self
.
head_dim
=
self
.
d_model
//
self
.
total_num_heads
self
.
total_num_kv_heads
=
config
.
attn_config
.
kv_n_heads
self
.
clip_qkv
=
config
.
attn_config
.
clip_qkv
self
.
rope_theta
=
config
.
attn_config
.
rope_theta
self
.
max_position
=
config
.
max_seq_len
# pylint: disable=invalid-name
self
.
Wqkv
=
QKVParallelLinear
(
self
.
d_model
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
linear_method
=
linear_method
,
)
self
.
out_proj
=
RowParallelLinear
(
self
.
d_model
,
self
.
d_model
,
bias
=
False
,
linear_method
=
linear_method
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
self
.
max_position
,
base
=
int
(
self
.
rope_theta
),
is_neox_style
=
True
,
)
tp_world_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
tp_world_size
assert
self
.
total_num_heads
%
tp_world_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_world_size
if
self
.
total_num_kv_heads
>=
tp_world_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_world_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_world_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_world_size
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
)
def
forward
(
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
Wqkv
(
hidden_states
)
if
self
.
clip_qkv
is
not
None
:
qkv
.
clamp_
(
min
=-
self
.
clip_qkv
,
max
=
self
.
clip_qkv
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
hidden_states
,
_
=
self
.
out_proj
(
attn_output
)
return
hidden_states
class
DbrxFusedNormAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
DbrxConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
self
.
d_model
=
config
.
d_model
self
.
attn
=
DbrxAttention
(
config
,
linear_method
)
self
.
norm_1
=
nn
.
LayerNorm
(
self
.
d_model
)
self
.
norm_2
=
nn
.
LayerNorm
(
self
.
d_model
)
def
forward
(
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
norm_1
(
hidden_states
)
x
=
self
.
attn
(
position_ids
=
position_ids
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
hidden_states
=
residual
+
x
residual
=
hidden_states
hidden_states
=
self
.
norm_2
(
hidden_states
)
return
hidden_states
,
residual
class
DbrxBlock
(
nn
.
Module
):
def
__init__
(
self
,
config
:
DbrxConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
self
.
norm_attn_norm
=
DbrxFusedNormAttention
(
config
,
linear_method
)
self
.
ffn
=
DbrxExperts
(
config
,
linear_method
)
def
forward
(
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
hidden_states
,
residual
=
self
.
norm_attn_norm
(
position_ids
=
position_ids
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
hidden_states
=
self
.
ffn
(
hidden_states
)
hidden_states
=
hidden_states
+
residual
return
hidden_states
class
DbrxModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
DbrxConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
self
.
wte
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
d_model
,
)
self
.
blocks
=
nn
.
ModuleList
(
[
DbrxBlock
(
config
,
linear_method
)
for
_
in
range
(
config
.
n_layers
)])
self
.
norm_f
=
nn
.
LayerNorm
(
config
.
d_model
,
eps
=
1e-5
)
for
module
in
self
.
modules
():
if
hasattr
(
module
,
"bias"
)
and
isinstance
(
module
.
bias
,
nn
.
Parameter
):
# Remove the bias term in Linear and LayerNorm.
module
.
register_parameter
(
"bias"
,
None
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
wte
(
input_ids
)
for
i
in
range
(
len
(
self
.
blocks
)):
block
=
self
.
blocks
[
i
]
hidden_states
=
block
(
position_ids
,
hidden_states
,
kv_caches
[
i
],
attn_metadata
,
)
hidden_states
=
self
.
norm_f
(
hidden_states
)
return
hidden_states
class
DbrxForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
DbrxConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
transformer
=
DbrxModel
(
config
,
linear_method
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
d_model
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
,
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
):
expert_params_mapping
=
[(
"ws"
if
weight_name
in
[
"w1"
,
"v1"
]
else
"w2s"
,
f
"experts.mlp.
{
weight_name
}
"
,
)
for
weight_name
in
[
"w1"
,
"v1"
,
"w2"
]]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
for
param_name
,
weight_name
in
expert_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
weight_name
)
break
else
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/deepseek.py
View file @
7c4f76e3
...
...
@@ -21,26 +21,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Deepseek model."""
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.
model_executor.input_metadata
import
Input
Metadata
from
vllm.
attention
import
Attention
,
Attention
Metadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
ReplicatedLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.communication_op
import
(
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
...
...
@@ -50,8 +50,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
DeepseekMLP
(
nn
.
Module
):
...
...
@@ -119,7 +117,8 @@ class DeepseekMoE(nn.Module):
linear_method
=
None
)
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
config
.
moe_intermediate_size
*
config
.
n_shared_experts
intermediate_size
=
(
config
.
moe_intermediate_size
*
config
.
n_shared_experts
)
self
.
shared_experts
=
DeepseekMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
...
...
@@ -148,11 +147,11 @@ class DeepseekMoE(nn.Module):
self
.
w2
=
self
.
w2
.
view
(
len
(
w2
),
*
w2s
[
0
].
shape
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
,
sequence_length
,
hidden_dim
=
hidden_states
.
shape
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
self
.
config
.
n_shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
# router_logits: (
batch * sequence_length
, n_experts)
# router_logits: (
num_tokens
, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
fused_moe
(
hidden_states
,
self
.
w1
,
...
...
@@ -167,8 +166,7 @@ class DeepseekMoE(nn.Module):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
batch_size
,
sequence_length
,
hidden_dim
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
class
DeepseekAttention
(
nn
.
Module
):
...
...
@@ -229,23 +227,22 @@ class DeepseekAttention(nn.Module):
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -273,8 +270,9 @@ class DeepseekDecoderLayer(nn.Module):
max_position_embeddings
=
max_position_embeddings
,
linear_method
=
linear_method
,
)
if
(
config
.
n_routed_experts
is
not
None
and
\
layer_idx
>=
config
.
first_k_dense_replace
and
layer_idx
%
config
.
moe_layer_freq
==
0
):
if
(
config
.
n_routed_experts
is
not
None
and
layer_idx
>=
config
.
first_k_dense_replace
and
layer_idx
%
config
.
moe_layer_freq
==
0
):
self
.
mlp
=
DeepseekMoE
(
config
=
config
,
linear_method
=
linear_method
)
else
:
self
.
mlp
=
DeepseekMLP
(
...
...
@@ -292,8 +290,8 @@ class DeepseekDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
...
...
@@ -307,7 +305,7 @@ class DeepseekDecoderLayer(nn.Module):
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
# Fully Connected
...
...
@@ -344,15 +342,15 @@ class DeepseekModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
input
_metadata
,
kv_caches
[
i
],
attn
_metadata
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
...
...
@@ -370,26 +368,32 @@ class DeepseekForCausalLM(nn.Module):
self
.
linear_method
=
linear_method
self
.
model
=
DeepseekModel
(
config
,
linear_method
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
hidden_state
s
:
Optional
[
torch
.
Tensor
],
logit
s
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
...
...
vllm/model_executor/models/falcon.py
View file @
7c4f76e3
...
...
@@ -19,24 +19,24 @@
"""PyTorch Falcon model."""
import
math
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Union
import
torch
from
torch
import
nn
from
torch.nn
import
LayerNorm
from
transformers
import
FalconConfig
as
HF_FalconConfig
from
vllm.
model_executor.input_metadata
import
Input
Metadata
from
vllm.
attention
import
Attention
,
Attention
Metadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.communication_op
import
(
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
...
...
@@ -47,7 +47,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs
import
RWConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
FalconConfig
=
Union
[
HF_FalconConfig
,
RWConfig
]
...
...
@@ -150,10 +149,10 @@ class FalconAttention(nn.Module):
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
)
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
inv_norm_factor
,
num_kv_heads
=
self
.
num_kv_heads
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
inv_norm_factor
,
num_kv_heads
=
self
.
num_kv_heads
)
elif
self
.
use_alibi
:
tp_rank
=
get_tensor_model_parallel_rank
()
head_start
=
tp_rank
*
self
.
num_heads
...
...
@@ -161,23 +160,23 @@ class FalconAttention(nn.Module):
alibi_slopes
=
(
_get_alibi_slopes
(
self
.
total_num_heads
)
*
self
.
inv_norm_factor
)
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
].
tolist
()
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
inv_norm_factor
,
num_kv_heads
=
self
.
num_kv_heads
,
alibi_slopes
=
alibi_slopes
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
inv_norm_factor
,
num_kv_heads
=
self
.
num_kv_heads
,
alibi_slopes
=
alibi_slopes
)
else
:
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
inv_norm_factor
,
num_kv_heads
=
self
.
num_kv_heads
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
inv_norm_factor
,
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
bias
=
self
.
query_key_value
(
hidden_states
)
if
bias
is
not
None
:
...
...
@@ -185,8 +184,7 @@ class FalconAttention(nn.Module):
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
if
self
.
use_rotary
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
,
bias
=
self
.
dense
(
attn_output
)
return
attn_output
,
bias
...
...
@@ -262,8 +260,8 @@ class FalconDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
...
...
@@ -278,7 +276,7 @@ class FalconDecoderLayer(nn.Module):
positions
=
positions
,
hidden_states
=
attention_layernorm_out
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
if
self
.
reduce_row_parallel_results
and
attention_bias
is
not
None
:
attention_output
+=
attention_bias
...
...
@@ -342,8 +340,8 @@ class FalconModel(nn.Module):
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
word_embeddings
(
input_ids
)
for
i
in
range
(
len
(
self
.
h
)):
...
...
@@ -352,7 +350,7 @@ class FalconModel(nn.Module):
positions
,
hidden_states
,
kv_caches
[
i
],
input
_metadata
,
attn
_metadata
,
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
...
...
@@ -369,34 +367,37 @@ class FalconForCausalLM(nn.Module):
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
transformer
=
FalconModel
(
config
,
linear_method
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
lm_head_weight
=
self
.
transformer
.
word_embeddings
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
,
attn
_metadata
,
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
...
...
@@ -412,9 +413,12 @@ class FalconForCausalLM(nn.Module):
else
:
total_num_kv_heads
=
total_num_heads
num_query_heads_per_kv_head
=
total_num_heads
//
total_num_kv_heads
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
name
==
"lm_head.weight"
:
# Falcon uses tied embeddings.
continue
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
...
...
vllm/model_executor/models/gemma.py
View file @
7c4f76e3
...
...
@@ -14,21 +14,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights."""
from
functools
import
lru_cache
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
GemmaConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
LoRAConfig
from
vllm.
model_executor.input_metadata
import
InputMetadata
from
vllm.
logger
import
init_logger
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -40,7 +42,33 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
logger
=
init_logger
(
__name__
)
@
lru_cache
(
maxsize
=
None
)
def
_get_gemma_act_fn
(
hidden_act
:
Optional
[
str
],
hidden_activation
:
Optional
[
str
],
)
->
nn
.
Module
:
if
hidden_activation
is
None
:
if
hidden_act
is
not
None
:
logger
.
warning
(
"Gemma's activation function was incorrectly set to exact GeLU "
"in the config JSON file when it was initially released. "
"Changing the activation function to approximate GeLU "
"(`gelu_pytorch_tanh`). If you want to use the legacy "
f
"`
{
hidden_act
}
`, edit the config JSON to set "
f
"`hidden_activation=
{
hidden_act
}
` instead of `hidden_act`. "
"See https://github.com/huggingface/transformers/pull/29402 "
"for more details."
)
return
GeluAndMul
(
approximate
=
"tanh"
)
elif
hidden_activation
==
"gelu_pytorch_tanh"
:
return
GeluAndMul
(
approximate
=
"tanh"
)
elif
hidden_activation
==
"gelu"
:
return
GeluAndMul
(
approximate
=
"none"
)
else
:
raise
ValueError
(
f
"Activation function
{
hidden_act
}
is not "
"supported for Gemma models."
)
class
GemmaMLP
(
nn
.
Module
):
...
...
@@ -49,6 +77,8 @@ class GemmaMLP(nn.Module):
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
Optional
[
str
]
=
None
,
hidden_activation
:
Optional
[
str
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -60,7 +90,7 @@ class GemmaMLP(nn.Module):
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
)
self
.
act_fn
=
GeluAndMul
(
)
self
.
act_fn
=
_get_gemma_act_fn
(
hidden_act
,
hidden_activation
)
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
...
...
@@ -123,23 +153,22 @@ class GemmaAttention(nn.Module):
base
=
self
.
rope_theta
,
is_neox_style
=
True
,
)
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -165,6 +194,8 @@ class GemmaDecoderLayer(nn.Module):
self
.
mlp
=
GemmaMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
hidden_activation
=
getattr
(
config
,
"hidden_activation"
,
None
),
linear_method
=
linear_method
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
...
...
@@ -176,8 +207,8 @@ class GemmaDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
...
...
@@ -191,7 +222,7 @@ class GemmaDecoderLayer(nn.Module):
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
# Fully Connected
...
...
@@ -221,16 +252,22 @@ class GemmaModel(nn.Module):
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
# Normalize the embedding by sqrt(hidden_size)
# The normalizer's data type should be downcasted to the model's
# data type such as bfloat16, not float32.
# See https://github.com/huggingface/transformers/pull/29402
normalizer
=
self
.
config
.
hidden_size
**
0.5
self
.
register_buffer
(
"normalizer"
,
torch
.
tensor
(
normalizer
))
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
# Normalize the embedding by sqrt(hidden_size)
hidden_states
*=
self
.
config
.
hidden_size
**
0.5
hidden_states
*=
self
.
normalizer
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
...
...
@@ -239,7 +276,7 @@ class GemmaModel(nn.Module):
positions
,
hidden_states
,
kv_caches
[
i
],
input
_metadata
,
attn
_metadata
,
residual
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
...
...
@@ -281,27 +318,33 @@ class GemmaForCausalLM(nn.Module):
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
GemmaModel
(
config
,
linear_method
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
model
.
embed_tokens
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
model
.
embed_tokens
.
weight
,
hidden_states
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
...
...
@@ -325,11 +368,21 @@ class GemmaForCausalLM(nn.Module):
if
shard_name
not
in
name
:
continue
name
=
name
.
replace
(
shard_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if
"lm_head.weight"
in
name
:
continue
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight.
if
"norm.weight"
in
name
:
...
...
vllm/model_executor/models/gpt2.py
View file @
7c4f76e3
...
...
@@ -17,19 +17,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
import
torch
from
torch
import
nn
from
transformers
import
GPT2Config
from
vllm.
model_executor.input_metadata
import
Input
Metadata
from
vllm.
attention
import
Attention
,
Attention
Metadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
...
...
@@ -40,8 +40,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
GPT2Attention
(
nn
.
Module
):
...
...
@@ -73,21 +71,17 @@ class GPT2Attention(nn.Module):
bias
=
True
,
linear_method
=
linear_method
,
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
scale
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
scale
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
key_cache
,
value_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
,
_
=
self
.
c_proj
(
attn_output
)
return
attn_output
...
...
@@ -145,15 +139,15 @@ class GPT2Block(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
ln_1
(
hidden_states
)
attn_output
=
self
.
attn
(
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
# residual connection
hidden_states
=
attn_output
+
residual
...
...
@@ -191,8 +185,8 @@ class GPT2Model(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
wte
(
input_ids
)
position_embeds
=
self
.
wpe
(
position_ids
)
...
...
@@ -200,7 +194,7 @@ class GPT2Model(nn.Module):
for
i
in
range
(
len
(
self
.
h
)):
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
input
_metadata
)
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
attn
_metadata
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
...
...
@@ -218,26 +212,32 @@ class GPT2LMHeadModel(nn.Module):
self
.
linear_method
=
linear_method
self
.
transformer
=
GPT2Model
(
config
,
linear_method
)
self
.
lm_head_weight
=
self
.
transformer
.
wte
.
weight
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
...
...
vllm/model_executor/models/gpt_bigcode.py
View file @
7c4f76e3
...
...
@@ -18,19 +18,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
import
torch
from
torch
import
nn
from
transformers
import
GPTBigCodeConfig
from
vllm.
model_executor.input_metadata
import
Input
Metadata
from
vllm.
attention
import
Attention
,
Attention
Metadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
...
...
@@ -41,8 +41,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
GPTBigCodeAttention
(
nn
.
Module
):
...
...
@@ -85,16 +83,16 @@ class GPTBigCodeAttention(nn.Module):
bias
=
True
,
linear_method
=
linear_method
,
)
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
scale
,
num_kv_heads
=
self
.
num_kv_heads
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
scale
,
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
(
...
...
@@ -104,9 +102,7 @@ class GPTBigCodeAttention(nn.Module):
],
dim
=-
1
,
)
key_cache
,
value_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
,
_
=
self
.
c_proj
(
attn_output
)
return
attn_output
...
...
@@ -164,15 +160,15 @@ class GPTBigCodeBlock(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
ln_1
(
hidden_states
)
attn_output
=
self
.
attn
(
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
# residual connection
hidden_states
=
attn_output
+
residual
...
...
@@ -210,8 +206,8 @@ class GPTBigCodeModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
wte
(
input_ids
)
position_embeds
=
self
.
wpe
(
position_ids
)
...
...
@@ -219,7 +215,7 @@ class GPTBigCodeModel(nn.Module):
for
i
in
range
(
len
(
self
.
h
)):
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
input
_metadata
)
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
attn
_metadata
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
...
...
@@ -237,26 +233,32 @@ class GPTBigCodeForCausalLM(nn.Module):
self
.
linear_method
=
linear_method
self
.
transformer
=
GPTBigCodeModel
(
config
,
linear_method
)
self
.
lm_head_weight
=
self
.
transformer
.
wte
.
weight
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
...
...
vllm/model_executor/models/gpt_j.py
View file @
7c4f76e3
...
...
@@ -16,23 +16,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-J model compatible with HuggingFace weights."""
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
import
torch
from
torch
import
nn
from
transformers
import
GPTJConfig
from
vllm.
model_executor.input_metadata
import
Input
Metadata
from
vllm.
attention
import
Attention
,
Attention
Metadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
...
@@ -40,8 +40,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
GPTJAttention
(
nn
.
Module
):
...
...
@@ -86,20 +84,19 @@ class GPTJAttention(nn.Module):
base
=
rope_theta
,
is_neox_style
=
False
,
)
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
head_size
,
scaling
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_size
,
scaling
)
def
forward
(
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
,
_
=
self
.
out_proj
(
attn_output
)
return
attn_output
...
...
@@ -143,7 +140,8 @@ class GPTJBlock(nn.Module):
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
inner_dim
=
4
*
config
.
n_embd
if
config
.
n_inner
is
None
else
config
.
n_inner
inner_dim
=
(
4
*
config
.
n_embd
if
config
.
n_inner
is
None
else
config
.
n_inner
)
self
.
ln_1
=
nn
.
LayerNorm
(
config
.
n_embd
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
GPTJAttention
(
config
,
linear_method
)
self
.
mlp
=
GPTJMLP
(
inner_dim
,
config
,
linear_method
)
...
...
@@ -152,8 +150,8 @@ class GPTJBlock(nn.Module):
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
ln_1
(
hidden_states
)
...
...
@@ -161,7 +159,7 @@ class GPTJBlock(nn.Module):
position_ids
=
position_ids
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
mlp_output
=
self
.
mlp
(
hidden_states
)
hidden_states
=
attn_output
+
mlp_output
+
residual
...
...
@@ -190,8 +188,8 @@ class GPTJModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
wte
(
input_ids
)
for
i
in
range
(
len
(
self
.
h
)):
...
...
@@ -200,7 +198,7 @@ class GPTJModel(nn.Module):
position_ids
,
hidden_states
,
kv_caches
[
i
],
input
_metadata
,
attn
_metadata
,
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
...
...
@@ -223,26 +221,32 @@ class GPTJForCausalLM(nn.Module):
config
.
n_embd
,
bias
=
True
,
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
,
self
.
lm_head
.
bias
)
return
logits
def
sample
(
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
,
self
.
lm_head
.
bias
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
...
...
vllm/model_executor/models/gpt_neox.py
View file @
7c4f76e3
...
...
@@ -16,23 +16,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
import
torch
from
torch
import
nn
from
transformers
import
GPTNeoXConfig
from
vllm.
model_executor.input_metadata
import
Input
Metadata
from
vllm.
attention
import
Attention
,
Attention
Metadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
...
@@ -40,8 +40,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
GPTNeoXAttention
(
nn
.
Module
):
...
...
@@ -87,20 +85,19 @@ class GPTNeoXAttention(nn.Module):
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
)
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
head_size
,
scaling
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_size
,
scaling
)
def
forward
(
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
dense
(
attn_output
)
return
output
...
...
@@ -154,15 +151,15 @@ class GPTNeoXLayer(nn.Module):
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
attn_input
=
self
.
input_layernorm
(
hidden_states
)
attn_output
=
self
.
attention
(
position_ids
=
position_ids
,
hidden_states
=
attn_input
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
if
self
.
use_parallel_residual
:
...
...
@@ -207,8 +204,8 @@ class GPTNeoXModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_in
(
input_ids
)
for
i
in
range
(
len
(
self
.
layers
)):
...
...
@@ -217,7 +214,7 @@ class GPTNeoXModel(nn.Module):
position_ids
,
hidden_states
,
kv_caches
[
i
],
input
_metadata
,
attn
_metadata
,
)
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
return
hidden_states
...
...
@@ -238,26 +235,32 @@ class GPTNeoXForCausalLM(nn.Module):
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
gpt_neox
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
embed_out
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
embed_out
.
weight
,
hidden_states
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
...
...
vllm/model_executor/models/internlm2.py
View file @
7c4f76e3
...
...
@@ -5,18 +5,18 @@ import torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.
model_executor.input_metadata
import
Input
Metadata
from
vllm.
attention
import
Attention
,
Attention
Metadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
...
@@ -24,8 +24,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
InternLM2MLP
(
nn
.
Module
):
...
...
@@ -114,23 +112,22 @@ class InternLM2Attention(nn.Module):
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
wqkv
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
wo
(
attn_output
)
return
output
...
...
@@ -171,8 +168,8 @@ class InternLMDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
...
...
@@ -186,7 +183,7 @@ class InternLMDecoderLayer(nn.Module):
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
# Fully Connected
...
...
@@ -220,8 +217,8 @@ class InternLM2Model(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
tok_embeddings
(
input_ids
)
residual
=
None
...
...
@@ -231,7 +228,7 @@ class InternLM2Model(nn.Module):
positions
,
hidden_states
,
kv_caches
[
i
],
input
_metadata
,
attn
_metadata
,
residual
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
...
...
@@ -250,26 +247,32 @@ class InternLM2ForCausalLM(nn.Module):
self
.
linear_method
=
linear_method
self
.
model
=
InternLM2Model
(
config
,
linear_method
)
self
.
output
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
output
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
output
.
weight
,
hidden_states
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
...
...
@@ -305,7 +308,8 @@ class InternLM2ForCausalLM(nn.Module):
param
=
params_dict
[
name
]
if
"wqkv"
in
name
:
config
=
self
.
config
kv_groups
=
config
.
num_attention_heads
//
config
.
num_key_value_heads
kv_groups
=
(
config
.
num_attention_heads
//
config
.
num_key_value_heads
)
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
loaded_weight
=
loaded_weight
.
view
(
-
1
,
2
+
kv_groups
,
head_dim
,
...
...
vllm/model_executor/models/jais.py
0 → 100644
View file @
7c4f76e3
# coding=utf-8
# Adapted from
# https://huggingface.co/core42/jais-30b-chat-v3/blob/main/modeling_jais.py
# Copyright 2023 The vLLM team.
# Copyright 2023 the Jais authors and HuggingFace Inc. team. All rights
# reserved.
# Copyright 2023 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Jais model compatible with HuggingFace weights."""
import
math
from
typing
import
List
,
Optional
import
torch
from
torch
import
nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs
import
JAISConfig
class
SwiGLUActivation
(
nn
.
Module
):
def
forward
(
self
,
x1
:
torch
.
Tensor
,
x2
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
x1
*
nn
.
functional
.
silu
(
x2
)
def
_get_alibi_slopes
(
n
):
def
get_slopes_power_of_2
(
n
):
start
=
2
**
(
-
(
2
**-
(
math
.
log2
(
n
)
-
3
)))
ratio
=
start
return
[
start
*
ratio
**
i
for
i
in
range
(
n
)]
if
math
.
log2
(
n
).
is_integer
():
return
get_slopes_power_of_2
(
n
)
else
:
closest_power_of_2
=
2
**
math
.
floor
(
math
.
log2
(
n
))
return
(
get_slopes_power_of_2
(
closest_power_of_2
)
+
_get_alibi_slopes
(
2
*
closest_power_of_2
)[
0
::
2
][:
n
-
closest_power_of_2
])
class
JAISAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
JAISConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
total_num_heads
=
config
.
num_attention_heads
tensor_model_parallel_world_size
=
(
get_tensor_model_parallel_world_size
())
assert
total_num_heads
%
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
total_num_heads
//
tensor_model_parallel_world_size
self
.
head_dim
=
self
.
hidden_size
//
total_num_heads
if
hasattr
(
config
,
"scale_qk_dot_by_d"
):
config
.
mup_scale_qk_dot_by_d
=
config
.
scale_qk_dot_by_d
self
.
attn_scale_power
=
1.0
if
config
.
mup_scale_qk_dot_by_d
else
0.5
self
.
scale
=
self
.
head_dim
**-
self
.
attn_scale_power
self
.
c_attn
=
QKVParallelLinear
(
self
.
hidden_size
,
self
.
head_dim
,
total_num_heads
,
bias
=
True
,
linear_method
=
linear_method
,
)
self
.
c_proj
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
linear_method
=
linear_method
,
)
tp_rank
=
get_tensor_model_parallel_rank
()
head_start
=
tp_rank
*
self
.
num_heads
head_end
=
(
tp_rank
+
1
)
*
self
.
num_heads
alibi_slopes
=
_get_alibi_slopes
(
total_num_heads
)
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
]
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
scale
,
alibi_slopes
=
alibi_slopes
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
,
_
=
self
.
c_proj
(
attn_output
)
return
attn_output
class
JAISMLP
(
nn
.
Module
):
def
__init__
(
self
,
intermediate_size
:
int
,
config
:
JAISConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
self
.
swiglu
=
config
.
activation_function
==
"swiglu"
self
.
c_fc
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
True
,
linear_method
=
linear_method
,
)
self
.
c_fc2
=
(
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
True
,
linear_method
=
linear_method
,
)
if
self
.
swiglu
else
None
)
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
True
,
linear_method
=
linear_method
,
)
self
.
act
=
SwiGLUActivation
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
swiglu
:
hidden_states2
,
_
=
self
.
c_fc2
(
hidden_states
)
hidden_states
,
_
=
self
.
c_fc
(
hidden_states
)
hidden_states
=
(
self
.
act
(
hidden_states
,
hidden_states2
)
if
self
.
swiglu
else
self
.
act
(
hidden_states
))
hidden_states
,
_
=
self
.
c_proj
(
hidden_states
)
return
hidden_states
class
JAISBlock
(
nn
.
Module
):
def
__init__
(
self
,
config
:
JAISConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
inner_dim
=
(
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
hidden_size
)
self
.
ln_1
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
JAISAttention
(
config
,
linear_method
)
self
.
ln_2
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
JAISMLP
(
inner_dim
,
config
,
linear_method
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
ln_1
(
hidden_states
)
attn_output
=
self
.
attn
(
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
# residual connection
hidden_states
=
attn_output
+
residual
residual
=
hidden_states
hidden_states
=
self
.
ln_2
(
hidden_states
)
feed_forward_hidden_states
=
self
.
mlp
(
hidden_states
)
# residual connection
hidden_states
=
residual
+
feed_forward_hidden_states
return
hidden_states
class
JAISModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
JAISConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
assert
not
config
.
add_cross_attention
assert
not
config
.
scale_attn_by_inverse_layer_idx
assert
not
config
.
reorder_and_upcast_attn
self
.
embed_dim
=
config
.
hidden_size
self
.
wte
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
embed_dim
)
self
.
wpe
=
(
nn
.
Embedding
(
config
.
max_position_embeddings
,
self
.
embed_dim
)
if
config
.
position_embedding_type
!=
"alibi"
else
None
)
if
hasattr
(
config
,
"embeddings_scale"
):
self
.
embeddings_scale
=
config
.
embeddings_scale
else
:
self
.
embeddings_scale
=
config
.
mup_embeddings_scale
self
.
h
=
nn
.
ModuleList
([
JAISBlock
(
config
,
linear_method
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
wte
(
input_ids
)
if
self
.
wpe
is
not
None
:
position_embeds
=
self
.
wpe
(
position_ids
)
hidden_states
=
inputs_embeds
+
position_embeds
else
:
hidden_states
=
inputs_embeds
hidden_states
*=
torch
.
tensor
(
float
(
self
.
embeddings_scale
),
dtype
=
hidden_states
.
dtype
)
for
i
in
range
(
len
(
self
.
h
)):
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
attn_metadata
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
class
JAISLMHeadModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
JAISConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
transformer
=
JAISModel
(
config
,
linear_method
)
self
.
lm_head_weight
=
self
.
transformer
.
wte
.
weight
if
hasattr
(
config
,
"width_scale"
):
self
.
output_logits_scale
=
config
.
width_scale
else
:
self
.
output_logits_scale
=
(
config
.
mup_output_alpha
*
config
.
mup_width_scale
)
self
.
logits_processor
=
LogitsProcessor
(
vocab_size
=
config
.
vocab_size
,
scale
=
self
.
output_logits_scale
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
):
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
"lm_head.weight"
in
name
:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
continue
if
".attn.bias"
in
name
or
".attn.masked_bias"
in
name
:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if
"relative_pe"
in
name
:
continue
if
not
name
.
startswith
(
"transformer."
):
name
=
"transformer."
+
name
param
=
params_dict
[
name
]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for
conv1d_weight_name
in
[
"c_attn"
,
"c_proj"
,
"c_fc"
]:
if
conv1d_weight_name
not
in
name
:
continue
if
not
name
.
endswith
(
".weight"
):
continue
loaded_weight
=
loaded_weight
.
t
()
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/llama.py
View file @
7c4f76e3
...
...
@@ -27,19 +27,19 @@ import torch
from
torch
import
nn
from
transformers
import
LlamaConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
LoRAConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
,
DEFAULT_VOCAB_PADDING_SIZE
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
...
@@ -47,8 +47,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
LlamaMLP
(
nn
.
Module
):
...
...
@@ -139,24 +137,23 @@ class LlamaAttention(nn.Module):
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
sliding_window
=
sliding_window
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
sliding_window
=
sliding_window
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -202,8 +199,8 @@ class LlamaDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
...
...
@@ -217,7 +214,7 @@ class LlamaDecoderLayer(nn.Module):
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
# Fully Connected
...
...
@@ -253,14 +250,21 @@ class LlamaModel(nn.Module):
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
Optional
[
torch
.
Tensor
]
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
...
...
@@ -268,7 +272,7 @@ class LlamaModel(nn.Module):
positions
,
hidden_states
,
kv_caches
[
i
],
input
_metadata
,
attn
_metadata
,
residual
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
...
...
@@ -325,26 +329,35 @@ class LlamaForCausalLM(nn.Module):
# compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
)
self
.
sampler
=
Sampler
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
...
...
vllm/model_executor/models/llava.py
0 → 100644
View file @
7c4f76e3
from
typing
import
List
,
Optional
import
torch
from
torch
import
nn
# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
# transformers' impl.
from
transformers
import
CLIPVisionModel
,
LlavaConfig
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VisionLanguageConfig
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
LinearMethodBase
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.models.llama
import
LlamaModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
_KEYS_TO_MODIFY_MAPPING
=
{
"language_model.lm_head"
:
"lm_head"
,
"language_model.model"
:
"language_model"
,
}
# TODO(xwjiang): Run benchmark and decide if TP.
class
LlavaMultiModalProjector
(
nn
.
Module
):
def
__init__
(
self
,
vision_hidden_size
:
int
,
text_hidden_size
:
int
,
projector_hidden_act
:
str
):
super
().
__init__
()
self
.
linear_1
=
nn
.
Linear
(
vision_hidden_size
,
text_hidden_size
,
bias
=
True
)
self
.
act
=
get_act_fn
(
projector_hidden_act
)
self
.
linear_2
=
nn
.
Linear
(
text_hidden_size
,
text_hidden_size
,
bias
=
True
)
def
forward
(
self
,
image_features
):
hidden_states
=
self
.
linear_1
(
image_features
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
linear_2
(
hidden_states
)
return
hidden_states
def
_merge_vision_embeddings
(
input_ids
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
,
vision_embeddings
:
torch
.
Tensor
,
image_token_id
:
int
):
"""In place merges in vision_embeddings with inputs_embeds."""
mask
=
(
input_ids
==
image_token_id
)
inputs_embeds
[
mask
]
=
vision_embeddings
.
view
(
-
1
,
vision_embeddings
.
shape
[
-
1
])
class
LlavaForConditionalGeneration
(
nn
.
Module
):
def
__init__
(
self
,
config
:
"LlavaConfig"
,
vision_language_config
:
VisionLanguageConfig
,
linear_method
:
Optional
[
"LinearMethodBase"
]
=
None
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
vision_language_config
=
vision_language_config
assert
self
.
vision_language_config
,
(
"Provide `image_input_type` and other vision "
"related configurations through LLM entrypoint "
"or engine arguments."
)
if
self
.
vision_language_config
.
image_input_type
==
(
VisionLanguageConfig
.
ImageInputType
.
PIXEL_VALUES
):
self
.
vision_tower
=
CLIPVisionModel
(
config
.
vision_config
)
else
:
self
.
vision_tower
=
None
self
.
multi_modal_projector
=
LlavaMultiModalProjector
(
vision_hidden_size
=
config
.
vision_config
.
hidden_size
,
text_hidden_size
=
config
.
text_config
.
hidden_size
,
projector_hidden_act
=
config
.
projector_hidden_act
)
self
.
linear_method
=
linear_method
self
.
language_model
=
LlamaModel
(
config
.
text_config
,
linear_method
)
self
.
unpadded_vocab_size
=
config
.
text_config
.
vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
text_config
.
hidden_size
,
org_num_embeddings
=
self
.
language_model
.
org_vocab_size
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
image_input
:
Optional
[
torch
.
Tensor
]
=
None
)
->
SamplerOutput
:
# noqa: E501
"""Run forward pass for Llava 1.5.
One key thing to understand is the `input_ids` already accounts for the
positions of the to-be-inserted image embeddings.
Concretely, consider a text prompt:
"<image>
\n
USER: What's the content of the image?
\n
ASSISTANT:".
Tokenizer outputs:
[1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
The to-be-inserted image has a size of 576 (24 * 24) along the context
length dimension.
`input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
9047, 13566, 29901].
There will be 576 `32000` in the `input_ids`.
(32000 is the token id for `<image>`.)
This way, the `positions` and `attn_metadata` are consistent
with the `input_ids`.
The model takes two types of image inputs:
PIXEL_VALUES and IMAGE_FEATURES.
The following shows how each maps to huggingface implementation.
PIXEL_VALUES:
- https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L353
IMAGE_FEATURES:
- https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L430
before going through the multi modal projector.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
image_input: A batch of image inputs.
For PIXEL_VALUES, expecting [1, 3, 336, 336].
For IMAGE_FEATURES, expecting [1, 576, 1024].
"""
if
image_input
is
not
None
:
if
list
(
image_input
.
shape
[
1
:])
!=
list
(
self
.
vision_language_config
.
image_input_shape
[
1
:]):
raise
ValueError
(
f
"The expected image tensor shape is batch dimension "
f
"plus "
f
"
{
self
.
vision_language_config
.
image_input_shape
[
1
:]
}
."
f
" You supplied
{
image_input
.
shape
}
. "
f
"If you are using vLLM's entrypoint, make sure your "
f
"supplied image input is consistent with "
f
"image_input_shape in engine args."
)
if
self
.
vision_tower
is
not
None
:
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
image_outputs
=
self
.
vision_tower
(
image_input
,
output_hidden_states
=
True
)
image_features
=
image_outputs
.
hidden_states
[
self
.
config
.
vision_feature_layer
]
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
if
self
.
config
.
vision_feature_select_strategy
==
"default"
:
image_features
=
image_features
[:,
1
:]
elif
self
.
config
.
vision_feature_select_strategy
==
"full"
:
image_features
=
image_features
else
:
raise
ValueError
(
f
"Unexpected select feature strategy: "
f
"
{
self
.
config
.
vision_feature_select_strategy
}
"
)
else
:
image_features
=
image_input
vision_embeddings
=
self
.
multi_modal_projector
(
image_features
)
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
_merge_vision_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
vision_language_config
.
image_token_id
)
input_ids
=
None
else
:
inputs_embeds
=
None
hidden_states
=
self
.
language_model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
inputs_embeds
=
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
):
# only doing this for language model part for now.
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
"rotary_emb.inv_freq"
in
name
:
continue
for
key_to_modify
,
new_key
in
_KEYS_TO_MODIFY_MAPPING
.
items
():
if
key_to_modify
in
name
:
name
=
name
.
replace
(
key_to_modify
,
new_key
)
use_default_weight_loading
=
False
if
"vision"
in
name
:
if
self
.
vision_tower
is
not
None
:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading
=
True
else
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
use_default_weight_loading
=
True
if
use_default_weight_loading
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/mixtral.py
View file @
7c4f76e3
...
...
@@ -21,25 +21,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
import
torch
from
torch
import
nn
from
transformers
import
MixtralConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
LoRAConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
,
DEFAULT_VOCAB_PADDING_SIZE
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.communication_op
import
(
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
...
...
@@ -50,8 +50,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
MixtralMoE
(
nn
.
Module
):
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
...
...
@@ -123,9 +121,9 @@ class MixtralMoE(nn.Module):
param_data
[
expert_id
,
:,
:]
=
loaded_weight
[:,
shard
]
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
,
sequence_length
,
hidden_size
=
hidden_states
.
shape
num_tokens
,
hidden_size
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_size
)
# router_logits: (
batch * sequence_length
, n_experts)
# router_logits: (
num_tokens
, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
fused_moe
(
hidden_states
,
self
.
ws
,
...
...
@@ -139,8 +137,7 @@ class MixtralMoE(nn.Module):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
batch_size
,
sequence_length
,
hidden_size
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_size
)
class
MixtralAttention
(
nn
.
Module
):
...
...
@@ -197,7 +194,7 @@ class MixtralAttention(nn.Module):
base
=
int
(
self
.
rope_theta
),
is_neox_style
=
True
,
)
self
.
attn
=
Paged
Attention
(
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
...
...
@@ -209,14 +206,13 @@ class MixtralAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -254,8 +250,8 @@ class MixtralDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
...
...
@@ -269,7 +265,7 @@ class MixtralDecoderLayer(nn.Module):
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
# Fully Connected
...
...
@@ -309,15 +305,15 @@ class MixtralModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
input
_metadata
,
kv_caches
[
i
],
attn
_metadata
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
...
...
@@ -369,26 +365,33 @@ class MixtralForCausalLM(nn.Module):
# compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
)
self
.
sampler
=
Sampler
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
hidden_state
s
:
Optional
[
torch
.
Tensor
],
logit
s
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
...
...
vllm/model_executor/models/mixtral_quant.py
View file @
7c4f76e3
...
...
@@ -21,27 +21,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
transformers
import
MixtralConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
ReplicatedLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.communication_op
import
(
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
...
...
@@ -51,8 +49,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
MixtralMLP
(
nn
.
Module
):
...
...
@@ -131,9 +127,9 @@ class MixtralMoE(nn.Module):
linear_method
=
None
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
,
sequence_length
,
hidden_dim
=
hidden_states
.
shape
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
# router_logits: (
batch * sequence_length
, n_experts)
# router_logits: (
num_tokens
, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
routing_weights
=
F
.
softmax
(
router_logits
,
dim
=
1
,
dtype
=
torch
.
float
)
...
...
@@ -157,7 +153,7 @@ class MixtralMoE(nn.Module):
final_hidden_states
.
add_
(
current_hidden_states
)
return
tensor_model_parallel_all_reduce
(
final_hidden_states
).
view
(
batch_size
,
sequence_length
,
hidden_dim
)
num_tokens
,
hidden_dim
)
class
MixtralAttention
(
nn
.
Module
):
...
...
@@ -214,7 +210,7 @@ class MixtralAttention(nn.Module):
base
=
int
(
self
.
rope_theta
),
is_neox_style
=
True
,
)
self
.
attn
=
Paged
Attention
(
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
...
...
@@ -226,14 +222,13 @@ class MixtralAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
@@ -268,8 +263,8 @@ class MixtralDecoderLayer(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
...
...
@@ -283,7 +278,7 @@ class MixtralDecoderLayer(nn.Module):
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
# Fully Connected
...
...
@@ -318,15 +313,15 @@ class MixtralModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
input
_metadata
,
kv_caches
[
i
],
attn
_metadata
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
...
...
@@ -344,26 +339,32 @@ class MixtralForCausalLM(nn.Module):
self
.
linear_method
=
linear_method
self
.
model
=
MixtralModel
(
config
,
linear_method
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
hidden_state
s
:
Optional
[
torch
.
Tensor
],
logit
s
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
...
...
vllm/model_executor/models/mpt.py
View file @
7c4f76e3
# coding=utf-8
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import
math
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
import
torch
import
torch.nn
as
nn
from
vllm.
model_executor.input_metadata
import
Input
Metadata
from
vllm.
attention
import
Attention
,
Attention
Metadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
...
...
@@ -24,8 +24,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs.mpt
import
MPTConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
def
_get_alibi_slopes
(
total_num_heads
:
int
,
...
...
@@ -105,18 +103,18 @@ class MPTAttention(nn.Module):
self
.
head_dim
=
self
.
d_model
//
self
.
total_num_heads
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scaling
,
alibi_slopes
=
alibi_slopes
,
num_kv_heads
=
self
.
num_kv_heads
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scaling
,
alibi_slopes
=
alibi_slopes
,
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
del
position_ids
# unused.
qkv
,
_
=
self
.
Wqkv
(
hidden_states
)
...
...
@@ -126,8 +124,7 @@ class MPTAttention(nn.Module):
if
self
.
qk_ln
:
q
=
self
.
q_ln
(
q
)
k
=
self
.
k_ln
(
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
...
...
@@ -183,15 +180,15 @@ class MPTBlock(nn.Module):
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input
_metadata
:
Input
Metadata
,
kv_cache
:
torch
.
Tensor
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
x
=
self
.
norm_1
(
hidden_states
)
x
=
self
.
attn
(
position_ids
=
position_ids
,
hidden_states
=
x
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
hidden_states
=
hidden_states
+
x
x
=
self
.
norm_2
(
hidden_states
)
...
...
@@ -229,8 +226,8 @@ class MPTModel(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
wte
(
input_ids
)
for
i
in
range
(
len
(
self
.
blocks
)):
...
...
@@ -239,7 +236,7 @@ class MPTModel(nn.Module):
position_ids
,
hidden_states
,
kv_caches
[
i
],
input
_metadata
,
attn
_metadata
,
)
hidden_states
=
self
.
norm_f
(
hidden_states
)
return
hidden_states
...
...
@@ -259,26 +256,32 @@ class MPTForCausalLM(nn.Module):
self
.
transformer
=
MPTModel
(
config
,
linear_method
)
self
.
lm_head_weight
=
self
.
transformer
.
wte
.
weight
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input
_metadata
:
Input
Metadata
,
kv_caches
:
List
[
torch
.
Tensor
],
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
...
...
vllm/model_executor/models/neuron/llama.py
deleted
100644 → 0
View file @
2da0dd3e
"""Inference-only LLaMA model compatible with HuggingFace weights."""
import
os
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
LlamaConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
LlamaForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
LlamaConfig
,
linear_method
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
None
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
with
torch
.
inference_mode
():
block_size
=
self
.
model
.
context_buckets
[
-
1
]
if
input_metadata
.
is_prompt
:
seq_ids
=
input_metadata
.
slot_mapping
[:,
0
]
//
block_size
else
:
seq_ids
=
input_metadata
.
block_tables
logits
=
self
.
model
(
input_ids
,
cache_ids
=
positions
,
start_ids
=
seq_ids
.
flatten
())
return
logits
def
sample
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
model
.
chkpt_model
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
**
kwargs
):
from
transformers_neuronx.llama.model
import
LlamaForSampling
split_model_dir
=
f
"
{
model_name_or_path
}
-split"
if
os
.
path
.
isdir
(
os
.
path
.
join
(
model_name_or_path
,
"pytorch_model.bin"
)):
split_model_dir
=
model_name_or_path
elif
not
os
.
path
.
exists
(
f
"
{
model_name_or_path
}
-split"
):
from
transformers.models.llama
import
LlamaForCausalLM
from
transformers_neuronx.module
import
save_pretrained_split
hf_model
=
LlamaForCausalLM
.
from_pretrained
(
model_name_or_path
,
low_cpu_mem_usage
=
True
)
save_pretrained_split
(
hf_model
,
f
"
{
model_name_or_path
}
-split"
)
self
.
model
=
LlamaForSampling
.
from_pretrained
(
split_model_dir
,
**
kwargs
)
self
.
model
.
to_neuron
()
Prev
1
…
10
11
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment