Unverified Commit 7f0027db authored by Teven's avatar Teven Committed by GitHub
Browse files

Fixing bug with param count without embeddings (#12461)



* fixing bug with param count without embeddings

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* style
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent d5b8fe3b
......@@ -352,11 +352,16 @@ class ModuleUtilsMixin:
:obj:`int`: The number of parameters.
"""
def parameter_filter(x):
return (x.requires_grad or not only_trainable) and not (isinstance(x, nn.Embedding) and exclude_embeddings)
params = filter(parameter_filter, self.parameters()) if only_trainable else self.parameters()
return sum(p.numel() for p in params)
if exclude_embeddings:
embedding_param_names = [
f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding)
]
non_embedding_parameters = [
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
]
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
else:
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> int:
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment