Unverified Commit c4e46433 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Misc] Add uninitialized params tracking for `AutoWeightsLoader` (#10327)


Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
parent d1557e66
...@@ -8,7 +8,7 @@ import math ...@@ -8,7 +8,7 @@ import math
import re import re
from functools import partial from functools import partial
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Tuple, TypedDict, Union) Optional, Set, Tuple, TypedDict, Union)
import numpy as np import numpy as np
import torch import torch
...@@ -964,13 +964,15 @@ class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -964,13 +964,15 @@ class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("gate_up_proj", "w2", 0), ("gate_up_proj", "w2", 0),
("gate_up_proj", "w1", 1), ("gate_up_proj", "w1", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -999,6 +1001,8 @@ class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -999,6 +1001,8 @@ class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class QWenLLM(QWenBaseModel): class QWenLLM(QWenBaseModel):
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2 model compatible with HuggingFace weights.""" """Inference-only Qwen2 model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -332,7 +332,8 @@ class Qwen2Model(nn.Module): ...@@ -332,7 +332,8 @@ class Qwen2Model(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -342,6 +343,7 @@ class Qwen2Model(nn.Module): ...@@ -342,6 +343,7 @@ class Qwen2Model(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -372,6 +374,8 @@ class Qwen2Model(nn.Module): ...@@ -372,6 +374,8 @@ class Qwen2Model(nn.Module):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...@@ -494,13 +498,14 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -494,13 +498,14 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
) -> Optional[PoolerOutput]: ) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata) return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=(["lm_head."] skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None), if self.config.tie_word_embeddings else None),
) )
loader.load_weights(weights) return loader.load_weights(weights)
class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP): class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
...@@ -564,7 +569,8 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP): ...@@ -564,7 +569,8 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
) -> Optional[PoolerOutput]: ) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata) return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self, loader = AutoWeightsLoader(self,
ignore_unexpected_prefixes=["lm_head."]) ignore_unexpected_prefixes=["lm_head."])
loader.load_weights(weights) return loader.load_weights(weights)
...@@ -20,7 +20,8 @@ ...@@ -20,7 +20,8 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" """Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from functools import lru_cache from functools import lru_cache
from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict, Union from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union)
import librosa import librosa
import numpy as np import numpy as np
...@@ -420,7 +421,8 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -420,7 +421,8 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -430,6 +432,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -430,6 +432,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -463,3 +466,5 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -463,3 +466,5 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# Copyright 2024 The Qwen team. # Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
"""Inference-only Qwen2-Classification model compatible with HF weights.""" """Inference-only Qwen2-Classification model compatible with HF weights."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Set, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -97,7 +97,8 @@ class Qwen2ForSequenceClassification(nn.Module, SupportsLoRA, SupportsPP): ...@@ -97,7 +97,8 @@ class Qwen2ForSequenceClassification(nn.Module, SupportsLoRA, SupportsPP):
) -> Optional[PoolerOutput]: ) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata) return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self, loader = AutoWeightsLoader(self,
ignore_unexpected_prefixes=["lm_head."]) ignore_unexpected_prefixes=["lm_head."])
loader.load_weights(weights) return loader.load_weights(weights)
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2MoE model compatible with HuggingFace weights.""" """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -436,7 +436,8 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP): ...@@ -436,7 +436,8 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -455,6 +456,7 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP): ...@@ -455,6 +456,7 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
num_experts=self.config.num_experts) num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -532,3 +534,5 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP): ...@@ -532,3 +534,5 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# Copyright 2024 The Qwen team. # Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
"""Inference-only Qwen2-RM model compatible with HuggingFace weights.""" """Inference-only Qwen2-RM model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -110,7 +110,8 @@ class Qwen2ForRewardModel(nn.Module, SupportsLoRA, SupportsPP): ...@@ -110,7 +110,8 @@ class Qwen2ForRewardModel(nn.Module, SupportsLoRA, SupportsPP):
) -> Optional[PoolerOutput]: ) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata) return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self, loader = AutoWeightsLoader(self,
ignore_unexpected_prefixes=["lm_head."]) ignore_unexpected_prefixes=["lm_head."])
loader.load_weights(weights) return loader.load_weights(weights)
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
"""Inference-only Qwen2-VL model compatible with HuggingFace weights.""" """Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from functools import partial from functools import partial
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Tuple, Type, TypedDict, Union) Optional, Set, Tuple, Type, TypedDict, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -1333,7 +1333,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1333,7 +1333,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> Optional[PoolerOutput]: ) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata) return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -1343,6 +1344,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1343,6 +1344,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
] ]
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -1392,3 +1394,5 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1392,3 +1394,5 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
within a vision language model.""" within a vision language model."""
import math import math
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -594,7 +594,8 @@ class SiglipVisionModel(nn.Module): ...@@ -594,7 +594,8 @@ class SiglipVisionModel(nn.Module):
interpolate_pos_encoding=interpolate_pos_encoding, interpolate_pos_encoding=interpolate_pos_encoding,
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -602,6 +603,7 @@ class SiglipVisionModel(nn.Module): ...@@ -602,6 +603,7 @@ class SiglipVisionModel(nn.Module):
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] if self.shard_weight else [] ] if self.shard_weight else []
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
layer_count = len(self.vision_model.encoder.layers) layer_count = len(self.vision_model.encoder.layers)
for name, loaded_weight in weights: for name, loaded_weight in weights:
...@@ -619,8 +621,9 @@ class SiglipVisionModel(nn.Module): ...@@ -619,8 +621,9 @@ class SiglipVisionModel(nn.Module):
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name)
param = params_dict[name.replace(weight_name, param_name)] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
...@@ -629,3 +632,5 @@ class SiglipVisionModel(nn.Module): ...@@ -629,3 +632,5 @@ class SiglipVisionModel(nn.Module):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Solar model compatible with HuggingFace weights.""" """Inference-only Solar model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -477,7 +477,8 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -477,7 +477,8 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
...@@ -487,6 +488,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -487,6 +488,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
(".gate_up_proj", ".up_proj", 1), (".gate_up_proj", ".up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -502,6 +504,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -502,6 +504,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
default_weight_loader) default_weight_loader)
loaded_weight = loaded_weight[0] loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
...@@ -535,6 +538,8 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -535,6 +538,8 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
# If this function is called, it should always initialize KV cache scale # If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should # factors (or else raise an exception). Thus, handled exceptions should
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM) """Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
model compatible with HuggingFace weights.""" model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -306,7 +306,8 @@ class StablelmForCausalLM(nn.Module, SupportsPP): ...@@ -306,7 +306,8 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -316,6 +317,7 @@ class StablelmForCausalLM(nn.Module, SupportsPP): ...@@ -316,6 +317,7 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -347,3 +349,5 @@ class StablelmForCausalLM(nn.Module, SupportsPP): ...@@ -347,3 +349,5 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PyTorch Starcoder2 model.""" """ PyTorch Starcoder2 model."""
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -314,7 +314,8 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP): ...@@ -314,7 +314,8 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -323,6 +324,7 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP): ...@@ -323,6 +324,7 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
] ]
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -346,3 +348,5 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP): ...@@ -346,3 +348,5 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import math import math
from functools import cached_property, lru_cache from functools import cached_property, lru_cache
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union, cast) TypedDict, Union, cast)
import numpy as np import numpy as np
...@@ -504,10 +504,11 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -504,10 +504,11 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."}) orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})
loader = AutoWeightsLoader(self, loader = AutoWeightsLoader(self,
ignore_unexpected_prefixes=["audio_tower."]) ignore_unexpected_prefixes=["audio_tower."])
loader.load_weights(weights, mapper=hf_to_vllm_mapper) return loader.load_weights(weights, mapper=hf_to_vllm_mapper)
import itertools import itertools
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Protocol, Tuple, Union, overload) Optional, Protocol, Set, Tuple, Union, overload)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -172,8 +172,9 @@ class AutoWeightsLoader: ...@@ -172,8 +172,9 @@ class AutoWeightsLoader:
if module != self.module: if module != self.module:
module_load_weights = getattr(module, "load_weights", None) module_load_weights = getattr(module, "load_weights", None)
if callable(module_load_weights): if callable(module_load_weights):
module_load_weights(weights) loaded_params = module_load_weights(weights)
return yield from map(lambda x: self._get_qualname(base_prefix, x),
loaded_params)
child_modules = dict(module.named_children()) child_modules = dict(module.named_children())
child_params = dict(module.named_parameters(recurse=False)) child_params = dict(module.named_parameters(recurse=False))
...@@ -222,11 +223,11 @@ class AutoWeightsLoader: ...@@ -222,11 +223,11 @@ class AutoWeightsLoader:
weights: Iterable[Tuple[str, torch.Tensor]], weights: Iterable[Tuple[str, torch.Tensor]],
*, *,
mapper: Optional[WeightsMapper] = None, mapper: Optional[WeightsMapper] = None,
) -> List[str]: ) -> Set[str]:
if mapper is not None: if mapper is not None:
weights = mapper.apply(weights) weights = mapper.apply(weights)
autoloaded_weights = list(self._load_module("", self.module, weights)) autoloaded_weights = set(self._load_module("", self.module, weights))
return autoloaded_weights return autoloaded_weights
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Xverse model compatible with HuggingFace weights.""" """Inference-only Xverse model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -376,7 +376,8 @@ class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -376,7 +376,8 @@ class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"), ("qkv_proj", "k_proj", "k"),
...@@ -385,6 +386,7 @@ class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -385,6 +386,7 @@ class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if ("rotary_emb.inv_freq" in name if ("rotary_emb.inv_freq" in name
or "rotary_emb.cos_cached" in name or "rotary_emb.cos_cached" in name
...@@ -413,3 +415,5 @@ class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -413,3 +415,5 @@ class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment