Unverified Commit 476844d4 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Fix underscores in dict keys passed via CLI (#19030)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 4e68ae5e
...@@ -259,11 +259,18 @@ def test_dict_args(parser): ...@@ -259,11 +259,18 @@ def test_dict_args(parser):
"--model-name=something.something", "--model-name=something.something",
"--hf-overrides.key1", "--hf-overrides.key1",
"val1", "val1",
# Test nesting
"--hf-overrides.key2.key3", "--hf-overrides.key2.key3",
"val2", "val2",
"--hf-overrides.key2.key4", "--hf-overrides.key2.key4",
"val3", "val3",
# Test = sign
"--hf-overrides.key5=val4", "--hf-overrides.key5=val4",
# Test underscore to dash conversion
"--hf_overrides.key_6",
"val5",
"--hf_overrides.key-7.key_8",
"val6",
] ]
parsed_args = parser.parse_args(args) parsed_args = parser.parse_args(args)
assert parsed_args.model_name == "something.something" assert parsed_args.model_name == "something.something"
...@@ -274,6 +281,10 @@ def test_dict_args(parser): ...@@ -274,6 +281,10 @@ def test_dict_args(parser):
"key4": "val3", "key4": "val3",
}, },
"key5": "val4", "key5": "val4",
"key_6": "val5",
"key-7": {
"key_8": "val6",
},
} }
......
...@@ -1456,17 +1456,24 @@ class FlexibleArgumentParser(ArgumentParser): ...@@ -1456,17 +1456,24 @@ class FlexibleArgumentParser(ArgumentParser):
if '--config' in args: if '--config' in args:
args = self._pull_args_from_config(args) args = self._pull_args_from_config(args)
def repl(match: re.Match) -> str:
"""Replaces underscores with dashes in the matched string."""
return match.group(0).replace("_", "-")
# Everything between the first -- and the first .
pattern = re.compile(r"(?<=--)[^\.]*")
# Convert underscores to dashes and vice versa in argument names # Convert underscores to dashes and vice versa in argument names
processed_args = [] processed_args = []
for arg in args: for arg in args:
if arg.startswith('--'): if arg.startswith('--'):
if '=' in arg: if '=' in arg:
key, value = arg.split('=', 1) key, value = arg.split('=', 1)
key = '--' + key[len('--'):].replace('_', '-') key = pattern.sub(repl, key, count=1)
processed_args.append(f'{key}={value}') processed_args.append(f'{key}={value}')
else: else:
processed_args.append('--' + key = pattern.sub(repl, arg, count=1)
arg[len('--'):].replace('_', '-')) processed_args.append(key)
elif arg.startswith('-O') and arg != '-O' and len(arg) == 2: elif arg.startswith('-O') and arg != '-O' and len(arg) == 2:
# allow -O flag to be used without space, e.g. -O3 # allow -O flag to be used without space, e.g. -O3
processed_args.append('-O') processed_args.append('-O')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment