Unverified Commit d6bb2a9d authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Fix Plamo 2/3 & LFM2 for Transformers v5 (#38090)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 1e673a43
......@@ -52,7 +52,7 @@ class Lfm2MLP(nn.Module):
def __init__(
self,
dim: int,
ff_dim: int,
intermediate_size: int,
multiple_of: int,
auto_adjust_ff_dim: bool,
ffn_dim_multiplier: float | None,
......@@ -61,21 +61,23 @@ class Lfm2MLP(nn.Module):
):
super().__init__()
if auto_adjust_ff_dim:
ff_dim = int(2 * ff_dim / 3)
intermediate_size = int(2 * intermediate_size / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
ff_dim = int(ffn_dim_multiplier * ff_dim)
ff_dim = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)
intermediate_size = int(ffn_dim_multiplier * intermediate_size)
intermediate_size = multiple_of * (
(intermediate_size + multiple_of - 1) // multiple_of
)
self.w13 = MergedColumnParallelLinear(
input_size=dim,
output_sizes=[ff_dim] * 2,
output_sizes=[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.w13",
)
self.w2 = RowParallelLinear(
input_size=ff_dim,
input_size=intermediate_size,
output_size=dim,
bias=False,
quant_config=quant_config,
......@@ -212,7 +214,7 @@ class Lfm2AttentionDecoderLayer(nn.Module):
self.feed_forward = Lfm2MLP(
dim=config.block_dim,
ff_dim=config.block_ff_dim,
intermediate_size=config.intermediate_size,
multiple_of=config.block_multiple_of,
auto_adjust_ff_dim=config.block_auto_adjust_ff_dim,
ffn_dim_multiplier=config.block_ffn_dim_multiplier,
......@@ -262,7 +264,7 @@ class Lfm2ShortConvDecoderLayer(nn.Module):
self.feed_forward = Lfm2MLP(
dim=config.block_dim,
ff_dim=config.block_ff_dim,
intermediate_size=config.intermediate_size,
multiple_of=config.block_multiple_of,
auto_adjust_ff_dim=config.block_auto_adjust_ff_dim,
ffn_dim_multiplier=config.block_ffn_dim_multiplier,
......
......@@ -4,6 +4,7 @@
from collections.abc import Iterable
from itertools import islice
from typing import TYPE_CHECKING
import torch
from torch import nn
......@@ -71,9 +72,10 @@ from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
# Only used for type hinting.
class Plamo2Config(PretrainedConfig): # type: ignore
if TYPE_CHECKING:
class Plamo2Config(PretrainedConfig): # type: ignore
model_type: str = "plamo2"
hidden_size: int
......@@ -94,7 +96,7 @@ class Plamo2Config(PretrainedConfig): # type: ignore
vocab_size: int
def is_mamba(config: Plamo2Config, i: int) -> bool:
def is_mamba(config: "Plamo2Config", i: int) -> bool:
assert config.mamba_step > 1
if config.num_hidden_layers <= (config.mamba_step // 2):
......@@ -502,7 +504,7 @@ direct_register_custom_op(
class DenseMLP(nn.Module):
def __init__(
self,
config: Plamo2Config,
config: "Plamo2Config",
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
......
......@@ -4,7 +4,7 @@
from collections.abc import Iterable
from itertools import islice
from typing import Any
from typing import TYPE_CHECKING, Any
import torch
from torch import nn
......@@ -46,9 +46,10 @@ from vllm.model_executor.models.utils import (
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
# Only used for type hinting.
class Plamo3Config(PretrainedConfig): # type: ignore
if TYPE_CHECKING:
class Plamo3Config(PretrainedConfig): # type: ignore
model_type: str = "plamo3"
hidden_size: int
......@@ -80,7 +81,7 @@ def rms_norm_weight_loader(offset: float) -> LoaderFunction:
class DenseMLP(nn.Module):
def __init__(
self,
config: Plamo3Config,
config: "Plamo3Config",
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment