Unverified Commit ad44437b authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Bugfix] Fix Mamba model initialization and MLP Speculator weights loading (#10456)


Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
parent 9e05252b
"""PyTorch MAMBA model.""" """PyTorch MAMBA model."""
from typing import Iterable, List, Optional, Set, Tuple from typing import Iterable, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -243,10 +243,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree): ...@@ -243,10 +243,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "A_log" in name: if "A_log" in name:
name = name.replace("A_log", "A") name = name.replace("A_log", "A")
...@@ -258,5 +256,3 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree): ...@@ -258,5 +256,3 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -193,7 +193,8 @@ class MLPSpeculator(nn.Module): ...@@ -193,7 +193,8 @@ class MLPSpeculator(nn.Module):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
param = params_dict.get(name.replace("speculator.", "")) name = name.replace("speculator.", "")
param = params_dict.get(name)
if param is not None: if param is not None:
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment