"vscode:/vscode.git/clone" did not exist on "6367bde739dc3cd558f22d8db2021098c6435e31"
Unverified Commit d1f6d1c8 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Model] Add has_weight to RMSNorm and re-enable weights loading tracker for Mamba (#10739)


Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
parent 6d525288
...@@ -20,6 +20,7 @@ class RMSNorm(CustomOp): ...@@ -20,6 +20,7 @@ class RMSNorm(CustomOp):
hidden_size: int, hidden_size: int,
eps: float = 1e-6, eps: float = 1e-6,
var_hidden_size: Optional[int] = None, var_hidden_size: Optional[int] = None,
has_weight: bool = True,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -27,7 +28,11 @@ class RMSNorm(CustomOp): ...@@ -27,7 +28,11 @@ class RMSNorm(CustomOp):
self.variance_epsilon = eps self.variance_epsilon = eps
self.variance_size_override = (None if var_hidden_size == hidden_size self.variance_size_override = (None if var_hidden_size == hidden_size
else var_hidden_size) else var_hidden_size)
self.weight = nn.Parameter(torch.ones(hidden_size)) self.has_weight = has_weight
self.weight = torch.ones(hidden_size)
if self.has_weight:
self.weight = nn.Parameter(self.weight)
def forward_native( def forward_native(
self, self,
...@@ -59,7 +64,9 @@ class RMSNorm(CustomOp): ...@@ -59,7 +64,9 @@ class RMSNorm(CustomOp):
variance = x_var.pow(2).mean(dim=-1, keepdim=True) variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon) x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight x = x.to(orig_dtype)
if self.has_weight:
x = x * self.weight
if residual is None: if residual is None:
return x return x
else: else:
......
...@@ -40,6 +40,7 @@ class MambaMixer(CustomOp): ...@@ -40,6 +40,7 @@ class MambaMixer(CustomOp):
use_conv_bias: bool, use_conv_bias: bool,
use_bias: bool, use_bias: bool,
use_rms_norm: bool, use_rms_norm: bool,
rms_norm_has_weight: bool = True,
rms_norm_eps: float = 1e-5, rms_norm_eps: float = 1e-5,
activation="silu"): activation="silu"):
super().__init__() super().__init__()
...@@ -105,14 +106,23 @@ class MambaMixer(CustomOp): ...@@ -105,14 +106,23 @@ class MambaMixer(CustomOp):
input_is_parallel=True, input_is_parallel=True,
) )
self.dt_layernorm = RMSNorm(time_step_rank, self.dt_layernorm = RMSNorm(
eps=rms_norm_eps) if use_rms_norm else None time_step_rank,
eps=rms_norm_eps,
has_weight=rms_norm_has_weight,
) if use_rms_norm else None
self.b_layernorm = RMSNorm(ssm_state_size, self.b_layernorm = RMSNorm(
eps=rms_norm_eps) if use_rms_norm else None ssm_state_size,
eps=rms_norm_eps,
has_weight=rms_norm_has_weight,
) if use_rms_norm else None
self.c_layernorm = RMSNorm(ssm_state_size, self.c_layernorm = RMSNorm(
eps=rms_norm_eps) if use_rms_norm else None ssm_state_size,
eps=rms_norm_eps,
has_weight=rms_norm_has_weight,
) if use_rms_norm else None
def forward_native(self, hidden_states: torch.Tensor, def forward_native(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
......
"""PyTorch MAMBA model.""" """PyTorch MAMBA model."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Set, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -47,6 +47,7 @@ class MambaDecoderLayer(nn.Module): ...@@ -47,6 +47,7 @@ class MambaDecoderLayer(nn.Module):
use_conv_bias=config.use_conv_bias, use_conv_bias=config.use_conv_bias,
use_bias=config.use_bias, use_bias=config.use_bias,
use_rms_norm=self.is_falcon_mamba, use_rms_norm=self.is_falcon_mamba,
rms_norm_has_weight=not self.is_falcon_mamba,
rms_norm_eps=mixer_rms_eps, rms_norm_eps=mixer_rms_eps,
activation=config.hidden_act) activation=config.hidden_act)
...@@ -241,8 +242,10 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree): ...@@ -241,8 +242,10 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "A_log" in name: if "A_log" in name:
name = name.replace("A_log", "A") name = name.replace("A_log", "A")
...@@ -254,3 +257,5 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree): ...@@ -254,3 +257,5 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment