Unverified Commit f1e9bbaf authored by JiLi's avatar JiLi Committed by GitHub
Browse files

feat: Add flexible validation for partial weight updates (#9663)


Co-authored-by: default avatarRichardW <rich-junwang@users.noreply.github.com>
Co-authored-by: default avatarZhuorany <yzr1914001753@gmail.com>
Co-authored-by: default avatarStefan He <hebiaobuaa@gmail.com>
Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
Co-authored-by: default avatarNight <32424487+PrinsYin@users.noreply.github.com>
Co-authored-by: default avatarzhaochenyang20 <zhaochen20@outlook.com>
Co-authored-by: default avatarLiangsheng Yin <hnyls2002@gmail.com>
parent 3fd1431d
...@@ -1029,10 +1029,6 @@ class GptOssForCausalLM(nn.Module): ...@@ -1029,10 +1029,6 @@ class GptOssForCausalLM(nn.Module):
) )
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
params_checker = {k: False for k, v in params_dict.items()}
for other_loaded_param_name in other_loaded_param_names:
params_checker[other_loaded_param_name] = True
for name, loaded_weight in weights: for name, loaded_weight in weights:
loaded_weight = _WeightCreator.maybe_materialize(loaded_weight) loaded_weight = _WeightCreator.maybe_materialize(loaded_weight)
...@@ -1069,7 +1065,6 @@ class GptOssForCausalLM(nn.Module): ...@@ -1069,7 +1065,6 @@ class GptOssForCausalLM(nn.Module):
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
params_checker[name] = True
break break
else: else:
for mapping in expert_params_mapping: for mapping in expert_params_mapping:
...@@ -1092,7 +1087,6 @@ class GptOssForCausalLM(nn.Module): ...@@ -1092,7 +1087,6 @@ class GptOssForCausalLM(nn.Module):
name, name,
shard_id=shard_id, shard_id=shard_id,
) )
params_checker[name] = True
break break
else: else:
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
...@@ -1111,17 +1105,9 @@ class GptOssForCausalLM(nn.Module): ...@@ -1111,17 +1105,9 @@ class GptOssForCausalLM(nn.Module):
param, "weight_loader", default_weight_loader param, "weight_loader", default_weight_loader
) )
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
params_checker[name] = True
else: else:
logger.warning(f"Parameter {name} not found in params_dict") logger.warning(f"Parameter {name} not found in params_dict")
not_loaded_params = [k for k, v in params_checker.items() if not v]
if tp_rank == 0:
if len(not_loaded_params) > 0:
raise Exception(f"Not all parameters loaded: {not_loaded_params}")
else:
logging.info("All parameters loaded successfully.")
def get_embed_and_head(self): def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight return self.model.embed_tokens.weight, self.lm_head.weight
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment