"...static/style/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "cbb63c5bec618354a25583c0861f45d4a01d9812"
Unverified Commit 44c5621d authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

fix tests (#11615)

parent 7eee950a
...@@ -1249,6 +1249,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1249,6 +1249,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
# key re-naming operations are never done on the keys
# that are loaded, but always on the keys of the newly initialized model
remove_prefix = not has_prefix_module and expects_prefix_module remove_prefix = not has_prefix_module and expects_prefix_module
add_prefix = has_prefix_module and not expects_prefix_module add_prefix = has_prefix_module and not expects_prefix_module
...@@ -1347,13 +1350,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1347,13 +1350,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False): def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
module_keys = set([".".join(key.split(".")[:-1]) for key in names]) module_keys = set([".".join(key.split(".")[:-1]) for key in names])
# torch.nn.ParameterList is a special case where two parameter keywords
# are appended to the module name, *e.g.* bert.special_embeddings.0
module_keys = module_keys.union(set([".".join(key.split(".")[:-2]) for key in names if key[-1].isdigit()]))
retrieved_modules = [] retrieved_modules = []
# retrieve all modules that has at least one missing weight name # retrieve all modules that has at least one missing weight name
for name, module in self.named_modules(): for name, module in self.named_modules():
if remove_prefix: if remove_prefix:
name = ".".join(name.split(".")[1:]) if name.startswith(self.base_model_prefix) else name name = ".".join(name.split(".")[1:]) if name.startswith(self.base_model_prefix) else name
elif add_prefix: elif add_prefix:
name = ".".join([self.base_model_prefix, name]) name = ".".join([self.base_model_prefix, name]) if len(name) > 0 else self.base_model_prefix
if name in module_keys: if name in module_keys:
retrieved_modules.append(module) retrieved_modules.append(module)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment