Unverified Commit 18ca59e1 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix arg count for partial functions (#12609)

parent 0cc2dc24
......@@ -119,6 +119,7 @@ from .trainer_utils import (
default_hp_space,
denumpify_detensorize,
get_last_checkpoint,
number_of_arguments,
set_seed,
speed_metrics,
)
......@@ -905,7 +906,7 @@ class Trainer:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
def call_model_init(self, trial=None):
model_init_argcount = len(inspect.signature(self.model_init).parameters)
model_init_argcount = number_of_arguments(self.model_init)
if model_init_argcount == 0:
model = self.model_init()
elif model_init_argcount == 1:
......
......@@ -17,6 +17,7 @@ Utilities for the Trainer and TFTrainer class. Should be independent from PyTorc
"""
import copy
import functools
import gc
import inspect
import os
......@@ -468,6 +469,16 @@ def denumpify_detensorize(metrics):
return metrics
def number_of_arguments(func):
"""
Return the number of arguments of the passed function, even if it's a partial function.
"""
if isinstance(func, functools.partial):
total_args = len(inspect.signature(func.func).parameters)
return total_args - len(func.args) - len(func.keywords)
return len(inspect.signature(func).parameters)
class ShardedDDPOption(ExplicitEnum):
SIMPLE = "simple"
ZERO_DP_2 = "zero_dp_2"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment