Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b57abe16
Unverified
Commit
b57abe16
authored
Mar 22, 2024
by
Jani Monoses
Committed by
GitHub
Mar 22, 2024
Browse files
Add StableLM model. (#301)
parent
e57f0792
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
293 additions
and
0 deletions
+293
-0
python/sglang/srt/models/stablelm.py
python/sglang/srt/models/stablelm.py
+293
-0
No files found.
python/sglang/srt/models/stablelm.py
0 → 100644
View file @
b57abe16
# This code is based on:
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/stablelm.py
"""Inference-only StableLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b)
model compatible with HuggingFace weights."""
from
typing
import
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.router.model_runner
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
,
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
,
)
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
,
)
class
StablelmMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
self
.
intermediate_size
=
config
.
intermediate_size
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
config
.
hidden_size
,
[
config
.
intermediate_size
]
*
2
,
bias
=
False
,
linear_method
=
linear_method
,
)
self
.
down_proj
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
bias
=
False
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
StablelmAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
layer_id
:
int
=
0
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
config
.
num_attention_heads
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_key_value_heads
=
config
.
num_key_value_heads
if
self
.
total_num_key_value_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_key_value_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_key_value_heads
==
0
self
.
num_key_value_heads
=
max
(
1
,
self
.
total_num_key_value_heads
//
tp_size
)
self
.
head_dim
=
self
.
hidden_size
//
self
.
total_num_heads
self
.
max_position_embeddings
=
config
.
max_position_embeddings
rope_pct
=
getattr
(
config
,
"rope_pct"
,
getattr
(
config
,
"partial_rotary_factor"
,
1
)
)
self
.
rotary_ndims
=
int
(
self
.
head_dim
*
rope_pct
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_key_value_heads
*
self
.
head_dim
self
.
qkv_bias
=
getattr
(
config
,
"use_qkv_bias"
,
False
)
if
(
self
.
head_dim
*
self
.
num_heads
*
tp_size
)
!=
self
.
hidden_size
:
raise
ValueError
(
f
"hidden_size must be divisible by num_heads "
f
"(got `hidden_size`:
{
self
.
hidden_size
}
"
f
" and `num_heads`:
{
self
.
num_heads
}
)."
)
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_key_value_heads
,
self
.
qkv_bias
,
linear_method
=
linear_method
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
rotary_ndims
,
max_position
=
self
.
config
.
max_position_embeddings
,
base
=
self
.
config
.
rope_theta
,
)
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_key_value_heads
,
layer_id
=
layer_id
,
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
StablelmDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
layer_id
:
int
=
0
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
self_attn
=
StablelmAttention
(
config
,
layer_id
=
layer_id
)
self
.
mlp
=
StablelmMLP
(
config
,
linear_method
)
norm_eps
=
getattr
(
config
,
"norm_eps"
,
getattr
(
config
,
"layer_norm_eps"
,
1e-05
))
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
norm_eps
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
)
hidden_states
=
residual
+
hidden_states
# Fully Connected
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
,
residual
class
StableLMEpochModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
)
->
None
:
super
().
__init__
()
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
(
[
StablelmDecoderLayer
(
config
,
i
,
linear_method
)
for
i
in
range
(
config
.
num_hidden_layers
)
]
)
norm_eps
=
getattr
(
config
,
"norm_eps"
,
getattr
(
config
,
"layer_norm_eps"
,
1e-05
))
self
.
norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
hidden_states
=
input_embeds
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
input_metadata
,
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
class
StableLmForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
StableLMEpochModel
(
config
,
linear_method
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
"rotary_emb.inv_freq"
in
name
:
continue
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
EntryClass
=
StableLmForCausalLM
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment