Unverified Commit c4e46433 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Misc] Add uninitialized params tracking for `AutoWeightsLoader` (#10327)


Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
parent d1557e66
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Set, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -156,8 +156,10 @@ class Medusa(nn.Module): ...@@ -156,8 +156,10 @@ class Medusa(nn.Module):
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
weights_map = {} weights_map = {}
...@@ -181,9 +183,12 @@ class Medusa(nn.Module): ...@@ -181,9 +183,12 @@ class Medusa(nn.Module):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
if self.token_map is not None: if self.token_map is not None:
self.token_map.to(device=self.lm_heads[0].weight.device) self.token_map.to(device=self.lm_heads[0].weight.device)
assert (self.truncated_vocab_size assert (self.truncated_vocab_size
== self.orig_vocab_size) or (self.token_map is not None) == self.orig_vocab_size) or (self.token_map is not None)
return loaded_params
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only MiniCPM model compatible with HuggingFace weights.""" """Inference-only MiniCPM model compatible with HuggingFace weights."""
import math import math
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -539,7 +539,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -539,7 +539,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -556,6 +557,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -556,6 +557,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
for weight_name in ["w1", "w2", "w3"] for weight_name in ["w1", "w2", "w3"]
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -606,3 +608,5 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -606,3 +608,5 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -24,7 +24,7 @@ import math ...@@ -24,7 +24,7 @@ import math
import re import re
from functools import partial from functools import partial
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
Tuple, TypedDict, Union) Set, Tuple, TypedDict, Union)
import torch import torch
import torch.types import torch.types
...@@ -602,7 +602,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -602,7 +602,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -612,6 +613,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -612,6 +613,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name: if key_to_modify in name:
...@@ -630,10 +632,10 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -630,10 +632,10 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
if is_pp_missing_parameter( name = name.replace(weight_name, param_name)
name.replace(weight_name, param_name), self): if is_pp_missing_parameter(name, self):
continue continue
param = params_dict[name.replace(weight_name, param_name)] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
...@@ -646,6 +648,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -646,6 +648,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def get_mm_mapping(self) -> MultiModelKeys: def get_mm_mapping(self) -> MultiModelKeys:
""" """
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Mixtral model.""" """Inference-only Mixtral model."""
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -404,7 +404,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -404,7 +404,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -421,6 +422,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -421,6 +422,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
num_experts=self.config.num_local_experts) num_experts=self.config.num_local_experts)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -478,3 +480,5 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -478,3 +480,5 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Mixtral model.""" """Inference-only Mixtral model."""
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -409,7 +409,8 @@ class MixtralForCausalLM(nn.Module, SupportsPP): ...@@ -409,7 +409,8 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -418,6 +419,7 @@ class MixtralForCausalLM(nn.Module, SupportsPP): ...@@ -418,6 +419,7 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -448,3 +450,5 @@ class MixtralForCausalLM(nn.Module, SupportsPP): ...@@ -448,3 +450,5 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""PyTorch Mllama model.""" """PyTorch Mllama model."""
import math import math
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union) TypedDict, Union)
import numpy as np import numpy as np
...@@ -1427,7 +1427,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -1427,7 +1427,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
return outputs return outputs
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
...@@ -1437,7 +1438,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -1437,7 +1438,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
(".gate_up_proj", ".up_proj", 1), (".gate_up_proj", ".up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
updated_params = set() updated_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if 'patch_embedding.weight' in name: if 'patch_embedding.weight' in name:
name = name.replace('patch_embedding.weight', name = name.replace('patch_embedding.weight',
...@@ -1457,6 +1458,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -1457,6 +1458,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
updated_params.add(name)
return updated_params
def skip_attention_mask(sparse_mask: List[List[int]]) -> bool: def skip_attention_mask(sparse_mask: List[List[int]]) -> bool:
......
import math import math
from typing import Iterable, List, Tuple from typing import Iterable, List, Set, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -188,11 +188,15 @@ class MLPSpeculator(nn.Module): ...@@ -188,11 +188,15 @@ class MLPSpeculator(nn.Module):
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
param = params_dict.get(name.replace("speculator.", "")) param = params_dict.get(name.replace("speculator.", ""))
if param is not None: if param is not None:
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math import math
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -324,8 +324,10 @@ class MPTForCausalLM(nn.Module, SupportsPP): ...@@ -324,8 +324,10 @@ class MPTForCausalLM(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
...@@ -336,3 +338,5 @@ class MPTForCausalLM(nn.Module, SupportsPP): ...@@ -336,3 +338,5 @@ class MPTForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Nemotron model compatible with HuggingFace weights.""" """Inference-only Nemotron model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -474,7 +474,8 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -474,7 +474,8 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
...@@ -482,6 +483,7 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -482,6 +483,7 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
(".qkv_proj", ".v_proj", "v"), (".qkv_proj", ".v_proj", "v"),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -522,3 +524,5 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -522,3 +524,5 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only OLMo model compatible with HuggingFace weights.""" """Inference-only OLMo model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -356,7 +356,8 @@ class OlmoForCausalLM(nn.Module, SupportsPP): ...@@ -356,7 +356,8 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -366,6 +367,7 @@ class OlmoForCausalLM(nn.Module, SupportsPP): ...@@ -366,6 +367,7 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -402,3 +404,5 @@ class OlmoForCausalLM(nn.Module, SupportsPP): ...@@ -402,3 +404,5 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only OLMoE model compatible with HuggingFace weights.""" """Inference-only OLMoE model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -364,7 +364,8 @@ class OlmoeForCausalLM(nn.Module, SupportsPP): ...@@ -364,7 +364,8 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -383,6 +384,7 @@ class OlmoeForCausalLM(nn.Module, SupportsPP): ...@@ -383,6 +384,7 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
num_experts=self.config.num_experts) num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -455,3 +457,5 @@ class OlmoeForCausalLM(nn.Module, SupportsPP): ...@@ -455,3 +457,5 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only OPT model compatible with HuggingFace weights.""" """Inference-only OPT model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -394,7 +394,8 @@ class OPTForCausalLM(nn.Module, SupportsPP): ...@@ -394,7 +394,8 @@ class OPTForCausalLM(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -402,6 +403,7 @@ class OPTForCausalLM(nn.Module, SupportsPP): ...@@ -402,6 +403,7 @@ class OPTForCausalLM(nn.Module, SupportsPP):
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] ]
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "lm_head.weight" in name and self.config.tie_word_embeddings: if "lm_head.weight" in name and self.config.tie_word_embeddings:
continue continue
...@@ -431,3 +433,5 @@ class OPTForCausalLM(nn.Module, SupportsPP): ...@@ -431,3 +433,5 @@ class OPTForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# Copyright (c) OrionStar Inc. # Copyright (c) OrionStar Inc.
# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE # LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
"""Inference-only Orion-14B model compatible with HuggingFace weights.""" """Inference-only Orion-14B model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -327,7 +327,8 @@ class OrionForCausalLM(nn.Module, SupportsPP): ...@@ -327,7 +327,8 @@ class OrionForCausalLM(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -337,6 +338,7 @@ class OrionForCausalLM(nn.Module, SupportsPP): ...@@ -337,6 +338,7 @@ class OrionForCausalLM(nn.Module, SupportsPP):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -368,3 +370,5 @@ class OrionForCausalLM(nn.Module, SupportsPP): ...@@ -368,3 +370,5 @@ class OrionForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union) TypedDict, Union)
import torch import torch
...@@ -295,6 +295,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -295,6 +295,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
loader.load_weights(weights) return loader.load_weights(weights)
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only persimmon model compatible with HuggingFace weights.""" """Inference-only persimmon model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -324,8 +324,10 @@ class PersimmonForCausalLM(nn.Module, SupportsPP): ...@@ -324,8 +324,10 @@ class PersimmonForCausalLM(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -358,3 +360,5 @@ class PersimmonForCausalLM(nn.Module, SupportsPP): ...@@ -358,3 +360,5 @@ class PersimmonForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -34,7 +34,7 @@ ...@@ -34,7 +34,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Inference-only Phi-1.5 model compatible with HuggingFace weights.""" """Inference-only Phi-1.5 model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -345,7 +345,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -345,7 +345,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -353,6 +354,7 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -353,6 +354,7 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
("qkv_proj", "v_proj", "v") ("qkv_proj", "v_proj", "v")
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
...@@ -383,3 +385,5 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -383,3 +385,5 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
import math import math
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -457,9 +457,11 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP): ...@@ -457,9 +457,11 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP):
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -471,3 +473,5 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP): ...@@ -471,3 +473,5 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import itertools import itertools
import re import re
from functools import cached_property, lru_cache from functools import cached_property, lru_cache
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union) Tuple, TypedDict, Union)
import numpy as np import numpy as np
...@@ -744,7 +744,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -744,7 +744,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
) -> Optional[PoolerOutput]: ) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata) return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={ orig_to_new_prefix={
"model.vision_embed_tokens.wte": "embed_tokens", "model.vision_embed_tokens.wte": "embed_tokens",
...@@ -759,5 +760,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -759,5 +760,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# The HF config doesn't specify whether these are tied, # The HF config doesn't specify whether these are tied,
# so we detect it this way # so we detect it this way
if "embed_tokens" not in autoloaded_weights: if "embed_tokens.weight" not in autoloaded_weights:
self.embed_tokens = self.language_model.model.embed_tokens self.embed_tokens = self.language_model.model.embed_tokens
autoloaded_weights.add("embed_tokens.weight")
return autoloaded_weights
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only PhiMoE model.""" """Inference-only PhiMoE model."""
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -598,7 +598,8 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -598,7 +598,8 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -613,6 +614,7 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -613,6 +614,7 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
num_experts=self.config.num_local_experts) num_experts=self.config.num_local_experts)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -666,3 +668,5 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -666,3 +668,5 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from functools import cached_property from functools import cached_property
from itertools import tee from itertools import tee
from typing import Iterable, List, Mapping, Optional, Tuple, Union from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union
import numpy import numpy
import torch import torch
...@@ -1053,7 +1053,8 @@ class PixtralHFVisionModel(nn.Module): ...@@ -1053,7 +1053,8 @@ class PixtralHFVisionModel(nn.Module):
# (TODO) Add prefix argument for filtering out weights to be loaded # (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
...@@ -1063,6 +1064,7 @@ class PixtralHFVisionModel(nn.Module): ...@@ -1063,6 +1064,7 @@ class PixtralHFVisionModel(nn.Module):
(".gate_up_proj", ".up_proj", 1), (".gate_up_proj", ".up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
layer_count = len(self.transformer.layers) layer_count = len(self.transformer.layers)
for name, loaded_weight in weights: for name, loaded_weight in weights:
...@@ -1075,8 +1077,8 @@ class PixtralHFVisionModel(nn.Module): ...@@ -1075,8 +1077,8 @@ class PixtralHFVisionModel(nn.Module):
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name)
param = params_dict[name.replace(weight_name, param_name)] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
...@@ -1085,3 +1087,5 @@ class PixtralHFVisionModel(nn.Module): ...@@ -1085,3 +1087,5 @@ class PixtralHFVisionModel(nn.Module):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment