Unverified Commit 0e96cd18 authored by Debjyoti Ray's avatar Debjyoti Ray Committed by GitHub
Browse files

Fixed #3005: Processes both formats of model_args: string and dictionay (#3097)



* git push --force
correctly processes both formats of model_args: string and dictionary both

* exctract to function for better test

* nit

---------
Co-authored-by: default avatarBaber <baber@hey.com>
parent 6e91fdcd
...@@ -4,6 +4,7 @@ import logging ...@@ -4,6 +4,7 @@ import logging
import os import os
import re import re
from pathlib import Path from pathlib import Path
from typing import Union
import pandas as pd import pandas as pd
from zeno_client import ZenoClient, ZenoMetric from zeno_client import ZenoClient, ZenoMetric
...@@ -35,6 +36,22 @@ def parse_args(): ...@@ -35,6 +36,22 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
def sanitize_string(model_args_raw: Union[str, dict]) -> str:
"""Sanitize the model_args string or dict"""
# Convert to string if it's a dictionary
model_args_str = (
json.dumps(model_args_raw)
if isinstance(model_args_raw, dict)
else model_args_raw
)
# Apply the sanitization
return re.sub(
r"[\"<>:/|\\?*\[\]]+",
"__",
model_args_str,
)
def main(): def main():
"""Upload the results of your benchmark tasks to the Zeno AI evaluation platform. """Upload the results of your benchmark tasks to the Zeno AI evaluation platform.
...@@ -87,13 +104,16 @@ def main(): ...@@ -87,13 +104,16 @@ def main():
latest_sample_results = get_latest_filename( latest_sample_results = get_latest_filename(
[Path(f).name for f in model_sample_filenames if task in f] [Path(f).name for f in model_sample_filenames if task in f]
) )
model_args = re.sub( # Load the model_args, which can be either a string or a dictionary
r"[\"<>:/\|\\?\*\[\]]+", model_args = sanitize_string(
"__",
json.load( json.load(
open(Path(args.data_path, model, latest_results), encoding="utf-8") open(
)["config"]["model_args"], Path(args.data_path, model, latest_results),
encoding="utf-8",
)
)["config"]["model_args"]
) )
print(model_args) print(model_args)
data = [] data = []
with open( with open(
......
import json
import re
import pytest
from scripts.zeno_visualize import sanitize_string
@pytest.skip("requires zeno_client dependency")
def test_zeno_sanitize_string():
"""
Test that the model_args handling logic in zeno_visualize.py properly handles
different model_args formats (string and dictionary).
"""
# Define the process_model_args function that replicates the fixed logic in zeno_visualize.py
# Test case 1: model_args as a string
string_model_args = "pretrained=EleutherAI/pythia-160m,dtype=float32"
result_string = sanitize_string(string_model_args)
expected_string = re.sub(r"[\"<>:/\|\\?\*\[\]]+", "__", string_model_args)
# Test case 2: model_args as a dictionary
dict_model_args = {"pretrained": "EleutherAI/pythia-160m", "dtype": "float32"}
result_dict = sanitize_string(dict_model_args)
expected_dict = re.sub(r"[\"<>:/\|\\?\*\[\]]+", "__", json.dumps(dict_model_args))
# Verify the results
assert result_string == expected_string
assert result_dict == expected_dict
# Also test that the sanitization works as expected
assert ":" not in result_string # No colons in sanitized output
assert ":" not in result_dict # No colons in sanitized output
assert "/" not in result_dict # No slashes in sanitized output
assert "<" not in result_dict # No angle brackets in sanitized output
if __name__ == "__main__":
test_zeno_sanitize_string()
print("All tests passed.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment