Unverified Commit c4e46433 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Misc] Add uninitialized params tracking for `AutoWeightsLoader` (#10327)


Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
parent d1557e66
......@@ -334,7 +334,17 @@ class DefaultModelLoader(BaseModelLoader):
with target_device:
model = _initialize_model(vllm_config=vllm_config)
model.load_weights(self._get_all_weights(model_config, model))
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(
self._get_all_weights(model_config, model))
# We only enable strict check for non-quantiized models
# that have loaded weights tracking currently.
if model_config.quantization is None and loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError(
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
......
"""Inference-only Snowflake Arctic model."""
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
......@@ -480,7 +480,8 @@ class ArcticForCausalLM(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -518,6 +519,7 @@ class ArcticForCausalLM(nn.Module, SupportsPP):
("ws", f"experts.{expert_id}.w3.weight", expert_id))
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
logger.info(
"It will take ~10 minutes loading from the 16-bit weights. "
......@@ -573,3 +575,5 @@ class ArcticForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
......@@ -18,7 +18,7 @@
# limitations under the License.
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
import math
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
......@@ -404,13 +404,15 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......@@ -449,6 +451,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
......
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Set, Tuple
import torch
from torch import nn
......@@ -337,7 +337,8 @@ class BertModel(nn.Module):
return self.encoder(hidden_states, kv_caches, attn_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "query", "q"),
......@@ -346,6 +347,7 @@ class BertModel(nn.Module):
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "pooler" in name:
continue
......@@ -368,6 +370,8 @@ class BertModel(nn.Module):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class BertEmbeddingModel(nn.Module):
......
"""Minimal implementation of BlipVisionModel intended to be only used
within a vision language model."""
from typing import Iterable, Optional, Tuple, Union
from typing import Iterable, Optional, Set, Tuple, Union
import torch
import torch.nn as nn
......@@ -415,7 +415,8 @@ class BlipVisionModel(nn.Module):
return self.post_layernorm(hidden_states)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -423,6 +424,7 @@ class BlipVisionModel(nn.Module):
("qkv_proj", "v_proj", "v"),
] if self.shard_weight else []
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
layer_count = len(self.encoder.layers)
for name, loaded_weight in weights:
......@@ -440,8 +442,8 @@ class BlipVisionModel(nn.Module):
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
......@@ -450,3 +452,5 @@ class BlipVisionModel(nn.Module):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
import torch
......@@ -692,6 +692,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
loader.load_weights(weights)
return loader.load_weights(weights)
......@@ -16,7 +16,7 @@
# limitations under the License.
"""Inference-only BLOOM model compatible with HuggingFace weights."""
import math
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
......@@ -341,8 +341,10 @@ class BloomForCausalLM(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if name == "lm_head.weight":
continue
......@@ -371,3 +373,5 @@ class BloomForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
from functools import cached_property
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)
import torch
......@@ -1034,7 +1034,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
......@@ -1044,6 +1045,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......@@ -1111,3 +1113,5 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
......@@ -3,7 +3,8 @@
"""Inference-only ChatGLM model compatible with THUDM weights."""
from argparse import Namespace
from array import array
from typing import Dict, Iterable, List, Mapping, Optional, Tuple, TypedDict
from typing import (Dict, Iterable, List, Mapping, Optional, Set, Tuple,
TypedDict)
import torch
from PIL import Image
......@@ -645,7 +646,8 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
# Merge two ColumnParallelLinear into one MergedColumnParallelLinear
merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = {
"transformer.vision.linear_proj.merged_proj.weight": {
......@@ -655,6 +657,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
}
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
is_weight_to_be_merge = False
for _, merged_weight_dict in merged_weights_dict.items():
......@@ -677,6 +680,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
for combined_name, merged_weight_dict in merged_weights_dict.items():
if combined_name in params_dict:
......@@ -686,3 +690,5 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, combined_weight)
loaded_params.add(combined_name)
return loaded_params
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, List, Optional, Set, Tuple, Union
import numpy as np
import torch
......@@ -483,7 +483,8 @@ class CLIPVisionModel(nn.Module):
# (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -491,6 +492,7 @@ class CLIPVisionModel(nn.Module):
("qkv_proj", "v_proj", "v"),
] if self.shard_weight else []
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
layer_count = len(self.vision_model.encoder.layers)
for name, loaded_weight in weights:
......@@ -508,8 +510,9 @@ class CLIPVisionModel(nn.Module):
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name.replace(weight_name, param_name)]
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
......@@ -518,3 +521,5 @@ class CLIPVisionModel(nn.Module):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
......@@ -402,7 +402,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -447,3 +448,4 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, List, Optional, Set, Tuple, Union
import torch
import torch.nn as nn
......@@ -417,13 +417,15 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
expert_params_mapping = [(
"w13_weight" if weight_name in ["w1", "v1"] else "w2_weight",
f"mlp.{weight_name}",
) for weight_name in ["w1", "v1", "w2"]]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for param_name, weight_name in expert_params_mapping:
if weight_name not in name:
......@@ -447,3 +449,5 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
......@@ -22,7 +22,7 @@
# limitations under the License.
"""Inference-only DeciLM model compatible with HuggingFace weights."""
from typing import Iterable, Tuple
from typing import Iterable, Set, Tuple
import torch
......@@ -57,7 +57,8 @@ class DeciLMForCausalLM(LlamaForCausalLM):
delattr(config, "num_key_value_heads_per_layer")
super().__init__(vllm_config=vllm_config)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -67,6 +68,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......@@ -97,6 +99,8 @@ class DeciLMForCausalLM(LlamaForCausalLM):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def _degroup_weight(self, loaded_weight: torch.Tensor) -> torch.Tensor:
hidden_size = self.config.hidden_size
......
......@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Deepseek model."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
......@@ -442,7 +442,8 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -453,6 +454,7 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......@@ -487,3 +489,5 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
......@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only DeepseekV2 model."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
......@@ -550,7 +550,8 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
device=device),
})
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
......@@ -566,6 +567,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
num_experts=self.config.n_routed_experts)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......@@ -623,3 +625,5 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
......@@ -22,7 +22,7 @@
# limitations under the License.
"""Inference-only Exaone model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
......@@ -513,7 +513,8 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
......@@ -523,6 +524,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
(".gate_up_proj", ".c_fc_1", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......@@ -543,6 +545,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
......@@ -576,6 +579,8 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
......
......@@ -18,7 +18,7 @@
"""PyTorch Falcon model."""
import math
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
......@@ -473,7 +473,8 @@ class FalconForCausalLM(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
total_num_heads = self.config.num_attention_heads
if self.config.new_decoder_architecture:
total_num_kv_heads = self.config.num_kv_heads
......@@ -483,6 +484,7 @@ class FalconForCausalLM(nn.Module, SupportsPP):
total_num_kv_heads = total_num_heads
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if name == "lm_head.weight" and self.tie_word_embeddings:
# Falcon uses tied embeddings except Falcon-11b.
......@@ -519,3 +521,5 @@ class FalconForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
import math
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Set, Tuple
import torch
import torch.nn as nn
......@@ -156,7 +156,8 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -165,12 +166,13 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
......@@ -183,6 +185,8 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Florence2ForConditionalGeneration(nn.Module):
......@@ -248,10 +252,11 @@ class Florence2ForConditionalGeneration(nn.Module):
) -> SamplerOutput:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
skip_prefixes = [
'image_projection', "vision_tower", "image_proj_norm",
"image_pos_embed", "visual_temporal_embed"
]
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
loader.load_weights(weights)
return loader.load_weights(weights)
......@@ -16,7 +16,8 @@
""" PyTorch Fuyu model."""
import math
from array import array
from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict)
import torch
import torch.nn as nn
......@@ -354,6 +355,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
next_tokens = self.language_model.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
loader.load_weights(weights)
return loader.load_weights(weights)
......@@ -424,7 +424,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -469,3 +470,4 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
logger.warning(
"Some weights are not initialized from checkpoints: %s",
unloaded_params)
return loaded_params
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment