Unverified Commit 7f0027db authored by Teven's avatar Teven Committed by GitHub
Browse files

Fixing bug with param count without embeddings (#12461)



* fixing bug with param count without embeddings

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* style
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent d5b8fe3b
...@@ -352,11 +352,16 @@ class ModuleUtilsMixin: ...@@ -352,11 +352,16 @@ class ModuleUtilsMixin:
:obj:`int`: The number of parameters. :obj:`int`: The number of parameters.
""" """
def parameter_filter(x): if exclude_embeddings:
return (x.requires_grad or not only_trainable) and not (isinstance(x, nn.Embedding) and exclude_embeddings) embedding_param_names = [
f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding)
params = filter(parameter_filter, self.parameters()) if only_trainable else self.parameters() ]
return sum(p.numel() for p in params) non_embedding_parameters = [
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
]
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
else:
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> int: def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> int:
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment