Unverified Commit c4e46433 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Misc] Add uninitialized params tracking for `AutoWeightsLoader` (#10327)


Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
parent d1557e66
...@@ -312,7 +312,8 @@ class Gemma2Model(nn.Module): ...@@ -312,7 +312,8 @@ class Gemma2Model(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -354,6 +355,7 @@ class Gemma2Model(nn.Module): ...@@ -354,6 +355,7 @@ class Gemma2Model(nn.Module):
logger.warning( logger.warning(
"Some weights are not initialized from checkpoints: %s", "Some weights are not initialized from checkpoints: %s",
unloaded_params) unloaded_params)
return loaded_params
class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...@@ -451,13 +453,14 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -451,13 +453,14 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=(["lm_head."] skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None), if self.config.tie_word_embeddings else None),
) )
loader.load_weights(weights) return loader.load_weights(weights)
class Gemma2EmbeddingModel(nn.Module, SupportsPP): class Gemma2EmbeddingModel(nn.Module, SupportsPP):
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPT-2 model compatible with HuggingFace weights.""" """Inference-only GPT-2 model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -298,8 +298,10 @@ class GPT2LMHeadModel(nn.Module, SupportsPP): ...@@ -298,8 +298,10 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "lm_head.weight" in name: if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final # GPT-2 ties the weights of the embedding layer and the final
...@@ -328,3 +330,5 @@ class GPT2LMHeadModel(nn.Module, SupportsPP): ...@@ -328,3 +330,5 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPTBigCode model compatible with HuggingFace weights.""" """Inference-only GPTBigCode model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -323,8 +323,10 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -323,8 +323,10 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "lm_head.weight" in name: if "lm_head.weight" in name:
continue continue
...@@ -344,3 +346,5 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -344,3 +346,5 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader(param, loaded_weight, 'v') weight_loader(param, loaded_weight, 'v')
else: else:
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPT-J model compatible with HuggingFace weights.""" """Inference-only GPT-J model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -291,7 +291,8 @@ class GPTJForCausalLM(nn.Module, SupportsPP): ...@@ -291,7 +291,8 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -301,6 +302,7 @@ class GPTJForCausalLM(nn.Module, SupportsPP): ...@@ -301,6 +302,7 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "attn.bias" in name or "attn.masked_bias" in name: if "attn.bias" in name or "attn.masked_bias" in name:
continue continue
...@@ -330,3 +332,5 @@ class GPTJForCausalLM(nn.Module, SupportsPP): ...@@ -330,3 +332,5 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPT-NeoX model compatible with HuggingFace weights.""" """Inference-only GPT-NeoX model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -303,8 +303,10 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP): ...@@ -303,8 +303,10 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if ("attention.bias" in name or "attention.masked_bias" in name if ("attention.bias" in name or "attention.masked_bias" in name
or "rotary_emb.inv_freq" in name): or "rotary_emb.inv_freq" in name):
...@@ -337,3 +339,5 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP): ...@@ -337,3 +339,5 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only IBM Granite model compatible with HuggingFace weights.""" """Inference-only IBM Granite model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -455,7 +455,8 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -455,7 +455,8 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
device=device), device=device),
}) })
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
...@@ -465,6 +466,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -465,6 +466,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
(".gate_up_proj", ".up_proj", 1), (".gate_up_proj", ".up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -485,6 +487,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -485,6 +487,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
default_weight_loader) default_weight_loader)
loaded_weight = loaded_weight[0] loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue continue
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
...@@ -518,6 +521,8 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -518,6 +521,8 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
# If this function is called, it should always initialize KV cache scale # If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should # factors (or else raise an exception). Thus, handled exceptions should
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GraniteMoe model.""" """Inference-only GraniteMoe model."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Set, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -419,7 +419,8 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -419,7 +419,8 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
new_weights = {} new_weights = {}
for n, p in weights: for n, p in weights:
if n.endswith('.block_sparse_moe.input_linear.weight'): if n.endswith('.block_sparse_moe.input_linear.weight'):
...@@ -452,4 +453,5 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -452,4 +453,5 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
pass pass
else: else:
new_weights[n] = p new_weights[n] = p
mixtral.MixtralForCausalLM.load_weights(self, new_weights.items()) return mixtral.MixtralForCausalLM.load_weights(self,
new_weights.items())
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
"""PyTorch Idefics2 model.""" """PyTorch Idefics2 model."""
from typing import Iterable, Optional, Tuple from typing import Iterable, Optional, Set, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -331,7 +331,8 @@ class Idefics2VisionTransformer(nn.Module): ...@@ -331,7 +331,8 @@ class Idefics2VisionTransformer(nn.Module):
last_hidden_state = self.post_layernorm(encoder_outputs) last_hidden_state = self.post_layernorm(encoder_outputs)
return last_hidden_state return last_hidden_state
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -339,11 +340,13 @@ class Idefics2VisionTransformer(nn.Module): ...@@ -339,11 +340,13 @@ class Idefics2VisionTransformer(nn.Module):
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
param = params_dict[name.replace(weight_name, param_name)] name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
...@@ -352,3 +355,5 @@ class Idefics2VisionTransformer(nn.Module): ...@@ -352,3 +355,5 @@ class Idefics2VisionTransformer(nn.Module):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import math import math
from typing import (Dict, Iterable, List, Literal, Mapping, NamedTuple, from typing import (Dict, Iterable, List, Literal, Mapping, NamedTuple,
Optional, Tuple, TypedDict, Union) Optional, Set, Tuple, TypedDict, Union)
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -751,9 +751,10 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -751,9 +751,10 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
loader.load_weights(weights) return loader.load_weights(weights)
def get_mm_mapping(self) -> MultiModelKeys: def get_mm_mapping(self) -> MultiModelKeys:
""" """
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
from functools import partial from functools import partial
from typing import Iterable, Optional, Tuple from typing import Iterable, Optional, Set, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -469,10 +469,14 @@ class InternVisionModel(nn.Module): ...@@ -469,10 +469,14 @@ class InternVisionModel(nn.Module):
return encoder_outputs return encoder_outputs
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
from functools import partial from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -369,13 +369,15 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP): ...@@ -369,13 +369,15 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("gate_up_proj", "w1", 0), ("gate_up_proj", "w1", 0),
("gate_up_proj", "w3", 1), ("gate_up_proj", "w3", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -402,3 +404,5 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP): ...@@ -402,3 +404,5 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
# -------------------------------------------------------- # --------------------------------------------------------
import re import re
from functools import cached_property, partial from functools import cached_property, partial
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union) TypedDict, Union)
import torch import torch
...@@ -663,6 +663,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -663,6 +663,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
loader.load_weights(weights) return loader.load_weights(weights)
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
"""Inference-only Jais model compatible with HuggingFace weights.""" """Inference-only Jais model compatible with HuggingFace weights."""
import math import math
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -350,8 +350,10 @@ class JAISLMHeadModel(nn.Module, SupportsPP): ...@@ -350,8 +350,10 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "lm_head.weight" in name: if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final # GPT-2 ties the weights of the embedding layer and the final
...@@ -382,3 +384,5 @@ class JAISLMHeadModel(nn.Module, SupportsPP): ...@@ -382,3 +384,5 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
"""Inference-only Jamba model.""" """Inference-only Jamba model."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Set, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -462,7 +462,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA): ...@@ -462,7 +462,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -479,6 +480,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA): ...@@ -479,6 +480,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
num_experts=self.config.num_experts) num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -534,6 +536,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA): ...@@ -534,6 +536,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def _is_moe_layer(name: str): def _is_moe_layer(name: str):
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights.""" """Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -350,7 +350,8 @@ class LlamaModel(nn.Module): ...@@ -350,7 +350,8 @@ class LlamaModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
...@@ -360,6 +361,7 @@ class LlamaModel(nn.Module): ...@@ -360,6 +361,7 @@ class LlamaModel(nn.Module):
(".gate_up_proj", ".up_proj", 1), (".gate_up_proj", ".up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -375,6 +377,7 @@ class LlamaModel(nn.Module): ...@@ -375,6 +377,7 @@ class LlamaModel(nn.Module):
default_weight_loader) default_weight_loader)
loaded_weight = loaded_weight[0] loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
...@@ -390,7 +393,6 @@ class LlamaModel(nn.Module): ...@@ -390,7 +393,6 @@ class LlamaModel(nn.Module):
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
...@@ -408,6 +410,8 @@ class LlamaModel(nn.Module): ...@@ -408,6 +410,8 @@ class LlamaModel(nn.Module):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
# If this function is called, it should always initialize KV cache scale # If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should # factors (or else raise an exception). Thus, handled exceptions should
...@@ -577,13 +581,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -577,13 +581,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=(["lm_head."] skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None), if self.config.tie_word_embeddings else None),
) )
loader.load_weights( return loader.load_weights(
self.maybe_remap_mistral(name, loaded_weight) self.maybe_remap_mistral(name, loaded_weight)
for name, loaded_weight in weights) for name, loaded_weight in weights)
......
from functools import cached_property from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
Tuple, TypedDict, Union) Tuple, TypedDict, Union)
import torch import torch
...@@ -547,6 +547,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -547,6 +547,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
loader.load_weights(weights) return loader.load_weights(weights)
from functools import cached_property from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union) TypedDict, Union)
import torch import torch
...@@ -654,6 +654,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -654,6 +654,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> Optional[PoolerOutput]: ) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata) return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
loader.load_weights(weights) return loader.load_weights(weights)
import math import math
from functools import cached_property from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union) TypedDict, Union)
import numpy as np import numpy as np
...@@ -445,10 +445,11 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -445,10 +445,11 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
# This model doesn't support images for now # This model doesn't support images for now
ignore_unexpected_prefixes=["image_newline"], ignore_unexpected_prefixes=["image_newline"],
) )
loader.load_weights(weights) return loader.load_weights(weights)
import math import math
from functools import cached_property from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union) TypedDict, Union)
import numpy as np import numpy as np
...@@ -887,6 +887,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -887,6 +887,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
loader.load_weights(weights) return loader.load_weights(weights)
"""PyTorch MAMBA model.""" """PyTorch MAMBA model."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Set, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -243,8 +243,10 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree): ...@@ -243,8 +243,10 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "A_log" in name: if "A_log" in name:
name = name.replace("A_log", "A") name = name.replace("A_log", "A")
...@@ -256,3 +258,5 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree): ...@@ -256,3 +258,5 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment