Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c6d80a7a
Unverified
Commit
c6d80a7a
authored
Aug 20, 2025
by
Jee Jee Li
Committed by
GitHub
Aug 20, 2025
Browse files
[Model] Improve olmo and olmo2 (#23228)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
7cd17e22
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
36 additions
and
7 deletions
+36
-7
docs/models/supported_models.md
docs/models/supported_models.md
+2
-2
vllm/model_executor/models/olmo.py
vllm/model_executor/models/olmo.py
+19
-3
vllm/model_executor/models/olmo2.py
vllm/model_executor/models/olmo2.py
+15
-2
No files found.
docs/models/supported_models.md
View file @
c6d80a7a
...
...
@@ -384,8 +384,8 @@ th {
|
`MPTForCausalLM`
| MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter |
`mosaicml/mpt-7b`
,
`mosaicml/mpt-7b-storywriter`
,
`mosaicml/mpt-30b`
, etc. | | ✅︎ | ✅︎ |
|
`NemotronForCausalLM`
| Nemotron-3, Nemotron-4, Minitron |
`nvidia/Minitron-8B-Base`
,
`mgoin/Nemotron-4-340B-Base-hf-FP8`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`NemotronHForCausalLM`
| Nemotron-H |
`nvidia/Nemotron-H-8B-Base-8K`
,
`nvidia/Nemotron-H-47B-Base-8K`
,
`nvidia/Nemotron-H-56B-Base-8K`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`OLMoForCausalLM`
| OLMo |
`allenai/OLMo-1B-hf`
,
`allenai/OLMo-7B-hf`
, etc. | | ✅︎ | ✅︎ |
|
`OLMo2ForCausalLM`
| OLMo2 |
`allenai/OLMo-2-0425-1B`
, etc. | | ✅︎ | ✅︎ |
|
`OLMoForCausalLM`
| OLMo |
`allenai/OLMo-1B-hf`
,
`allenai/OLMo-7B-hf`
, etc. |
✅︎
| ✅︎ | ✅︎ |
|
`OLMo2ForCausalLM`
| OLMo2 |
`allenai/OLMo-2-0425-1B`
, etc. |
✅︎
| ✅︎ | ✅︎ |
|
`OLMoEForCausalLM`
| OLMoE |
`allenai/OLMoE-1B-7B-0924`
,
`allenai/OLMoE-1B-7B-0924-Instruct`
, etc. | | ✅︎ | ✅︎ |
|
`OPTForCausalLM`
| OPT, OPT-IML |
`facebook/opt-66b`
,
`facebook/opt-iml-max-30b`
, etc. | | ✅︎ | ✅︎ |
|
`OrionForCausalLM`
| Orion |
`OrionStarAI/Orion-14B-Base`
,
`OrionStarAI/Orion-14B-Chat`
, etc. | | ✅︎ | ✅︎ |
...
...
vllm/model_executor/models/olmo.py
View file @
c6d80a7a
...
...
@@ -47,7 +47,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
...
...
@@ -91,6 +91,7 @@ class OlmoAttention(nn.Module):
self
.
total_num_heads
,
bias
=
config
.
attention_bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
# Rotary embeddings.
...
...
@@ -114,6 +115,7 @@ class OlmoAttention(nn.Module):
self
.
hidden_size
,
bias
=
config
.
attention_bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
def
forward
(
...
...
@@ -142,6 +144,7 @@ class OlmoMLP(nn.Module):
self
,
config
:
OlmoConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -154,6 +157,7 @@ class OlmoMLP(nn.Module):
[
self
.
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
)
# Activation function.
...
...
@@ -165,6 +169,7 @@ class OlmoMLP(nn.Module):
self
.
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
)
def
forward
(
...
...
@@ -197,7 +202,7 @@ class OlmoDecoderLayer(nn.Module):
prefix
=
f
"
{
prefix
}
.self_attn"
)
# MLP block.
self
.
mlp
=
OlmoMLP
(
config
,
quant_config
)
self
.
mlp
=
OlmoMLP
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
# LayerNorm
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
...
...
@@ -326,10 +331,21 @@ class OlmoModel(nn.Module):
return
loaded_params
class
OlmoForCausalLM
(
nn
.
Module
,
SupportsPP
):
class
OlmoForCausalLM
(
nn
.
Module
,
SupportsPP
,
SupportsLoRA
):
"""
Extremely barebones HF model wrapper.
"""
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
...
...
vllm/model_executor/models/olmo2.py
View file @
c6d80a7a
...
...
@@ -33,6 +33,7 @@ from torch import nn
from
transformers
import
Olmo2Config
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed.communication_op
import
tensor_model_parallel_all_gather
...
...
@@ -48,7 +49,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
SupportsPP
from
vllm.model_executor.models.interfaces
import
SupportsLoRA
,
SupportsPP
from
vllm.model_executor.models.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
...
...
@@ -253,6 +254,7 @@ class Olmo2DecoderLayer(nn.Module):
return
hidden_states
@
support_torch_compile
class
Olmo2Model
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
@@ -354,10 +356,21 @@ class Olmo2Model(nn.Module):
return
loaded_params
class
Olmo2ForCausalLM
(
nn
.
Module
,
SupportsPP
):
class
Olmo2ForCausalLM
(
nn
.
Module
,
SupportsPP
,
SupportsLoRA
):
"""
Extremely barebones HF model wrapper.
"""
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment