Unverified Commit 604b62c4 authored by Surya Kasturi's avatar Surya Kasturi Committed by GitHub
Browse files

Allow writing config to wandb (#2736)

* Allow writing confing to wandb

* set defaults

* Update help

* Update help
parent b8adf3cc
...@@ -240,6 +240,12 @@ def setup_parser() -> argparse.ArgumentParser: ...@@ -240,6 +240,12 @@ def setup_parser() -> argparse.ArgumentParser:
default="", default="",
help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval", help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval",
) )
parser.add_argument(
"--wandb_config_args",
type=str,
default="",
help="Comma separated string arguments passed to wandb.config.update. Use this to trace parameters that aren't already traced by default. eg. `lr=0.01,repeats=3",
)
parser.add_argument( parser.add_argument(
"--hf_hub_log_args", "--hf_hub_log_args",
type=str, type=str,
...@@ -300,7 +306,9 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -300,7 +306,9 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
args = parse_eval_args(parser) args = parse_eval_args(parser)
if args.wandb_args: if args.wandb_args:
wandb_logger = WandbLogger(**simple_parse_args_string(args.wandb_args)) wandb_args_dict = simple_parse_args_string(args.wandb_args)
wandb_config_args_dict = simple_parse_args_string(args.wandb_config_args)
wandb_logger = WandbLogger(wandb_args_dict, wandb_config_args_dict)
utils.setup_logging(args.verbosity) utils.setup_logging(args.verbosity)
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
......
...@@ -22,11 +22,12 @@ def get_wandb_printer() -> Literal["Printer"]: ...@@ -22,11 +22,12 @@ def get_wandb_printer() -> Literal["Printer"]:
class WandbLogger: class WandbLogger:
def __init__(self, **kwargs) -> None: def __init__(self, init_args=None, config_args=None) -> None:
"""Attaches to wandb logger if already initialized. Otherwise, passes kwargs to wandb.init() """Attaches to wandb logger if already initialized. Otherwise, passes init_args to wandb.init() and config_args to wandb.config.update()
Args: Args:
kwargs Optional[Any]: Arguments for configuration. init_args Optional[Dict]: Arguments for init configuration.
config_args Optional[Dict]: Arguments for config
Parse and log the results returned from evaluator.simple_evaluate() with: Parse and log the results returned from evaluator.simple_evaluate() with:
wandb_logger.post_init(results) wandb_logger.post_init(results)
...@@ -46,7 +47,8 @@ class WandbLogger: ...@@ -46,7 +47,8 @@ class WandbLogger:
f"{e}" f"{e}"
) )
self.wandb_args: Dict[str, Any] = kwargs self.wandb_args: Dict[str, Any] = init_args or {}
self.wandb_config_args: Dict[str, Any] = config_args or {}
# pop the step key from the args to save for all logging calls # pop the step key from the args to save for all logging calls
self.step = self.wandb_args.pop("step", None) self.step = self.wandb_args.pop("step", None)
...@@ -54,6 +56,8 @@ class WandbLogger: ...@@ -54,6 +56,8 @@ class WandbLogger:
# initialize a W&B run # initialize a W&B run
if wandb.run is None: if wandb.run is None:
self.run = wandb.init(**self.wandb_args) self.run = wandb.init(**self.wandb_args)
if self.wandb_config_args:
self.run.config.update(self.wandb_config_args)
else: else:
self.run = wandb.run self.run = wandb.run
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment