Unverified Commit 18ca59e1 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix arg count for partial functions (#12609)

parent 0cc2dc24
...@@ -119,6 +119,7 @@ from .trainer_utils import ( ...@@ -119,6 +119,7 @@ from .trainer_utils import (
default_hp_space, default_hp_space,
denumpify_detensorize, denumpify_detensorize,
get_last_checkpoint, get_last_checkpoint,
number_of_arguments,
set_seed, set_seed,
speed_metrics, speed_metrics,
) )
...@@ -905,7 +906,7 @@ class Trainer: ...@@ -905,7 +906,7 @@ class Trainer:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
def call_model_init(self, trial=None): def call_model_init(self, trial=None):
model_init_argcount = len(inspect.signature(self.model_init).parameters) model_init_argcount = number_of_arguments(self.model_init)
if model_init_argcount == 0: if model_init_argcount == 0:
model = self.model_init() model = self.model_init()
elif model_init_argcount == 1: elif model_init_argcount == 1:
......
...@@ -17,6 +17,7 @@ Utilities for the Trainer and TFTrainer class. Should be independent from PyTorc ...@@ -17,6 +17,7 @@ Utilities for the Trainer and TFTrainer class. Should be independent from PyTorc
""" """
import copy import copy
import functools
import gc import gc
import inspect import inspect
import os import os
...@@ -468,6 +469,16 @@ def denumpify_detensorize(metrics): ...@@ -468,6 +469,16 @@ def denumpify_detensorize(metrics):
return metrics return metrics
def number_of_arguments(func):
"""
Return the number of arguments of the passed function, even if it's a partial function.
"""
if isinstance(func, functools.partial):
total_args = len(inspect.signature(func.func).parameters)
return total_args - len(func.args) - len(func.keywords)
return len(inspect.signature(func).parameters)
class ShardedDDPOption(ExplicitEnum): class ShardedDDPOption(ExplicitEnum):
SIMPLE = "simple" SIMPLE = "simple"
ZERO_DP_2 = "zero_dp_2" ZERO_DP_2 = "zero_dp_2"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment