Unverified Commit 2df60287 authored by IMvision12's avatar IMvision12 Committed by GitHub
Browse files

Added tests for yaml and json parser (#19219)

* Added tests for yaml and json

* Added tests for yaml and json
parent 2d956958
......@@ -281,7 +281,9 @@ class HfArgumentParser(ArgumentParser):
- the dataclass instances in the same order as they were passed to the initializer.
"""
outputs = self.parse_dict(json.loads(Path(json_file).read_text()), allow_extra_keys=allow_extra_keys)
open_json_file = open(Path(json_file))
data = json.loads(open_json_file.read())
outputs = self.parse_dict(data, allow_extra_keys=allow_extra_keys)
return tuple(outputs)
def parse_yaml_file(self, yaml_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
......@@ -301,5 +303,5 @@ class HfArgumentParser(ArgumentParser):
- the dataclass instances in the same order as they were passed to the initializer.
"""
outputs = self.parse_dict(yaml.safe_load(yaml_file), allow_extra_keys=allow_extra_keys)
outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys)
return tuple(outputs)
......@@ -13,12 +13,17 @@
# limitations under the License.
import argparse
import json
import os
import tempfile
import unittest
from argparse import Namespace
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import List, Optional
import yaml
from transformers import HfArgumentParser, TrainingArguments
from transformers.hf_argparser import string_to_bool
......@@ -258,6 +263,43 @@ class HfArgumentParserTest(unittest.TestCase):
self.assertRaises(ValueError, parser.parse_dict, args_dict, allow_extra_keys=False)
def test_parse_json(self):
parser = HfArgumentParser(BasicExample)
args_dict_for_json = {
"foo": 12,
"bar": 3.14,
"baz": "42",
"flag": True,
}
with tempfile.TemporaryDirectory() as tmp_dir:
temp_local_path = os.path.join(tmp_dir, "temp_json")
os.mkdir(temp_local_path)
with open(temp_local_path + ".json", "w+") as f:
json.dump(args_dict_for_json, f)
parsed_args = parser.parse_yaml_file(Path(temp_local_path + ".json"))[0]
args = BasicExample(**args_dict_for_json)
self.assertEqual(parsed_args, args)
def test_parse_yaml(self):
parser = HfArgumentParser(BasicExample)
args_dict_for_yaml = {
"foo": 12,
"bar": 3.14,
"baz": "42",
"flag": True,
}
with tempfile.TemporaryDirectory() as tmp_dir:
temp_local_path = os.path.join(tmp_dir, "temp_yaml")
os.mkdir(temp_local_path)
with open(temp_local_path + ".yaml", "w+") as f:
yaml.dump(args_dict_for_yaml, f)
parsed_args = parser.parse_yaml_file(Path(temp_local_path + ".yaml"))[0]
args = BasicExample(**args_dict_for_yaml)
self.assertEqual(parsed_args, args)
def test_integration_training_args(self):
parser = HfArgumentParser(TrainingArguments)
self.assertIsNotNone(parser)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment