Unverified Commit 604b62c4 authored by Surya Kasturi's avatar Surya Kasturi Committed by GitHub
Browse files

Allow writing config to wandb (#2736)

* Allow writing confing to wandb

* set defaults

* Update help

* Update help
parent b8adf3cc
......@@ -240,6 +240,12 @@ def setup_parser() -> argparse.ArgumentParser:
default="",
help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval",
)
parser.add_argument(
"--wandb_config_args",
type=str,
default="",
help="Comma separated string arguments passed to wandb.config.update. Use this to trace parameters that aren't already traced by default. eg. `lr=0.01,repeats=3",
)
parser.add_argument(
"--hf_hub_log_args",
type=str,
......@@ -300,7 +306,9 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
args = parse_eval_args(parser)
if args.wandb_args:
wandb_logger = WandbLogger(**simple_parse_args_string(args.wandb_args))
wandb_args_dict = simple_parse_args_string(args.wandb_args)
wandb_config_args_dict = simple_parse_args_string(args.wandb_config_args)
wandb_logger = WandbLogger(wandb_args_dict, wandb_config_args_dict)
utils.setup_logging(args.verbosity)
eval_logger = logging.getLogger(__name__)
......
......@@ -22,11 +22,12 @@ def get_wandb_printer() -> Literal["Printer"]:
class WandbLogger:
def __init__(self, **kwargs) -> None:
"""Attaches to wandb logger if already initialized. Otherwise, passes kwargs to wandb.init()
def __init__(self, init_args=None, config_args=None) -> None:
"""Attaches to wandb logger if already initialized. Otherwise, passes init_args to wandb.init() and config_args to wandb.config.update()
Args:
kwargs Optional[Any]: Arguments for configuration.
init_args Optional[Dict]: Arguments for init configuration.
config_args Optional[Dict]: Arguments for config
Parse and log the results returned from evaluator.simple_evaluate() with:
wandb_logger.post_init(results)
......@@ -46,7 +47,8 @@ class WandbLogger:
f"{e}"
)
self.wandb_args: Dict[str, Any] = kwargs
self.wandb_args: Dict[str, Any] = init_args or {}
self.wandb_config_args: Dict[str, Any] = config_args or {}
# pop the step key from the args to save for all logging calls
self.step = self.wandb_args.pop("step", None)
......@@ -54,6 +56,8 @@ class WandbLogger:
# initialize a W&B run
if wandb.run is None:
self.run = wandb.init(**self.wandb_args)
if self.wandb_config_args:
self.run.config.update(self.wandb_config_args)
else:
self.run = wandb.run
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment