Unverified Commit c4e46433 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Misc] Add uninitialized params tracking for `AutoWeightsLoader` (#10327)


Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
parent d1557e66
......@@ -8,7 +8,7 @@ import math
import re
from functools import partial
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Tuple, TypedDict, Union)
Optional, Set, Tuple, TypedDict, Union)
import numpy as np
import torch
......@@ -964,13 +964,15 @@ class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "w2", 0),
("gate_up_proj", "w1", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......@@ -999,6 +1001,8 @@ class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class QWenLLM(QWenBaseModel):
......
......@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
......@@ -332,7 +332,8 @@ class Qwen2Model(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -342,6 +343,7 @@ class Qwen2Model(nn.Module):
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......@@ -372,6 +374,8 @@ class Qwen2Model(nn.Module):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
......@@ -494,13 +498,14 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
loader.load_weights(weights)
return loader.load_weights(weights)
class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
......@@ -564,7 +569,8 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self,
ignore_unexpected_prefixes=["lm_head."])
loader.load_weights(weights)
return loader.load_weights(weights)
......@@ -20,7 +20,8 @@
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from functools import lru_cache
from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict, Union
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union)
import librosa
import numpy as np
......@@ -420,7 +421,8 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -430,6 +432,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......@@ -463,3 +466,5 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
......@@ -4,7 +4,7 @@
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
"""Inference-only Qwen2-Classification model compatible with HF weights."""
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Set, Tuple
import torch
from torch import nn
......@@ -97,7 +97,8 @@ class Qwen2ForSequenceClassification(nn.Module, SupportsLoRA, SupportsPP):
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self,
ignore_unexpected_prefixes=["lm_head."])
loader.load_weights(weights)
return loader.load_weights(weights)
......@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch
import torch.nn.functional as F
......@@ -436,7 +436,8 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -455,6 +456,7 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......@@ -532,3 +534,5 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
......@@ -3,7 +3,7 @@
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
"""Inference-only Qwen2-RM model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
......@@ -110,7 +110,8 @@ class Qwen2ForRewardModel(nn.Module, SupportsLoRA, SupportsPP):
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self,
ignore_unexpected_prefixes=["lm_head."])
loader.load_weights(weights)
return loader.load_weights(weights)
......@@ -23,7 +23,7 @@
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from functools import partial
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Tuple, Type, TypedDict, Union)
Optional, Set, Tuple, Type, TypedDict, Union)
import torch
import torch.nn as nn
......@@ -1333,7 +1333,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -1343,6 +1344,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
("gate_up_proj", "gate_proj", 0),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......@@ -1392,3 +1394,5 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
......@@ -2,7 +2,7 @@
within a vision language model."""
import math
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, List, Optional, Set, Tuple, Union
import numpy as np
import torch
......@@ -594,7 +594,8 @@ class SiglipVisionModel(nn.Module):
interpolate_pos_encoding=interpolate_pos_encoding,
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -602,6 +603,7 @@ class SiglipVisionModel(nn.Module):
("qkv_proj", "v_proj", "v"),
] if self.shard_weight else []
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
layer_count = len(self.vision_model.encoder.layers)
for name, loaded_weight in weights:
......@@ -619,8 +621,9 @@ class SiglipVisionModel(nn.Module):
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name.replace(weight_name, param_name)]
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
......@@ -629,3 +632,5 @@ class SiglipVisionModel(nn.Module):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
......@@ -21,7 +21,7 @@
# limitations under the License.
"""Inference-only Solar model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
......@@ -477,7 +477,8 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
......@@ -487,6 +488,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......@@ -502,6 +504,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
......@@ -535,6 +538,8 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
......
......@@ -18,7 +18,7 @@
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
......@@ -306,7 +306,8 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -316,6 +317,7 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......@@ -347,3 +349,5 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
......@@ -17,7 +17,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Starcoder2 model."""
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
......@@ -314,7 +314,8 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -323,6 +324,7 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......@@ -346,3 +348,5 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
......@@ -3,7 +3,7 @@
import math
from functools import cached_property, lru_cache
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union, cast)
import numpy as np
......@@ -504,10 +504,11 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})
loader = AutoWeightsLoader(self,
ignore_unexpected_prefixes=["audio_tower."])
loader.load_weights(weights, mapper=hf_to_vllm_mapper)
return loader.load_weights(weights, mapper=hf_to_vllm_mapper)
import itertools
from dataclasses import dataclass, field
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Protocol, Tuple, Union, overload)
Optional, Protocol, Set, Tuple, Union, overload)
import torch
import torch.nn as nn
......@@ -172,8 +172,9 @@ class AutoWeightsLoader:
if module != self.module:
module_load_weights = getattr(module, "load_weights", None)
if callable(module_load_weights):
module_load_weights(weights)
return
loaded_params = module_load_weights(weights)
yield from map(lambda x: self._get_qualname(base_prefix, x),
loaded_params)
child_modules = dict(module.named_children())
child_params = dict(module.named_parameters(recurse=False))
......@@ -222,11 +223,11 @@ class AutoWeightsLoader:
weights: Iterable[Tuple[str, torch.Tensor]],
*,
mapper: Optional[WeightsMapper] = None,
) -> List[str]:
) -> Set[str]:
if mapper is not None:
weights = mapper.apply(weights)
autoloaded_weights = list(self._load_module("", self.module, weights))
autoloaded_weights = set(self._load_module("", self.module, weights))
return autoloaded_weights
......
......@@ -19,7 +19,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Xverse model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
......@@ -376,7 +376,8 @@ class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
......@@ -385,6 +386,7 @@ class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if ("rotary_emb.inv_freq" in name
or "rotary_emb.cos_cached" in name
......@@ -413,3 +415,5 @@ class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment