"vscode:/vscode.git/clone" did not exist on "2ce72b9c22e92492cf4407b1d5ea8e01411c6040"
Unverified Commit c6d80a7a authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Model] Improve olmo and olmo2 (#23228)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 7cd17e22
......@@ -384,8 +384,8 @@ th {
| `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | ✅︎ |
| `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | | ✅︎ | ✅︎ |
| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | | ✅︎ | ✅︎ |
| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ |
| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ | ✅︎ |
| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ |
......
......@@ -47,7 +47,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
......@@ -91,6 +91,7 @@ class OlmoAttention(nn.Module):
self.total_num_heads,
bias=config.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
# Rotary embeddings.
......@@ -114,6 +115,7 @@ class OlmoAttention(nn.Module):
self.hidden_size,
bias=config.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
def forward(
......@@ -142,6 +144,7 @@ class OlmoMLP(nn.Module):
self,
config: OlmoConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
......@@ -154,6 +157,7 @@ class OlmoMLP(nn.Module):
[self.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
# Activation function.
......@@ -165,6 +169,7 @@ class OlmoMLP(nn.Module):
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
def forward(
......@@ -197,7 +202,7 @@ class OlmoDecoderLayer(nn.Module):
prefix=f"{prefix}.self_attn")
# MLP block.
self.mlp = OlmoMLP(config, quant_config)
self.mlp = OlmoMLP(config, quant_config, prefix=f"{prefix}.mlp")
# LayerNorm
self.input_layernorm = nn.LayerNorm(config.hidden_size,
......@@ -326,10 +331,21 @@ class OlmoModel(nn.Module):
return loaded_params
class OlmoForCausalLM(nn.Module, SupportsPP):
class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
"""
Extremely barebones HF model wrapper.
"""
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
......
......@@ -33,6 +33,7 @@ from torch import nn
from transformers import Olmo2Config
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.communication_op import tensor_model_parallel_all_gather
......@@ -48,7 +49,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsPP
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
from vllm.model_executor.models.utils import (
AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
......@@ -253,6 +254,7 @@ class Olmo2DecoderLayer(nn.Module):
return hidden_states
@support_torch_compile
class Olmo2Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......@@ -354,10 +356,21 @@ class Olmo2Model(nn.Module):
return loaded_params
class Olmo2ForCausalLM(nn.Module, SupportsPP):
class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
"""
Extremely barebones HF model wrapper.
"""
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment