Unverified Commit c4e46433 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Misc] Add uninitialized params tracking for `AutoWeightsLoader` (#10327)


Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
parent d1557e66
...@@ -334,7 +334,17 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -334,7 +334,17 @@ class DefaultModelLoader(BaseModelLoader):
with target_device: with target_device:
model = _initialize_model(vllm_config=vllm_config) model = _initialize_model(vllm_config=vllm_config)
model.load_weights(self._get_all_weights(model_config, model)) weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(
self._get_all_weights(model_config, model))
# We only enable strict check for non-quantiized models
# that have loaded weights tracking currently.
if model_config.quantization is None and loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError(
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")
for _, module in model.named_modules(): for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None) quant_method = getattr(module, "quant_method", None)
......
"""Inference-only Snowflake Arctic model.""" """Inference-only Snowflake Arctic model."""
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -480,7 +480,8 @@ class ArcticForCausalLM(nn.Module, SupportsPP): ...@@ -480,7 +480,8 @@ class ArcticForCausalLM(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -518,6 +519,7 @@ class ArcticForCausalLM(nn.Module, SupportsPP): ...@@ -518,6 +519,7 @@ class ArcticForCausalLM(nn.Module, SupportsPP):
("ws", f"experts.{expert_id}.w3.weight", expert_id)) ("ws", f"experts.{expert_id}.w3.weight", expert_id))
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
logger.info( logger.info(
"It will take ~10 minutes loading from the 16-bit weights. " "It will take ~10 minutes loading from the 16-bit weights. "
...@@ -573,3 +575,5 @@ class ArcticForCausalLM(nn.Module, SupportsPP): ...@@ -573,3 +575,5 @@ class ArcticForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only BaiChuan model compatible with HuggingFace weights.""" """Inference-only BaiChuan model compatible with HuggingFace weights."""
import math import math
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -404,13 +404,15 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -404,13 +404,15 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -449,6 +451,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -449,6 +451,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
......
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Set, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -337,7 +337,8 @@ class BertModel(nn.Module): ...@@ -337,7 +337,8 @@ class BertModel(nn.Module):
return self.encoder(hidden_states, kv_caches, attn_metadata) return self.encoder(hidden_states, kv_caches, attn_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "query", "q"), ("qkv_proj", "query", "q"),
...@@ -346,6 +347,7 @@ class BertModel(nn.Module): ...@@ -346,6 +347,7 @@ class BertModel(nn.Module):
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "pooler" in name: if "pooler" in name:
continue continue
...@@ -368,6 +370,8 @@ class BertModel(nn.Module): ...@@ -368,6 +370,8 @@ class BertModel(nn.Module):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class BertEmbeddingModel(nn.Module): class BertEmbeddingModel(nn.Module):
......
"""Minimal implementation of BlipVisionModel intended to be only used """Minimal implementation of BlipVisionModel intended to be only used
within a vision language model.""" within a vision language model."""
from typing import Iterable, Optional, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -415,7 +415,8 @@ class BlipVisionModel(nn.Module): ...@@ -415,7 +415,8 @@ class BlipVisionModel(nn.Module):
return self.post_layernorm(hidden_states) return self.post_layernorm(hidden_states)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -423,6 +424,7 @@ class BlipVisionModel(nn.Module): ...@@ -423,6 +424,7 @@ class BlipVisionModel(nn.Module):
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] if self.shard_weight else [] ] if self.shard_weight else []
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
layer_count = len(self.encoder.layers) layer_count = len(self.encoder.layers)
for name, loaded_weight in weights: for name, loaded_weight in weights:
...@@ -440,8 +442,8 @@ class BlipVisionModel(nn.Module): ...@@ -440,8 +442,8 @@ class BlipVisionModel(nn.Module):
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name)
param = params_dict[name.replace(weight_name, param_name)] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
...@@ -450,3 +452,5 @@ class BlipVisionModel(nn.Module): ...@@ -450,3 +452,5 @@ class BlipVisionModel(nn.Module):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
from functools import cached_property from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union) TypedDict, Union)
import torch import torch
...@@ -692,6 +692,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -692,6 +692,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
loader.load_weights(weights) return loader.load_weights(weights)
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only BLOOM model compatible with HuggingFace weights.""" """Inference-only BLOOM model compatible with HuggingFace weights."""
import math import math
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -341,8 +341,10 @@ class BloomForCausalLM(nn.Module, SupportsPP): ...@@ -341,8 +341,10 @@ class BloomForCausalLM(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name == "lm_head.weight": if name == "lm_head.weight":
continue continue
...@@ -371,3 +373,5 @@ class BloomForCausalLM(nn.Module, SupportsPP): ...@@ -371,3 +373,5 @@ class BloomForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
from functools import cached_property from functools import cached_property
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union) Tuple, TypedDict, Union)
import torch import torch
...@@ -1034,7 +1034,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1034,7 +1034,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
...@@ -1044,6 +1045,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1044,6 +1045,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
(".gate_up_proj", ".up_proj", 1), (".gate_up_proj", ".up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -1111,3 +1113,5 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1111,3 +1113,5 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
"""Inference-only ChatGLM model compatible with THUDM weights.""" """Inference-only ChatGLM model compatible with THUDM weights."""
from argparse import Namespace from argparse import Namespace
from array import array from array import array
from typing import Dict, Iterable, List, Mapping, Optional, Tuple, TypedDict from typing import (Dict, Iterable, List, Mapping, Optional, Set, Tuple,
TypedDict)
import torch import torch
from PIL import Image from PIL import Image
...@@ -645,7 +646,8 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, ...@@ -645,7 +646,8 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
# Merge two ColumnParallelLinear into one MergedColumnParallelLinear # Merge two ColumnParallelLinear into one MergedColumnParallelLinear
merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = { merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = {
"transformer.vision.linear_proj.merged_proj.weight": { "transformer.vision.linear_proj.merged_proj.weight": {
...@@ -655,6 +657,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, ...@@ -655,6 +657,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
} }
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
is_weight_to_be_merge = False is_weight_to_be_merge = False
for _, merged_weight_dict in merged_weights_dict.items(): for _, merged_weight_dict in merged_weights_dict.items():
...@@ -677,6 +680,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, ...@@ -677,6 +680,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
for combined_name, merged_weight_dict in merged_weights_dict.items(): for combined_name, merged_weight_dict in merged_weights_dict.items():
if combined_name in params_dict: if combined_name in params_dict:
...@@ -686,3 +690,5 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, ...@@ -686,3 +690,5 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, combined_weight) weight_loader(param, combined_weight)
loaded_params.add(combined_name)
return loaded_params
"""Minimal implementation of CLIPVisionModel intended to be only used """Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model.""" within a vision language model."""
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -483,7 +483,8 @@ class CLIPVisionModel(nn.Module): ...@@ -483,7 +483,8 @@ class CLIPVisionModel(nn.Module):
# (TODO) Add prefix argument for filtering out weights to be loaded # (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -491,6 +492,7 @@ class CLIPVisionModel(nn.Module): ...@@ -491,6 +492,7 @@ class CLIPVisionModel(nn.Module):
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] if self.shard_weight else [] ] if self.shard_weight else []
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
layer_count = len(self.vision_model.encoder.layers) layer_count = len(self.vision_model.encoder.layers)
for name, loaded_weight in weights: for name, loaded_weight in weights:
...@@ -508,8 +510,9 @@ class CLIPVisionModel(nn.Module): ...@@ -508,8 +510,9 @@ class CLIPVisionModel(nn.Module):
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name)
param = params_dict[name.replace(weight_name, param_name)] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
...@@ -518,3 +521,5 @@ class CLIPVisionModel(nn.Module): ...@@ -518,3 +521,5 @@ class CLIPVisionModel(nn.Module):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -402,7 +402,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -402,7 +402,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -447,3 +448,4 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -447,3 +448,4 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
return loaded_params
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -417,13 +417,15 @@ class DbrxForCausalLM(nn.Module, SupportsPP): ...@@ -417,13 +417,15 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
expert_params_mapping = [( expert_params_mapping = [(
"w13_weight" if weight_name in ["w1", "v1"] else "w2_weight", "w13_weight" if weight_name in ["w1", "v1"] else "w2_weight",
f"mlp.{weight_name}", f"mlp.{weight_name}",
) for weight_name in ["w1", "v1", "w2"]] ) for weight_name in ["w1", "v1", "w2"]]
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
for param_name, weight_name in expert_params_mapping: for param_name, weight_name in expert_params_mapping:
if weight_name not in name: if weight_name not in name:
...@@ -447,3 +449,5 @@ class DbrxForCausalLM(nn.Module, SupportsPP): ...@@ -447,3 +449,5 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only DeciLM model compatible with HuggingFace weights.""" """Inference-only DeciLM model compatible with HuggingFace weights."""
from typing import Iterable, Tuple from typing import Iterable, Set, Tuple
import torch import torch
...@@ -57,7 +57,8 @@ class DeciLMForCausalLM(LlamaForCausalLM): ...@@ -57,7 +57,8 @@ class DeciLMForCausalLM(LlamaForCausalLM):
delattr(config, "num_key_value_heads_per_layer") delattr(config, "num_key_value_heads_per_layer")
super().__init__(vllm_config=vllm_config) super().__init__(vllm_config=vllm_config)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -67,6 +68,7 @@ class DeciLMForCausalLM(LlamaForCausalLM): ...@@ -67,6 +68,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -97,6 +99,8 @@ class DeciLMForCausalLM(LlamaForCausalLM): ...@@ -97,6 +99,8 @@ class DeciLMForCausalLM(LlamaForCausalLM):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def _degroup_weight(self, loaded_weight: torch.Tensor) -> torch.Tensor: def _degroup_weight(self, loaded_weight: torch.Tensor) -> torch.Tensor:
hidden_size = self.config.hidden_size hidden_size = self.config.hidden_size
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Deepseek model.""" """Inference-only Deepseek model."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -442,7 +442,8 @@ class DeepseekForCausalLM(nn.Module, SupportsPP): ...@@ -442,7 +442,8 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -453,6 +454,7 @@ class DeepseekForCausalLM(nn.Module, SupportsPP): ...@@ -453,6 +454,7 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -487,3 +489,5 @@ class DeepseekForCausalLM(nn.Module, SupportsPP): ...@@ -487,3 +489,5 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only DeepseekV2 model.""" """Inference-only DeepseekV2 model."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -550,7 +550,8 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -550,7 +550,8 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
device=device), device=device),
}) })
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
...@@ -566,6 +567,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -566,6 +567,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
num_experts=self.config.n_routed_experts) num_experts=self.config.n_routed_experts)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -623,3 +625,5 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -623,3 +625,5 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Exaone model compatible with HuggingFace weights.""" """Inference-only Exaone model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -513,7 +513,8 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -513,7 +513,8 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
...@@ -523,6 +524,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -523,6 +524,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
(".gate_up_proj", ".c_fc_1", 1), (".gate_up_proj", ".c_fc_1", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -543,6 +545,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -543,6 +545,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
default_weight_loader) default_weight_loader)
loaded_weight = loaded_weight[0] loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
...@@ -576,6 +579,8 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -576,6 +579,8 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
# If this function is called, it should always initialize KV cache scale # If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should # factors (or else raise an exception). Thus, handled exceptions should
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
"""PyTorch Falcon model.""" """PyTorch Falcon model."""
import math import math
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -473,7 +473,8 @@ class FalconForCausalLM(nn.Module, SupportsPP): ...@@ -473,7 +473,8 @@ class FalconForCausalLM(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
total_num_heads = self.config.num_attention_heads total_num_heads = self.config.num_attention_heads
if self.config.new_decoder_architecture: if self.config.new_decoder_architecture:
total_num_kv_heads = self.config.num_kv_heads total_num_kv_heads = self.config.num_kv_heads
...@@ -483,6 +484,7 @@ class FalconForCausalLM(nn.Module, SupportsPP): ...@@ -483,6 +484,7 @@ class FalconForCausalLM(nn.Module, SupportsPP):
total_num_kv_heads = total_num_heads total_num_kv_heads = total_num_heads
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name == "lm_head.weight" and self.tie_word_embeddings: if name == "lm_head.weight" and self.tie_word_embeddings:
# Falcon uses tied embeddings except Falcon-11b. # Falcon uses tied embeddings except Falcon-11b.
...@@ -519,3 +521,5 @@ class FalconForCausalLM(nn.Module, SupportsPP): ...@@ -519,3 +521,5 @@ class FalconForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
import math import math
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Set, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -156,7 +156,8 @@ class Florence2LanguageForConditionalGeneration(nn.Module): ...@@ -156,7 +156,8 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -165,12 +166,13 @@ class Florence2LanguageForConditionalGeneration(nn.Module): ...@@ -165,12 +166,13 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name)
param = params_dict[name.replace(weight_name, param_name)] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
...@@ -183,6 +185,8 @@ class Florence2LanguageForConditionalGeneration(nn.Module): ...@@ -183,6 +185,8 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Florence2ForConditionalGeneration(nn.Module): class Florence2ForConditionalGeneration(nn.Module):
...@@ -248,10 +252,11 @@ class Florence2ForConditionalGeneration(nn.Module): ...@@ -248,10 +252,11 @@ class Florence2ForConditionalGeneration(nn.Module):
) -> SamplerOutput: ) -> SamplerOutput:
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
skip_prefixes = [ skip_prefixes = [
'image_projection', "vision_tower", "image_proj_norm", 'image_projection', "vision_tower", "image_proj_norm",
"image_pos_embed", "visual_temporal_embed" "image_pos_embed", "visual_temporal_embed"
] ]
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
loader.load_weights(weights) return loader.load_weights(weights)
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
""" PyTorch Fuyu model.""" """ PyTorch Fuyu model."""
import math import math
from array import array from array import array
from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -354,6 +355,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -354,6 +355,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
next_tokens = self.language_model.sampler(logits, sampling_metadata) next_tokens = self.language_model.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
loader.load_weights(weights) return loader.load_weights(weights)
...@@ -424,7 +424,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -424,7 +424,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -469,3 +470,4 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -469,3 +470,4 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
logger.warning( logger.warning(
"Some weights are not initialized from checkpoints: %s", "Some weights are not initialized from checkpoints: %s",
unloaded_params) unloaded_params)
return loaded_params
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment