Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
db674e3d
"vscode:/vscode.git/clone" did not exist on "e557ed89ff0aefeb2a5ffbc601b39ee88b16f1e3"
Unverified
Commit
db674e3d
authored
Nov 28, 2024
by
Jani Monoses
Committed by
GitHub
Nov 28, 2024
Browse files
Add OLMo2 model. (#2233)
parent
fb915bd1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
393 additions
and
0 deletions
+393
-0
python/sglang/srt/models/olmo2.py
python/sglang/srt/models/olmo2.py
+392
-0
test/srt/models/test_generation_models.py
test/srt/models/test_generation_models.py
+1
-0
No files found.
python/sglang/srt/models/olmo2.py
0 → 100755
View file @
db674e3d
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/olmo2.py
"""Inference-only OLMo2 model compatible with HuggingFace weights."""
from
functools
import
partial
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
split_tensor_along_last_dim
,
tensor_model_parallel_all_gather
,
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
make_layers
class
Olmo2Attention
(
nn
.
Module
):
"""
This is the attention block where the output is computed as
``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def
__init__
(
self
,
config
:
PretrainedConfig
,
layer_id
:
int
=
0
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
config
.
num_attention_heads
assert
self
.
hidden_size
%
self
.
total_num_heads
==
0
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
self
.
config
.
num_key_value_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
self
.
hidden_size
//
self
.
total_num_heads
self
.
max_position_embeddings
=
config
.
max_position_embeddings
self
.
rope_theta
=
config
.
rope_theta
# Attention input projection. Projects x -> (q, k, v)
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
bias
=
config
.
attention_bias
,
)
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
k_norm
=
RMSNorm
(
self
.
total_num_kv_heads
*
self
.
head_dim
,
eps
=
self
.
config
.
rms_norm_eps
,
)
self
.
q_norm
=
RMSNorm
(
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
rms_norm_eps
)
# Rotary embeddings.
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
self
.
max_position_embeddings
,
base
=
self
.
rope_theta
,
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
)
# Attention output projection.
self
.
o_proj
=
RowParallelLinear
(
self
.
head_dim
*
self
.
total_num_heads
,
self
.
hidden_size
,
bias
=
config
.
attention_bias
,
)
def
_apply_qk_norm
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
self
.
tp_size
>
1
:
q
=
tensor_model_parallel_all_gather
(
q
.
contiguous
())
k
=
tensor_model_parallel_all_gather
(
k
.
contiguous
())
q
=
self
.
q_norm
.
forward_native
(
q
)
k
=
self
.
k_norm
.
forward_native
(
k
)
if
self
.
tp_size
>
1
:
splitter
=
partial
(
split_tensor_along_last_dim
,
num_partitions
=
self
.
tp_size
)
q
=
splitter
(
q
)[
self
.
tp_rank
]
k
=
splitter
(
k
)[
self
.
tp_rank
]
return
q
,
k
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
Olmo2MLP
(
nn
.
Module
):
"""
This is the MLP block where the output is computed as
``MLP(x)`` in ``LN(MLP(x + LN(Attention(x))))``
(plus another skip connection).
"""
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
self
.
intermediate_size
=
config
.
intermediate_size
# Feed-forward input projection.
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
self
.
hidden_size
,
[
self
.
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
)
# Activation function.
self
.
act_fn
=
SiluAndMul
()
# Feed-forward output projection.
self
.
down_proj
=
RowParallelLinear
(
self
.
intermediate_size
,
self
.
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
Olmo2DecoderLayer
(
nn
.
Module
):
"""
This is a typical transformer block where the output is
computed as ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def
__init__
(
self
,
config
:
PretrainedConfig
,
layer_id
:
int
=
0
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
# Attention block.
self
.
self_attn
=
Olmo2Attention
(
config
,
layer_id
,
quant_config
)
# MLP block.
self
.
mlp
=
Olmo2MLP
(
config
,
quant_config
)
# RMSNorm
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_feedforward_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
# Attention block.
residual
=
hidden_states
hidden_states
=
self
.
self_attn
(
positions
,
hidden_states
,
forward_batch
)
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
hidden_states
+
residual
# MLP block.
residual
=
hidden_states
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
post_feedforward_layernorm
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
class
Olmo2Model
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
idx
,
prefix
:
Olmo2DecoderLayer
(
layer_id
=
idx
,
config
=
config
,
quant_config
=
quant_config
,
),
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
"""
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
"""
# Get embeddings of input.
# shape: (batch_size, seq_len, d_model)
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
hidden_states
=
input_embeds
# Apply blocks one-by-one.
for
layer_id
,
decoder_layer
in
enumerate
(
self
.
layers
):
# shape: (batch_size, seq_len, d_model)
hidden_states
=
decoder_layer
(
positions
,
hidden_states
,
forward_batch
,
)
# Apply final layer norm.
# shape: (batch_size, seq_len or 1, d_model)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
class
Olmo2ForCausalLM
(
nn
.
Module
):
"""
Extremely barebones HF model wrapper.
"""
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
model
=
Olmo2Model
(
config
,
quant_config
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
quant_config
=
quant_config
,
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
forward_batch
=
forward_batch
,
input_embeds
=
input_embeds
,
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
EntryClass
=
Olmo2ForCausalLM
test/srt/models/test_generation_models.py
View file @
db674e3d
...
@@ -56,6 +56,7 @@ ALL_OTHER_MODELS = [
...
@@ -56,6 +56,7 @@ ALL_OTHER_MODELS = [
ModelCase
(
"THUDM/glm-4-9b-chat"
),
ModelCase
(
"THUDM/glm-4-9b-chat"
),
ModelCase
(
"openai-community/gpt2"
),
ModelCase
(
"openai-community/gpt2"
),
ModelCase
(
"microsoft/Phi-3-small-8k-instruct"
),
ModelCase
(
"microsoft/Phi-3-small-8k-instruct"
),
ModelCase
(
"allenai/OLMo-2-1124-7B-Instruct"
,
skip_long_prompt
=
True
),
]
]
TORCH_DTYPES
=
[
torch
.
float16
]
TORCH_DTYPES
=
[
torch
.
float16
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment