Unverified Commit d6bb2a9d authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Fix Plamo 2/3 & LFM2 for Transformers v5 (#38090)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 1e673a43
...@@ -52,7 +52,7 @@ class Lfm2MLP(nn.Module): ...@@ -52,7 +52,7 @@ class Lfm2MLP(nn.Module):
def __init__( def __init__(
self, self,
dim: int, dim: int,
ff_dim: int, intermediate_size: int,
multiple_of: int, multiple_of: int,
auto_adjust_ff_dim: bool, auto_adjust_ff_dim: bool,
ffn_dim_multiplier: float | None, ffn_dim_multiplier: float | None,
...@@ -61,21 +61,23 @@ class Lfm2MLP(nn.Module): ...@@ -61,21 +61,23 @@ class Lfm2MLP(nn.Module):
): ):
super().__init__() super().__init__()
if auto_adjust_ff_dim: if auto_adjust_ff_dim:
ff_dim = int(2 * ff_dim / 3) intermediate_size = int(2 * intermediate_size / 3)
# custom dim factor multiplier # custom dim factor multiplier
if ffn_dim_multiplier is not None: if ffn_dim_multiplier is not None:
ff_dim = int(ffn_dim_multiplier * ff_dim) intermediate_size = int(ffn_dim_multiplier * intermediate_size)
ff_dim = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of) intermediate_size = multiple_of * (
(intermediate_size + multiple_of - 1) // multiple_of
)
self.w13 = MergedColumnParallelLinear( self.w13 = MergedColumnParallelLinear(
input_size=dim, input_size=dim,
output_sizes=[ff_dim] * 2, output_sizes=[intermediate_size] * 2,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.w13", prefix=f"{prefix}.w13",
) )
self.w2 = RowParallelLinear( self.w2 = RowParallelLinear(
input_size=ff_dim, input_size=intermediate_size,
output_size=dim, output_size=dim,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
...@@ -212,7 +214,7 @@ class Lfm2AttentionDecoderLayer(nn.Module): ...@@ -212,7 +214,7 @@ class Lfm2AttentionDecoderLayer(nn.Module):
self.feed_forward = Lfm2MLP( self.feed_forward = Lfm2MLP(
dim=config.block_dim, dim=config.block_dim,
ff_dim=config.block_ff_dim, intermediate_size=config.intermediate_size,
multiple_of=config.block_multiple_of, multiple_of=config.block_multiple_of,
auto_adjust_ff_dim=config.block_auto_adjust_ff_dim, auto_adjust_ff_dim=config.block_auto_adjust_ff_dim,
ffn_dim_multiplier=config.block_ffn_dim_multiplier, ffn_dim_multiplier=config.block_ffn_dim_multiplier,
...@@ -262,7 +264,7 @@ class Lfm2ShortConvDecoderLayer(nn.Module): ...@@ -262,7 +264,7 @@ class Lfm2ShortConvDecoderLayer(nn.Module):
self.feed_forward = Lfm2MLP( self.feed_forward = Lfm2MLP(
dim=config.block_dim, dim=config.block_dim,
ff_dim=config.block_ff_dim, intermediate_size=config.intermediate_size,
multiple_of=config.block_multiple_of, multiple_of=config.block_multiple_of,
auto_adjust_ff_dim=config.block_auto_adjust_ff_dim, auto_adjust_ff_dim=config.block_auto_adjust_ff_dim,
ffn_dim_multiplier=config.block_ffn_dim_multiplier, ffn_dim_multiplier=config.block_ffn_dim_multiplier,
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice from itertools import islice
from typing import TYPE_CHECKING
import torch import torch
from torch import nn from torch import nn
...@@ -71,9 +72,10 @@ from vllm.utils.torch_utils import direct_register_custom_op ...@@ -71,9 +72,10 @@ from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
# Only used for type hinting. # Only used for type hinting.
class Plamo2Config(PretrainedConfig): # type: ignore if TYPE_CHECKING:
class Plamo2Config(PretrainedConfig): # type: ignore
model_type: str = "plamo2" model_type: str = "plamo2"
hidden_size: int hidden_size: int
...@@ -94,7 +96,7 @@ class Plamo2Config(PretrainedConfig): # type: ignore ...@@ -94,7 +96,7 @@ class Plamo2Config(PretrainedConfig): # type: ignore
vocab_size: int vocab_size: int
def is_mamba(config: Plamo2Config, i: int) -> bool: def is_mamba(config: "Plamo2Config", i: int) -> bool:
assert config.mamba_step > 1 assert config.mamba_step > 1
if config.num_hidden_layers <= (config.mamba_step // 2): if config.num_hidden_layers <= (config.mamba_step // 2):
...@@ -502,7 +504,7 @@ direct_register_custom_op( ...@@ -502,7 +504,7 @@ direct_register_custom_op(
class DenseMLP(nn.Module): class DenseMLP(nn.Module):
def __init__( def __init__(
self, self,
config: Plamo2Config, config: "Plamo2Config",
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice from itertools import islice
from typing import Any from typing import TYPE_CHECKING, Any
import torch import torch
from torch import nn from torch import nn
...@@ -46,9 +46,10 @@ from vllm.model_executor.models.utils import ( ...@@ -46,9 +46,10 @@ from vllm.model_executor.models.utils import (
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
# Only used for type hinting. # Only used for type hinting.
class Plamo3Config(PretrainedConfig): # type: ignore if TYPE_CHECKING:
class Plamo3Config(PretrainedConfig): # type: ignore
model_type: str = "plamo3" model_type: str = "plamo3"
hidden_size: int hidden_size: int
...@@ -80,7 +81,7 @@ def rms_norm_weight_loader(offset: float) -> LoaderFunction: ...@@ -80,7 +81,7 @@ def rms_norm_weight_loader(offset: float) -> LoaderFunction:
class DenseMLP(nn.Module): class DenseMLP(nn.Module):
def __init__( def __init__(
self, self,
config: Plamo3Config, config: "Plamo3Config",
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment