Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
a130cf33
Commit
a130cf33
authored
Mar 06, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.3.3' into vllm-v0.3.2-dtk23.10 and add gfx
parents
a2d181be
82091b86
Changes
106
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
661 additions
and
336 deletions
+661
-336
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+3
-3
vllm/model_executor/models/decilm.py
vllm/model_executor/models/decilm.py
+1
-1
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+66
-53
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+6
-2
vllm/model_executor/models/neuron/llama.py
vllm/model_executor/models/neuron/llama.py
+79
-0
vllm/model_executor/models/olmo.py
vllm/model_executor/models/olmo.py
+3
-1
vllm/model_executor/models/orion.py
vllm/model_executor/models/orion.py
+64
-119
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+4
-4
vllm/model_executor/models/stablelm.py
vllm/model_executor/models/stablelm.py
+10
-6
vllm/model_executor/models/starcoder2.py
vllm/model_executor/models/starcoder2.py
+310
-0
vllm/model_executor/neuron_model_loader.py
vllm/model_executor/neuron_model_loader.py
+66
-0
vllm/model_executor/parallel_utils/custom_all_reduce.py
vllm/model_executor/parallel_utils/custom_all_reduce.py
+2
-2
vllm/model_executor/parallel_utils/parallel_state.py
vllm/model_executor/parallel_utils/parallel_state.py
+1
-1
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+2
-2
vllm/model_executor/utils.py
vllm/model_executor/utils.py
+17
-0
vllm/sampling_params.py
vllm/sampling_params.py
+15
-0
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+10
-2
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+2
-6
vllm/transformers_utils/configs/baichuan.py
vllm/transformers_utils/configs/baichuan.py
+0
-62
vllm/transformers_utils/configs/olmo.py
vllm/transformers_utils/configs/olmo.py
+0
-72
No files found.
vllm/model_executor/models/baichuan.py
View file @
a130cf33
...
@@ -23,6 +23,7 @@ from typing import List, Optional, Tuple
...
@@ -23,6 +23,7 @@ from typing import List, Optional, Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
...
@@ -42,7 +43,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
...
@@ -42,7 +43,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs.baichuan
import
BaiChuanConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
@@ -186,7 +186,7 @@ class BaiChuanAttention(nn.Module):
...
@@ -186,7 +186,7 @@ class BaiChuanAttention(nn.Module):
class
BaiChuanDecoderLayer
(
nn
.
Module
):
class
BaiChuanDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
config
:
BaiChuan
Config
,
config
:
Pretrained
Config
,
position_embedding
:
str
,
position_embedding
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
super
().
__init__
()
super
().
__init__
()
...
@@ -245,7 +245,7 @@ class BaiChuanDecoderLayer(nn.Module):
...
@@ -245,7 +245,7 @@ class BaiChuanDecoderLayer(nn.Module):
class
BaiChuanModel
(
nn
.
Module
):
class
BaiChuanModel
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
config
:
BaiChuan
Config
,
config
:
Pretrained
Config
,
position_embedding
:
str
,
position_embedding
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
super
().
__init__
()
super
().
__init__
()
...
...
vllm/model_executor/models/decilm.py
View file @
a130cf33
...
@@ -41,7 +41,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
...
@@ -41,7 +41,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
Based on the llama executor.
Based on the llama executor.
The main difference is that DeciLM uses Variable Grouped Query Attention.
The main difference is that DeciLM uses Variable Grouped Query Attention.
The constant number of GQA heads in the decoder is overriden with a value
The constant number of GQA heads in the decoder is overrid
d
en with a value
per layer.
per layer.
Usually, in the HuggingFace implementation, instead of
Usually, in the HuggingFace implementation, instead of
...
...
vllm/model_executor/models/gemma.py
View file @
a130cf33
...
@@ -20,10 +20,13 @@ import torch
...
@@ -20,10 +20,13 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
GemmaConfig
from
transformers
import
GemmaConfig
from
vllm.config
import
LoRAConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.layernorm
import
RMSNorm
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
...
@@ -40,21 +43,6 @@ from vllm.sequence import SamplerOutput
...
@@ -40,21 +43,6 @@ from vllm.sequence import SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
GemmaRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
eps
:
float
=
1e-6
):
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
zeros
(
dim
))
def
_norm
(
self
,
x
):
return
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
self
.
eps
)
def
forward
(
self
,
x
):
output
=
self
.
_norm
(
x
.
float
()).
type_as
(
x
)
return
output
*
(
1
+
self
.
weight
)
class
GemmaMLP
(
nn
.
Module
):
class
GemmaMLP
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -64,27 +52,21 @@ class GemmaMLP(nn.Module):
...
@@ -64,27 +52,21 @@ class GemmaMLP(nn.Module):
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
gate_proj
=
ColumnParallelLinear
(
hidden_size
,
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
intermediate_size
,
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
linear_method
=
linear_method
)
self
.
up_proj
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
False
,
linear_method
=
linear_method
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
linear_method
=
linear_method
)
self
.
act_fn
=
nn
.
GELU
()
self
.
act_fn
=
GeluAndMul
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
gate
,
_
=
self
.
gate_proj
(
x
)
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
gate
=
self
.
act_fn
(
gate
)
x
=
self
.
act_fn
(
gate_up
)
up
,
_
=
self
.
up_proj
(
x
)
x
,
_
=
self
.
down_proj
(
x
)
fuse
=
gate
*
up
return
x
outputs
,
_
=
self
.
down_proj
(
fuse
)
return
outputs
class
GemmaAttention
(
nn
.
Module
):
class
GemmaAttention
(
nn
.
Module
):
...
@@ -185,10 +167,10 @@ class GemmaDecoderLayer(nn.Module):
...
@@ -185,10 +167,10 @@ class GemmaDecoderLayer(nn.Module):
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
input_layernorm
=
Gemma
RMSNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
Gemma
RMSNorm
(
config
.
hidden_size
,
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -196,25 +178,27 @@ class GemmaDecoderLayer(nn.Module):
...
@@ -196,25 +178,27 @@ class GemmaDecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
# Self Attention
residual
=
hidden_states
if
residual
is
None
:
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
kv_cache
=
kv_cache
,
input_metadata
=
input_metadata
,
input_metadata
=
input_metadata
,
)
)
hidden_states
=
residual
+
hidden_states
# Fully Connected
# Fully Connected
residual
=
hidden_states
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
,
residual
return
hidden_states
class
GemmaModel
(
nn
.
Module
):
class
GemmaModel
(
nn
.
Module
):
...
@@ -235,7 +219,7 @@ class GemmaModel(nn.Module):
...
@@ -235,7 +219,7 @@ class GemmaModel(nn.Module):
GemmaDecoderLayer
(
config
,
linear_method
)
GemmaDecoderLayer
(
config
,
linear_method
)
for
_
in
range
(
config
.
num_hidden_layers
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
])
self
.
norm
=
Gemma
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -246,27 +230,53 @@ class GemmaModel(nn.Module):
...
@@ -246,27 +230,53 @@ class GemmaModel(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
# Normalize the embedding by sqrt(hidden_size)
# Normalize the embedding by sqrt(hidden_size)
hidden_states
=
hidden_states
*
(
self
.
config
.
hidden_size
**
0.5
)
hidden_states
*
=
self
.
config
.
hidden_size
**
0.5
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
positions
,
hidden_states
,
hidden_states
,
kv_caches
[
i
],
kv_caches
[
i
],
input_metadata
,
input_metadata
,
residual
,
)
)
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
class
GemmaForCausalLM
(
nn
.
Module
):
class
GemmaForCausalLM
(
nn
.
Module
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
# LoRA specific attributes
supported_lora_modules
=
[
"qkv_proj"
,
"o_proj"
,
"gate_up_proj"
,
"down_proj"
,
]
# Gemma does not apply LoRA to the embedding layer.
embedding_modules
=
{}
embedding_padding_modules
=
[]
def
__init__
(
def
__init__
(
self
,
self
,
config
:
GemmaConfig
,
config
:
GemmaConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
)
->
None
:
del
lora_config
# Unused.
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
...
@@ -304,6 +314,8 @@ class GemmaForCausalLM(nn.Module):
...
@@ -304,6 +314,8 @@ class GemmaForCausalLM(nn.Module):
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
=
set
()
loaded_params
=
set
()
...
@@ -318,9 +330,10 @@ class GemmaForCausalLM(nn.Module):
...
@@ -318,9 +330,10 @@ class GemmaForCausalLM(nn.Module):
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
# Skip loading extra layer for lora models.
# GemmaRMSNorm is different from Llama's in that it multiplies
if
"lm_head"
in
name
:
# (1 + weight) to the output, instead of just weight.
continue
if
"norm.weight"
in
name
:
loaded_weight
+=
1.0
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
...
@@ -329,5 +342,5 @@ class GemmaForCausalLM(nn.Module):
...
@@ -329,5 +342,5 @@ class GemmaForCausalLM(nn.Module):
unloaded_params
=
params_dict
.
keys
()
-
loaded_params
unloaded_params
=
params_dict
.
keys
()
-
loaded_params
if
unloaded_params
:
if
unloaded_params
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Some weights are not initialized from checkpoints:
{
unloaded_params
}
"
"Some weights are not initialized from checkpoints: "
)
f
"
{
unloaded_params
}
"
)
vllm/model_executor/models/llama.py
View file @
a130cf33
...
@@ -27,6 +27,7 @@ import torch
...
@@ -27,6 +27,7 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
LlamaConfig
from
transformers
import
LlamaConfig
from
vllm.config
import
LoRAConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.attention
import
PagedAttention
...
@@ -45,7 +46,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
...
@@ -45,7 +46,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
from
vllm.config
import
LoRAConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
@@ -92,6 +92,7 @@ class LlamaAttention(nn.Module):
...
@@ -92,6 +92,7 @@ class LlamaAttention(nn.Module):
max_position_embeddings
:
int
=
8192
,
max_position_embeddings
:
int
=
8192
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
bias
:
bool
=
False
,
bias
:
bool
=
False
,
sliding_window
:
Optional
[
int
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -141,7 +142,8 @@ class LlamaAttention(nn.Module):
...
@@ -141,7 +142,8 @@ class LlamaAttention(nn.Module):
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
self
.
scaling
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
)
num_kv_heads
=
self
.
num_kv_heads
,
sliding_window
=
sliding_window
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -172,6 +174,7 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -172,6 +174,7 @@ class LlamaDecoderLayer(nn.Module):
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
8192
)
sliding_window
=
getattr
(
config
,
"sliding_window"
,
None
)
self
.
self_attn
=
LlamaAttention
(
self
.
self_attn
=
LlamaAttention
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_heads
=
config
.
num_attention_heads
,
...
@@ -182,6 +185,7 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -182,6 +185,7 @@ class LlamaDecoderLayer(nn.Module):
max_position_embeddings
=
max_position_embeddings
,
max_position_embeddings
=
max_position_embeddings
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
bias
=
getattr
(
config
,
"bias"
,
False
),
bias
=
getattr
(
config
,
"bias"
,
False
),
sliding_window
=
sliding_window
,
)
)
self
.
mlp
=
LlamaMLP
(
self
.
mlp
=
LlamaMLP
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
...
...
vllm/model_executor/models/neuron/llama.py
0 → 100644
View file @
a130cf33
"""Inference-only LLaMA model compatible with HuggingFace weights."""
import
os
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
LlamaConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
LlamaForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
LlamaConfig
,
linear_method
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
None
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
with
torch
.
inference_mode
():
block_size
=
self
.
model
.
context_buckets
[
-
1
]
if
input_metadata
.
is_prompt
:
seq_ids
=
input_metadata
.
slot_mapping
[:,
0
]
//
block_size
else
:
seq_ids
=
input_metadata
.
block_tables
logits
=
self
.
model
(
input_ids
,
cache_ids
=
positions
,
start_ids
=
seq_ids
.
flatten
())
return
logits
def
sample
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
model
.
chkpt_model
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
**
kwargs
):
from
transformers_neuronx.llama.model
import
LlamaForSampling
split_model_dir
=
f
"
{
model_name_or_path
}
-split"
if
os
.
path
.
isdir
(
os
.
path
.
join
(
model_name_or_path
,
"pytorch_model.bin"
)):
split_model_dir
=
model_name_or_path
elif
not
os
.
path
.
exists
(
f
"
{
model_name_or_path
}
-split"
):
from
transformers.models.llama
import
LlamaForCausalLM
from
transformers_neuronx.module
import
save_pretrained_split
hf_model
=
LlamaForCausalLM
.
from_pretrained
(
model_name_or_path
,
low_cpu_mem_usage
=
True
)
save_pretrained_split
(
hf_model
,
f
"
{
model_name_or_path
}
-split"
)
self
.
model
=
LlamaForSampling
.
from_pretrained
(
split_model_dir
,
**
kwargs
)
self
.
model
.
to_neuron
()
vllm/model_executor/models/olmo.py
View file @
a130cf33
...
@@ -61,7 +61,9 @@ from vllm.model_executor.weight_utils import (
...
@@ -61,7 +61,9 @@ from vllm.model_executor.weight_utils import (
hf_model_weights_iterator
,
hf_model_weights_iterator
,
)
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs.olmo
import
OLMoConfig
# this model must need this dependency
from
hf_olmo
import
OLMoConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
...
vllm/model_executor/models/
mistral
.py
→
vllm/model_executor/models/
orion
.py
View file @
a130cf33
# coding=utf-8
# coding=utf-8
# Adapted from
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/modeling_orion.py
# Copyright 2023 The vLLM team.
# Copyright (c) OrionStar Inc.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
#
"""Inference-only Orion-14B model compatible with HuggingFace weights."""
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mistral model compatible with HuggingFace weights."""
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
Mistral
Config
from
transformers
import
Pretrained
Config
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
...
@@ -38,19 +20,18 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
...
@@ -38,19 +20,18 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
,
DEFAULT_VOCAB_PADDING_SIZE
)
VocabParallelEmbedding
,
ParallelLMHead
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
from
vllm.config
import
LoRAConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
Mistral
MLP
(
nn
.
Module
):
class
Orion
MLP
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -80,16 +61,18 @@ class MistralMLP(nn.Module):
...
@@ -80,16 +61,18 @@ class MistralMLP(nn.Module):
return
x
return
x
class
Mistral
Attention
(
nn
.
Module
):
class
Orion
Attention
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
hidden_size
:
int
,
self
,
num_heads
:
int
,
hidden_size
:
int
,
num_kv_heads
:
int
,
num_heads
:
int
,
max_position
:
int
=
4096
*
32
,
num_kv_heads
:
int
,
rope_theta
:
float
=
10000
,
rope_theta
:
float
=
10000
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
)
->
None
:
max_position_embeddings
:
int
=
8192
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
...
@@ -111,7 +94,7 @@ class MistralAttention(nn.Module):
...
@@ -111,7 +94,7 @@ class MistralAttention(nn.Module):
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
rope_theta
=
rope_theta
self
.
sliding_window
=
sliding_window
self
.
max_position_embeddings
=
max_position_embeddings
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
hidden_size
,
...
@@ -131,14 +114,14 @@ class MistralAttention(nn.Module):
...
@@ -131,14 +114,14 @@ class MistralAttention(nn.Module):
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position
,
max_position
=
max_position_embeddings
,
base
=
self
.
rope_theta
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
self
.
scaling
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
num_kv_heads
=
self
.
num_kv_heads
)
sliding_window
=
self
.
sliding_window
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -156,35 +139,39 @@ class MistralAttention(nn.Module):
...
@@ -156,35 +139,39 @@ class MistralAttention(nn.Module):
return
output
return
output
class
Mistral
DecoderLayer
(
nn
.
Module
):
class
Orion
DecoderLayer
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
Mistral
Config
,
config
:
Pretrained
Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
# Requires transformers > 4.32.0
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
self
.
self_attn
=
MistralAttention
(
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
self
.
self_attn
=
OrionAttention
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_heads
=
config
.
num_attention_heads
,
max_position
=
config
.
max_position_embeddings
,
num_kv_heads
=
config
.
num_key_value_heads
,
num_kv_heads
=
config
.
num_key_value_heads
,
rope_theta
=
rope_theta
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
sliding_window
=
config
.
sliding_window
)
)
self
.
mlp
=
Mistral
MLP
(
self
.
mlp
=
Orion
MLP
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -195,12 +182,8 @@ class MistralDecoderLayer(nn.Module):
...
@@ -195,12 +182,8 @@ class MistralDecoderLayer(nn.Module):
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
...
@@ -208,39 +191,36 @@ class MistralDecoderLayer(nn.Module):
...
@@ -208,39 +191,36 @@ class MistralDecoderLayer(nn.Module):
input_metadata
=
input_metadata
,
input_metadata
=
input_metadata
,
)
)
hidden_states
=
residual
+
hidden_states
# Fully Connected
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
residual
=
hidden_states
hidden_states
,
residual
)
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
hidden_states
=
residual
+
hidden_states
return
hidden_states
,
None
class
Mistral
Model
(
nn
.
Module
):
class
Orion
Model
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
Mistral
Config
,
config
:
Pretrained
Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
self
.
vocab_size
=
config
.
vocab_size
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
self
.
org_vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
)
)
self
.
layers
=
nn
.
ModuleList
([
self
.
layers
=
nn
.
ModuleList
([
Mistral
DecoderLayer
(
config
,
linear_method
)
Orion
DecoderLayer
(
config
,
linear_method
)
for
_
in
range
(
config
.
num_hidden_layers
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
])
self
.
norm
=
RMS
Norm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
nn
.
Layer
Norm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -260,63 +240,23 @@ class MistralModel(nn.Module):
...
@@ -260,63 +240,23 @@ class MistralModel(nn.Module):
input_metadata
,
input_metadata
,
residual
,
residual
,
)
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
return
hidden_states
class
MistralForCausalLM
(
nn
.
Module
):
class
OrionForCausalLM
(
nn
.
Module
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
# LoRA specific attributes
supported_lora_modules
=
[
"qkv_proj"
,
"o_proj"
,
"gate_up_proj"
,
"down_proj"
,
"embed_tokens"
,
"lm_head"
,
]
embedding_modules
=
{
"embed_tokens"
:
"input_embeddings"
,
"lm_head"
:
"output_embeddings"
,
}
embedding_padding_modules
=
[
"lm_head"
]
def
__init__
(
def
__init__
(
self
,
self
,
config
:
Mistral
Config
,
config
:
Pretrained
Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
model
=
MistralModel
(
config
,
self
.
model
=
OrionModel
(
config
,
linear_method
)
linear_method
,
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
lora_config
=
lora_config
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
ParallelLMHead
(
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
)
self
.
sampler
=
Sampler
(
unpadded_vocab_size
,
config
.
vocab_size
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -356,6 +296,11 @@ class MistralForCausalLM(nn.Module):
...
@@ -356,6 +296,11 @@ class MistralForCausalLM(nn.Module):
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
if
(
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
...
...
vllm/model_executor/models/qwen.py
View file @
a130cf33
...
@@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Tuple
...
@@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
...
@@ -27,7 +28,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
...
@@ -27,7 +28,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs.qwen
import
QWenConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
@@ -127,7 +127,7 @@ class QWenBlock(nn.Module):
...
@@ -127,7 +127,7 @@ class QWenBlock(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
QWen
Config
,
config
:
Pretrained
Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -179,7 +179,7 @@ class QWenModel(nn.Module):
...
@@ -179,7 +179,7 @@ class QWenModel(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
QWen
Config
,
config
:
Pretrained
Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -222,7 +222,7 @@ class QWenLMHeadModel(nn.Module):
...
@@ -222,7 +222,7 @@ class QWenLMHeadModel(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
QWen
Config
,
config
:
Pretrained
Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
...
...
vllm/model_executor/models/stablelm.py
View file @
a130cf33
...
@@ -94,7 +94,9 @@ class StablelmAttention(nn.Module):
...
@@ -94,7 +94,9 @@ class StablelmAttention(nn.Module):
1
,
self
.
total_num_key_value_heads
//
tp_size
)
1
,
self
.
total_num_key_value_heads
//
tp_size
)
self
.
head_dim
=
self
.
hidden_size
//
self
.
total_num_heads
self
.
head_dim
=
self
.
hidden_size
//
self
.
total_num_heads
self
.
max_position_embeddings
=
config
.
max_position_embeddings
self
.
max_position_embeddings
=
config
.
max_position_embeddings
self
.
rotary_ndims
=
int
(
self
.
head_dim
*
self
.
config
.
rope_pct
)
rope_pct
=
getattr
(
config
,
"rope_pct"
,
getattr
(
config
,
"partial_rotary_factor"
,
1
))
self
.
rotary_ndims
=
int
(
self
.
head_dim
*
rope_pct
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_key_value_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_key_value_heads
*
self
.
head_dim
...
@@ -114,7 +116,6 @@ class StablelmAttention(nn.Module):
...
@@ -114,7 +116,6 @@ class StablelmAttention(nn.Module):
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
False
,
bias
=
False
,
linear_method
=
linear_method
)
linear_method
=
linear_method
)
self
.
rotary_ndims
=
int
(
self
.
head_dim
*
self
.
config
.
rope_pct
)
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
self
.
head_dim
,
rotary_dim
=
self
.
rotary_ndims
,
rotary_dim
=
self
.
rotary_ndims
,
...
@@ -152,10 +153,11 @@ class StablelmDecoderLayer(nn.Module):
...
@@ -152,10 +153,11 @@ class StablelmDecoderLayer(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
self_attn
=
StablelmAttention
(
config
)
self
.
self_attn
=
StablelmAttention
(
config
)
self
.
mlp
=
StablelmMLP
(
config
,
linear_method
)
self
.
mlp
=
StablelmMLP
(
config
,
linear_method
)
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
norm_eps
=
getattr
(
config
,
"norm_eps"
,
eps
=
config
.
norm_eps
)
getattr
(
config
,
"layer_norm_eps"
,
1e-05
))
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
norm_eps
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
norm_eps
)
eps
=
norm_eps
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -199,7 +201,9 @@ class StableLMEpochModel(nn.Module):
...
@@ -199,7 +201,9 @@ class StableLMEpochModel(nn.Module):
StablelmDecoderLayer
(
config
,
linear_method
)
StablelmDecoderLayer
(
config
,
linear_method
)
for
_
in
range
(
config
.
num_hidden_layers
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
])
self
.
norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
norm_eps
)
norm_eps
=
getattr
(
config
,
"norm_eps"
,
getattr
(
config
,
"layer_norm_eps"
,
1e-05
))
self
.
norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
norm_eps
)
def
forward
(
def
forward
(
self
,
self
,
...
...
vllm/model_executor/models/starcoder2.py
0 → 100644
View file @
a130cf33
# coding=utf-8
# Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Starcoder2 model."""
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
,
DEFAULT_VOCAB_PADDING_SIZE
)
from
vllm.model_executor.parallel_utils.parallel_state
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
try
:
from
transformers
import
Starcoder2Config
except
ImportError
:
# fallback to PretrainedConfig
# NOTE: Please install transformers from source or use transformers>=4.39.0
from
transformers
import
PretrainedConfig
as
Starcoder2Config
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
Starcoder2Attention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Starcoder2Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
config
.
num_attention_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
config
.
num_key_value_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
self
.
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
config
.
rope_theta
self
.
max_position_embeddings
=
config
.
max_position_embeddings
self
.
use_bias
=
config
.
use_bias
self
.
sliding_window
=
config
.
sliding_window
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
self
.
use_bias
,
linear_method
=
linear_method
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
hidden_size
,
bias
=
self
.
use_bias
,
linear_method
=
linear_method
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
self
.
max_position_embeddings
,
base
=
int
(
self
.
rope_theta
),
is_neox_style
=
True
,
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
sliding_window
=
self
.
sliding_window
,
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
Starcoder2MLP
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Starcoder2Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
super
().
__init__
()
self
.
c_fc
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
intermediate_size
,
bias
=
config
.
use_bias
,
linear_method
=
linear_method
,
)
self
.
c_proj
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
bias
=
config
.
use_bias
,
linear_method
=
linear_method
,
)
self
.
act
=
get_act_fn
(
config
.
hidden_act
,
intermediate_size
=
config
.
intermediate_size
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
c_fc
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
,
_
=
self
.
c_proj
(
hidden_states
)
return
hidden_states
class
Starcoder2DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Starcoder2Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
self_attn
=
Starcoder2Attention
(
config
,
linear_method
=
linear_method
)
self
.
mlp
=
Starcoder2MLP
(
config
,
linear_method
=
linear_method
)
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
norm_epsilon
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
norm_epsilon
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
# Self Attention
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input_metadata
=
input_metadata
,
)
hidden_states
=
residual
+
hidden_states
# Fully Connected
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
class
Starcoder2Model
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Starcoder2Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
# TODO: consider padding_idx (currently removed)
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
([
Starcoder2DecoderLayer
(
config
,
linear_method
=
linear_method
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
norm_epsilon
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
input_metadata
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
class
Starcoder2ForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Starcoder2Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
model
=
Starcoder2Model
(
config
,
linear_method
=
linear_method
)
self
.
vocab_size
=
config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
config
.
tie_word_embeddings
:
self
.
lm_head_weight
=
self
.
model
.
embed_tokens
.
weight
else
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
,
)
self
.
lm_head_weight
=
self
.
lm_head
.
weight
self
.
sampler
=
Sampler
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
)
return
hidden_states
def
sample
(
self
,
hidden_states
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
"rotary_emb.inv_freq"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/neuron_model_loader.py
0 → 100644
View file @
a130cf33
"""Utilities for selecting and loading models."""
from
typing
import
Type
import
torch
import
torch.nn
as
nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
ModelConfig
,
DeviceConfig
from
vllm.model_executor.models
import
ModelRegistry
TORCH_DTYPE_TO_NEURON_AMP
=
{
"auto"
:
"f32"
,
"half"
:
"f16"
,
"float16"
:
"f16"
,
"bfloat16"
:
"bf16"
,
"float"
:
"f32"
,
"float32"
:
"f32"
,
torch
.
float16
:
"f16"
,
torch
.
bfloat16
:
"bf16"
,
torch
.
float32
:
"f32"
,
}
def
_get_model_architecture
(
config
:
PretrainedConfig
)
->
Type
[
nn
.
Module
]:
architectures
=
getattr
(
config
,
"architectures"
,
[])
for
arch
in
architectures
:
model_cls
=
ModelRegistry
.
load_model_cls
(
arch
)
if
model_cls
is
not
None
:
return
model_cls
raise
ValueError
(
f
"Model architectures
{
architectures
}
are not supported for now. "
f
"Supported architectures:
{
ModelRegistry
.
get_supported_archs
()
}
"
)
def
get_model
(
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
**
kwargs
)
->
nn
.
Module
:
from
transformers_neuronx.config
import
NeuronConfig
,
ContinuousBatchingConfig
parallel_config
=
kwargs
.
get
(
"parallel_config"
)
scheduler_config
=
kwargs
.
get
(
"scheduler_config"
)
model_class
=
_get_model_architecture
(
model_config
.
hf_config
)
linear_method
=
None
# Create a model instance.
model
=
model_class
(
model_config
.
hf_config
,
linear_method
)
continuous_batching_config
=
ContinuousBatchingConfig
(
batch_size_for_shared_caches
=
scheduler_config
.
max_num_seqs
)
neuron_config
=
NeuronConfig
(
continuous_batching
=
continuous_batching_config
)
# Load the weights from the cached or downloaded files.
model
.
load_weights
(
model_config
.
model
,
model_config
.
download_dir
,
model_config
.
load_format
,
model_config
.
revision
,
tp_degree
=
parallel_config
.
neuron_tp_degree
,
amp
=
TORCH_DTYPE_TO_NEURON_AMP
[
model_config
.
dtype
],
neuron_config
=
neuron_config
,
context_length_estimate
=
[
scheduler_config
.
max_model_len
],
n_positions
=
[
scheduler_config
.
max_model_len
],
batch_size
=
scheduler_config
.
max_num_seqs
)
return
model
.
eval
()
vllm/model_executor/parallel_utils/custom_all_reduce.py
View file @
a130cf33
...
@@ -36,14 +36,14 @@ def init_custom_ar() -> None:
...
@@ -36,14 +36,14 @@ def init_custom_ar() -> None:
if
world_size
not
in
_SUPPORTED_WORLD_SIZES
:
if
world_size
not
in
_SUPPORTED_WORLD_SIZES
:
logger
.
warn
(
logger
.
warn
(
"Custom allreduce is disabled due to an unsupported world size: "
"Custom allreduce is disabled due to an unsupported world size: "
"%d. Supported world sizes: %s. To s
l
ience this warning, specify"
"%d. Supported world sizes: %s. To si
l
ence this warning, specify"
"disable_custom_all_reduce=True explicitly."
,
world_size
,
"disable_custom_all_reduce=True explicitly."
,
world_size
,
str
(
_SUPPORTED_WORLD_SIZES
))
str
(
_SUPPORTED_WORLD_SIZES
))
return
return
if
not
_can_p2p
(
rank
,
world_size
):
if
not
_can_p2p
(
rank
,
world_size
):
logger
.
warn
(
logger
.
warn
(
"Custom allreduce is disabled because your platform lacks GPU P2P"
"Custom allreduce is disabled because your platform lacks GPU P2P"
" capability. To s
l
ience this warning, specify"
" capability. To si
l
ence this warning, specify"
"disable_custom_all_reduce=True explicitly."
)
"disable_custom_all_reduce=True explicitly."
)
return
return
_CA_HANDLE
=
CustomAllreduce
(
rank
,
world_size
)
_CA_HANDLE
=
CustomAllreduce
(
rank
,
world_size
)
...
...
vllm/model_executor/parallel_utils/parallel_state.py
View file @
a130cf33
...
@@ -189,7 +189,7 @@ def get_pipeline_model_parallel_next_rank():
...
@@ -189,7 +189,7 @@ def get_pipeline_model_parallel_next_rank():
def
get_pipeline_model_parallel_prev_rank
():
def
get_pipeline_model_parallel_prev_rank
():
"""Return the global rank that prece
e
ds the caller in the pipeline"""
"""Return the global rank that preced
e
s the caller in the pipeline"""
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
(
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
(
"Pipeline parallel group is not initialized"
)
"Pipeline parallel group is not initialized"
)
rank_in_pipeline
=
get_pipeline_model_parallel_rank
()
rank_in_pipeline
=
get_pipeline_model_parallel_rank
()
...
...
vllm/model_executor/sampling_metadata.py
View file @
a130cf33
...
@@ -5,7 +5,7 @@ import torch
...
@@ -5,7 +5,7 @@ import torch
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SequenceData
from
vllm.sequence
import
SequenceData
from
vllm.utils
import
in_wsl
from
vllm.utils
import
in_wsl
,
is_neuron
_SAMPLING_EPS
=
1e-5
_SAMPLING_EPS
=
1e-5
...
@@ -155,7 +155,7 @@ class SamplingTensors:
...
@@ -155,7 +155,7 @@ class SamplingTensors:
dtype
:
torch
.
dtype
)
->
"SamplingTensors"
:
dtype
:
torch
.
dtype
)
->
"SamplingTensors"
:
# Note that the performance will be very bad without
# Note that the performance will be very bad without
# pinned memory.
# pinned memory.
pin_memory
=
not
in_wsl
()
pin_memory
=
not
in_wsl
()
and
not
is_neuron
()
prompt_max_len
=
max
(
len
(
tokens
)
for
tokens
in
prompt_tokens
)
prompt_max_len
=
max
(
len
(
tokens
)
for
tokens
in
prompt_tokens
)
prompt_padded_tokens
=
[
prompt_padded_tokens
=
[
tokens
+
[
vocab_size
]
*
(
prompt_max_len
-
len
(
tokens
))
tokens
+
[
vocab_size
]
*
(
prompt_max_len
-
len
(
tokens
))
...
...
vllm/model_executor/utils.py
View file @
a130cf33
"""Utils for model executor."""
"""Utils for model executor."""
import
random
import
random
import
importlib
from
typing
import
Any
,
Dict
,
Optional
from
typing
import
Any
,
Dict
,
Optional
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
vllm.config
import
DeviceConfig
,
ModelConfig
DEVICE_TO_MODEL_LOADER_MAP
=
{
"cuda"
:
"model_loader"
,
"neuron"
:
"neuron_model_loader"
,
}
def
set_random_seed
(
seed
:
int
)
->
None
:
def
set_random_seed
(
seed
:
int
)
->
None
:
random
.
seed
(
seed
)
random
.
seed
(
seed
)
...
@@ -33,3 +41,12 @@ def set_weight_attrs(
...
@@ -33,3 +41,12 @@ def set_weight_attrs(
assert
not
hasattr
(
assert
not
hasattr
(
weight
,
key
),
(
f
"Overwriting existing tensor attribute:
{
key
}
"
)
weight
,
key
),
(
f
"Overwriting existing tensor attribute:
{
key
}
"
)
setattr
(
weight
,
key
,
value
)
setattr
(
weight
,
key
,
value
)
def
get_model
(
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
**
kwargs
)
->
torch
.
nn
.
Module
:
model_loader_module
=
DEVICE_TO_MODEL_LOADER_MAP
[
device_config
.
device_type
]
imported_model_loader
=
importlib
.
import_module
(
f
"vllm.model_executor.
{
model_loader_module
}
"
)
get_model_fn
=
imported_model_loader
.
get_model
return
get_model_fn
(
model_config
,
device_config
,
**
kwargs
)
vllm/sampling_params.py
View file @
a130cf33
"""Sampling parameters for text generation."""
"""Sampling parameters for text generation."""
import
copy
from
enum
import
IntEnum
from
enum
import
IntEnum
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
Callable
,
List
,
Optional
,
Union
from
typing
import
Callable
,
List
,
Optional
,
Union
...
@@ -237,6 +238,20 @@ class SamplingParams:
...
@@ -237,6 +238,20 @@ class SamplingParams:
return
SamplingType
.
RANDOM_SEED
return
SamplingType
.
RANDOM_SEED
return
SamplingType
.
RANDOM
return
SamplingType
.
RANDOM
def
clone
(
self
)
->
"SamplingParams"
:
"""Deep copy excluding LogitsProcessor objects.
LogitsProcessor objects are excluded because they may contain an
arbitrary, nontrivial amount of data.
See https://github.com/vllm-project/vllm/issues/3087
"""
logit_processor_refs
=
None
if
self
.
logits_processors
is
None
else
{
id
(
lp
):
lp
for
lp
in
self
.
logits_processors
}
return
copy
.
deepcopy
(
self
,
memo
=
logit_processor_refs
)
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
return
(
f
"SamplingParams(n=
{
self
.
n
}
, "
f
"SamplingParams(n=
{
self
.
n
}
, "
...
...
vllm/transformers_utils/config.py
View file @
a130cf33
...
@@ -5,12 +5,11 @@ from transformers import AutoConfig, PretrainedConfig
...
@@ -5,12 +5,11 @@ from transformers import AutoConfig, PretrainedConfig
from
vllm.transformers_utils.configs
import
*
from
vllm.transformers_utils.configs
import
*
_CONFIG_REGISTRY
=
{
_CONFIG_REGISTRY
=
{
"baichuan"
:
BaiChuanConfig
,
"chatglm"
:
ChatGLMConfig
,
"chatglm"
:
ChatGLMConfig
,
"mpt"
:
MPTConfig
,
"mpt"
:
MPTConfig
,
"qwen"
:
QWenConfig
,
"RefinedWeb"
:
RWConfig
,
# For tiiuae/falcon-40b(-instruct)
"RefinedWeb"
:
RWConfig
,
# For tiiuae/falcon-40b(-instruct)
"RefinedWebModel"
:
RWConfig
,
# For tiiuae/falcon-7b(-instruct)
"RefinedWebModel"
:
RWConfig
,
# For tiiuae/falcon-7b(-instruct)
"starcoder2"
:
Starcoder2Config
,
}
}
...
@@ -18,6 +17,15 @@ def get_config(model: str,
...
@@ -18,6 +17,15 @@ def get_config(model: str,
trust_remote_code
:
bool
,
trust_remote_code
:
bool
,
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
code_revision
:
Optional
[
str
]
=
None
)
->
PretrainedConfig
:
code_revision
:
Optional
[
str
]
=
None
)
->
PretrainedConfig
:
# FIXME(woosuk): This is a temporary fix for StarCoder2.
# Remove this when the model is supported by HuggingFace transformers.
if
"bigcode"
in
model
and
"starcoder2"
in
model
:
config_class
=
_CONFIG_REGISTRY
[
"starcoder2"
]
config
=
config_class
.
from_pretrained
(
model
,
revision
=
revision
,
code_revision
=
code_revision
)
return
config
try
:
try
:
config
=
AutoConfig
.
from_pretrained
(
config
=
AutoConfig
.
from_pretrained
(
model
,
model
,
...
...
vllm/transformers_utils/configs/__init__.py
View file @
a130cf33
from
vllm.transformers_utils.configs.baichuan
import
BaiChuanConfig
from
vllm.transformers_utils.configs.chatglm
import
ChatGLMConfig
from
vllm.transformers_utils.configs.chatglm
import
ChatGLMConfig
from
vllm.transformers_utils.configs.mpt
import
MPTConfig
from
vllm.transformers_utils.configs.mpt
import
MPTConfig
from
vllm.transformers_utils.configs.olmo
import
OLMoConfig
from
vllm.transformers_utils.configs.qwen
import
QWenConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library.
# `FalconConfig` class from the official HuggingFace transformers library.
from
vllm.transformers_utils.configs.falcon
import
RWConfig
from
vllm.transformers_utils.configs.falcon
import
RWConfig
from
vllm.transformers_utils.configs.starcoder2
import
Starcoder2Config
__all__
=
[
__all__
=
[
"BaiChuanConfig"
,
"ChatGLMConfig"
,
"ChatGLMConfig"
,
"MPTConfig"
,
"MPTConfig"
,
"OLMoConfig"
,
"QWenConfig"
,
"RWConfig"
,
"RWConfig"
,
"Starcoder2Config"
,
]
]
vllm/transformers_utils/configs/baichuan.py
deleted
100644 → 0
View file @
a2d181be
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
transformers.configuration_utils
import
PretrainedConfig
class
BaiChuanConfig
(
PretrainedConfig
):
model_type
=
"baichuan"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
def
__init__
(
self
,
vocab_size
=
64000
,
hidden_size
=
4096
,
intermediate_size
=
11008
,
num_hidden_layers
=
32
,
num_attention_heads
=
32
,
hidden_act
=
"silu"
,
max_position_embeddings
=
4096
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
use_cache
=
True
,
pad_token_id
=
0
,
bos_token_id
=
1
,
eos_token_id
=
2
,
tie_word_embeddings
=
False
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_cache
=
use_cache
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
,
)
vllm/transformers_utils/configs/olmo.py
deleted
100644 → 0
View file @
a2d181be
# coding=utf-8
# adapted from https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/configuration_olmo.py
"""OLMo configuration"""
from
transformers
import
PretrainedConfig
class
OLMoConfig
(
PretrainedConfig
):
model_type
=
'olmo'
attribute_map
=
{
'num_attention_heads'
:
'n_heads'
,
'hidden_size'
:
'd_model'
,
'num_hidden_layers'
:
'n_layers'
,
}
# Note that the defaults for these attributes are equivalent to the base GPT2 model.
def
__init__
(
self
,
d_model
=
768
,
n_heads
=
12
,
n_layers
=
12
,
mlp_ratio
=
4
,
mlp_hidden_size
=
None
,
activation_type
=
"swiglu"
,
block_type
=
"sequential"
,
block_group_size
=
1
,
alibi
=
False
,
alibi_bias_max
=
8.0
,
rope
=
False
,
rope_full_precision
=
True
,
multi_query_attention
=
False
,
attention_layer_norm
=
False
,
layer_norm_type
=
"default"
,
layer_norm_with_affine
=
True
,
attention_layer_norm_with_affine
=
True
,
max_sequence_length
=
1024
,
include_bias
=
True
,
bias_for_layer_norm
=
None
,
scale_logits
=
False
,
vocab_size
=
50257
,
embedding_size
=
50304
,
weight_tying
=
True
,
eos_token_id
=
50256
,
pad_token_id
=
50256
,
**
kwargs
,
):
self
.
d_model
=
d_model
self
.
n_heads
=
n_heads
self
.
n_layers
=
n_layers
self
.
mlp_ratio
=
mlp_ratio
self
.
mlp_hidden_size
=
mlp_hidden_size
self
.
activation_type
=
activation_type
self
.
block_type
=
block_type
self
.
block_group_size
=
block_group_size
self
.
alibi
=
alibi
self
.
alibi_bias_max
=
alibi_bias_max
self
.
rope
=
rope
self
.
rope_full_precision
=
rope_full_precision
self
.
multi_query_attention
=
multi_query_attention
self
.
attention_layer_norm
=
attention_layer_norm
self
.
layer_norm_type
=
layer_norm_type
self
.
layer_norm_with_affine
=
layer_norm_with_affine
self
.
attention_layer_norm_with_affine
=
attention_layer_norm_with_affine
self
.
max_sequence_length
=
max_sequence_length
self
.
include_bias
=
include_bias
self
.
bias_for_layer_norm
=
bias_for_layer_norm
self
.
scale_logits
=
scale_logits
self
.
vocab_size
=
vocab_size
self
.
embedding_size
=
embedding_size
self
.
weight_tying
=
weight_tying
self
.
eos_token_id
=
eos_token_id
self
.
pad_token_id
=
pad_token_id
super
().
__init__
(
**
kwargs
)
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment