Unverified Commit ae79b121 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

modify `WandbLogger` to accept arbitrary kwargs (#1491)

* make `WandbLogger` init args optional

* nit

* nit

* nit

* move import warning to `WandbLogger`

* nit

* update docs

* nit
parent 27a3da96
......@@ -321,7 +321,7 @@ lm_eval \
--log_samples
```
In the stdout, you will find the link to the W&B run page as well as link to the generated report. You can find an example of this workflow in [examples/visualize-wandb.ipynb](examples/visualize-wandb.ipynb).
In the stdout, you will find the link to the W&B run page as well as link to the generated report. You can find an example of this workflow in [examples/visualize-wandb.ipynb](examples/visualize-wandb.ipynb), and an example of how to integrate it beyond the CLI.
## How to Contribute or Learn More?
......
......@@ -67,6 +67,7 @@
"outputs": [],
"source": [
"import wandb\n",
"\n",
"wandb.login()"
]
},
......@@ -104,6 +105,43 @@
" --wandb_args project=lm-eval-harness-integration \\\n",
" --log_samples"
]
},
{
"cell_type": "markdown",
"id": "e974cabdbe70b667",
"metadata": {},
"source": ""
},
{
"cell_type": "markdown",
"id": "5178ca9445b844e4",
"metadata": {},
"source": "W&B can also be initialized programmatically for use outside the CLI to parse and log the results."
},
{
"cell_type": "code",
"execution_count": null,
"id": "c6a421b2cf3ddac5",
"metadata": {},
"outputs": [],
"source": [
"import lm_eval\n",
"from lm_eval.logging_utils import WandbLogger\n",
"\n",
"results = lm_eval.simple_evaluate(\n",
" model=\"hf\",\n",
" model_args=\"pretrained=microsoft/phi-2,trust_remote_code=True\",\n",
" tasks=\"hellaswag,mmlu_abstract_algebra\",\n",
" log_samples=True,\n",
")\n",
"\n",
"wandb_logger = WandbLogger(\n",
" project=\"lm-eval-harness-integration\", job_type=\"eval\"\n",
") # or empty if wandb.init(...) already called before\n",
"wandb_logger.post_init(results)\n",
"wandb_logger.log_eval_result()\n",
"wandb_logger.log_eval_samples(results[\"samples\"]) # if log_samples"
]
}
],
"metadata": {
......
......@@ -14,7 +14,7 @@ from lm_eval import evaluator, utils
from lm_eval.evaluator import request_caching_arg_to_dict
from lm_eval.logging_utils import WandbLogger
from lm_eval.tasks import TaskManager, include_path, initialize_tasks
from lm_eval.utils import make_table
from lm_eval.utils import make_table, simple_parse_args_string
def _handle_non_serializable(o):
......@@ -210,7 +210,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
args = parse_eval_args()
if args.wandb_args:
wandb_logger = WandbLogger(args)
wandb_logger = WandbLogger(**simple_parse_args_string(args.wandb_args))
eval_logger = utils.eval_logger
eval_logger.setLevel(getattr(logging, f"{args.verbosity}"))
......
......@@ -13,24 +13,9 @@ from packaging.version import Version
from torch.utils.collect_env import get_pretty_env_info
from transformers import __version__ as trans_version
from lm_eval.utils import simple_parse_args_string
logger = logging.getLogger(__name__)
try:
import wandb
assert Version(wandb.__version__) >= Version("0.13.6")
if Version(wandb.__version__) < Version("0.13.6"):
wandb.require("report-editing:v0")
except Exception as e:
logger.warning(
"To use the wandb reporting functionality please install wandb>=0.13.6.\n"
"To install the latest version of wandb run `pip install wandb --upgrade`\n"
f"{e}"
)
def remove_none_pattern(input_string: str) -> Tuple[str, bool]:
"""Remove the ',none' substring from the input_string if it exists at the end.
......@@ -83,14 +68,31 @@ def get_wandb_printer() -> Literal["Printer"]:
class WandbLogger:
def __init__(self, args: Any) -> None:
"""Initialize the WandbLogger.
def __init__(self, **kwargs) -> None:
"""Attaches to wandb logger if already initialized. Otherwise, passes kwargs to wandb.init()
Args:
results (Dict[str, Any]): The results dictionary.
args (Any): Arguments for configuration.
kwargs Optional[Any]: Arguments for configuration.
Parse and log the results returned from evaluator.simple_evaluate() with:
wandb_logger.post_init(results)
wandb_logger.log_eval_result()
wandb_logger.log_eval_samples(results["samples"])
"""
self.wandb_args: Dict[str, Any] = simple_parse_args_string(args.wandb_args)
try:
import wandb
assert Version(wandb.__version__) >= Version("0.13.6")
if Version(wandb.__version__) < Version("0.13.6"):
wandb.require("report-editing:v0")
except Exception as e:
logger.warning(
"To use the wandb reporting functionality please install wandb>=0.13.6.\n"
"To install the latest version of wandb run `pip install wandb --upgrade`\n"
f"{e}"
)
self.wandb_args: Dict[str, Any] = kwargs
# initialize a W&B run
if wandb.run is None:
......@@ -164,6 +166,8 @@ class WandbLogger:
]
def make_table(columns: List[str], key: str = "results"):
import wandb
table = wandb.Table(columns=columns)
results = copy.deepcopy(self.results)
......@@ -202,6 +206,8 @@ class WandbLogger:
def _log_results_as_artifact(self) -> None:
"""Log results as JSON artifact to W&B."""
import wandb
dumped = json.dumps(
self.results, indent=2, default=_handle_non_serializable, ensure_ascii=False
)
......@@ -320,6 +326,8 @@ class WandbLogger:
def _log_samples_as_artifact(
self, data: List[Dict[str, Any]], task_name: str
) -> None:
import wandb
# log the samples as an artifact
dumped = json.dumps(
data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment