Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a5114b6f
Unverified
Commit
a5114b6f
authored
Oct 16, 2024
by
Jani Monoses
Committed by
GitHub
Oct 16, 2024
Browse files
Add OLMo model (#1676)
parent
b6b40946
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
357 additions
and
1 deletion
+357
-1
python/sglang/srt/models/olmo.py
python/sglang/srt/models/olmo.py
+352
-0
test/srt/models/test_generation_models.py
test/srt/models/test_generation_models.py
+5
-1
No files found.
python/sglang/srt/models/olmo.py
0 → 100755
View file @
a5114b6f
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/olmo.py#L1
"""Inference-only OLMo model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
OlmoConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
class
OlmoAttention
(
nn
.
Module
):
"""
This is the attention block where the output is computed as
``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def
__init__
(
self
,
config
:
OlmoConfig
,
layer_id
:
int
=
0
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
config
.
num_attention_heads
assert
self
.
hidden_size
%
self
.
total_num_heads
==
0
assert
self
.
total_num_heads
%
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tensor_model_parallel_world_size
self
.
head_dim
=
self
.
hidden_size
//
self
.
total_num_heads
self
.
max_position_embeddings
=
config
.
max_position_embeddings
self
.
rope_theta
=
config
.
rope_theta
self
.
clip_qkv
=
config
.
clip_qkv
# Attention input projection. Projects x -> (q, k, v)
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
bias
=
config
.
attention_bias
,
)
# Rotary embeddings.
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
self
.
max_position_embeddings
,
base
=
self
.
rope_theta
,
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_heads
,
layer_id
=
layer_id
,
)
# Attention output projection.
self
.
o_proj
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
config
.
attention_bias
,
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
if
self
.
clip_qkv
is
not
None
:
qkv
.
clamp_
(
min
=-
self
.
clip_qkv
,
max
=
self
.
clip_qkv
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
OlmoMLP
(
nn
.
Module
):
"""
This is the MLP block where the output is computed as
``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def
__init__
(
self
,
config
:
OlmoConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
self
.
intermediate_size
=
config
.
intermediate_size
# Feed-forward input projection.
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
self
.
hidden_size
,
[
self
.
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
)
# Activation function.
self
.
act_fn
=
SiluAndMul
()
# Feed-forward output projection.
self
.
down_proj
=
RowParallelLinear
(
self
.
intermediate_size
,
self
.
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
OlmoDecoderLayer
(
nn
.
Module
):
"""
This is a typical transformer block where the output is
computed as ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def
__init__
(
self
,
config
:
OlmoConfig
,
layer_id
:
int
=
0
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
# Attention block.
self
.
self_attn
=
OlmoAttention
(
config
,
layer_id
,
quant_config
)
# MLP block.
self
.
mlp
=
OlmoMLP
(
config
,
quant_config
)
# LayerNorm
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
elementwise_affine
=
False
,
bias
=
False
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
elementwise_affine
=
False
,
bias
=
False
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]:
# Attention block.
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
positions
,
hidden_states
,
forward_batch
)
hidden_states
=
hidden_states
+
residual
# MLP block.
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
class
OlmoModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
OlmoConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
(
[
OlmoDecoderLayer
(
config
,
layer_idx
,
quant_config
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
elementwise_affine
=
False
,
bias
=
False
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
"""
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
"""
# Get embeddings of input.
# shape: (batch_size, seq_len, d_model)
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
hidden_states
=
input_embeds
# Apply blocks one-by-one.
for
layer_idx
,
decoder_layer
in
enumerate
(
self
.
layers
):
# shape: (batch_size, seq_len, d_model)
hidden_states
=
decoder_layer
(
positions
,
hidden_states
,
forward_batch
,
)
# Apply final layer norm.
# shape: (batch_size, seq_len or 1, d_model)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
class
OlmoForCausalLM
(
nn
.
Module
):
"""
Extremely barebones HF model wrapper.
"""
def
__init__
(
self
,
config
:
OlmoConfig
,
cache_config
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
model
=
OlmoModel
(
config
,
quant_config
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
quant_config
=
quant_config
,
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
forward_batch
=
forward_batch
,
input_embeds
=
input_embeds
,
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
EntryClass
=
OlmoForCausalLM
test/srt/models/test_generation_models.py
100644 → 100755
View file @
a5114b6f
...
@@ -53,6 +53,7 @@ ALL_OTHER_MODELS = [
...
@@ -53,6 +53,7 @@ ALL_OTHER_MODELS = [
ModelCase
(
"Qwen/Qwen2-1.5B"
),
ModelCase
(
"Qwen/Qwen2-1.5B"
),
ModelCase
(
"Qwen/Qwen2.5-14B-Instruct"
),
ModelCase
(
"Qwen/Qwen2.5-14B-Instruct"
),
ModelCase
(
"HuggingFaceTB/SmolLM-135M-Instruct"
),
ModelCase
(
"HuggingFaceTB/SmolLM-135M-Instruct"
),
ModelCase
(
"allenai/OLMo-1B-0724-hf"
,
decode_tolerance
=
8e-2
),
]
]
TORCH_DTYPES
=
[
torch
.
float16
]
TORCH_DTYPES
=
[
torch
.
float16
]
...
@@ -153,7 +154,10 @@ class TestGenerationModels(unittest.TestCase):
...
@@ -153,7 +154,10 @@ class TestGenerationModels(unittest.TestCase):
# Skip long prompts for models that does not have a long context
# Skip long prompts for models that does not have a long context
prompts
=
DEFAULT_PROMPTS
prompts
=
DEFAULT_PROMPTS
if
model_case
.
model_path
in
[
"HuggingFaceTB/SmolLM-135M-Instruct"
]:
if
model_case
.
model_path
in
[
"HuggingFaceTB/SmolLM-135M-Instruct"
,
"allenai/OLMo-1B-0724-hf"
,
]:
prompts
=
[
p
for
p
in
DEFAULT_PROMPTS
if
len
(
p
)
<
1000
]
prompts
=
[
p
for
p
in
DEFAULT_PROMPTS
if
len
(
p
)
<
1000
]
# Assert the logits and output strs are close
# Assert the logits and output strs are close
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment