Unverified Commit 60861fe1 authored by 조준래's avatar 조준래 Committed by GitHub
Browse files

Implement JSON dump conversion for torch_dtype in TrainingArguments (#31224)



* Implement JSON dump conversion for torch_dtype in TrainingArguments

* Add unit test for converting torch_dtype in TrainingArguments to JSON

* move unit test for converting torch_dtype into TrainerIntegrationTest class

* reformating using ruff

* convert dict_torch_dtype_to_str to private method _dict_torch_dtype_to_str

---------
Co-authored-by: default avatarjun.4 <jun.4@kakaobrain.com>
parent ff689f57
...@@ -2370,6 +2370,18 @@ class TrainingArguments: ...@@ -2370,6 +2370,18 @@ class TrainingArguments:
) )
return warmup_steps return warmup_steps
def _dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
"""
Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None,
converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
string, which can then be stored in the json format.
"""
if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
for value in d.values():
if isinstance(value, dict):
self._dict_torch_dtype_to_str(value)
def to_dict(self): def to_dict(self):
""" """
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
...@@ -2388,6 +2400,8 @@ class TrainingArguments: ...@@ -2388,6 +2400,8 @@ class TrainingArguments:
# Handle the accelerator_config if passed # Handle the accelerator_config if passed
if is_accelerate_available() and isinstance(v, AcceleratorConfig): if is_accelerate_available() and isinstance(v, AcceleratorConfig):
d[k] = v.to_dict() d[k] = v.to_dict()
self._dict_torch_dtype_to_str(d)
return d return d
def to_json_string(self): def to_json_string(self):
......
...@@ -3445,6 +3445,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -3445,6 +3445,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
) )
self.assertTrue("Tried passing in a callable to `accelerator_config`" in str(context.exception)) self.assertTrue("Tried passing in a callable to `accelerator_config`" in str(context.exception))
def test_torch_dtype_to_json(self):
@dataclasses.dataclass
class TorchDtypeTrainingArguments(TrainingArguments):
torch_dtype: torch.dtype = dataclasses.field(
default=torch.float32,
)
for dtype in [
"float32",
"float64",
"complex64",
"complex128",
"float16",
"bfloat16",
"uint8",
"int8",
"int16",
"int32",
"int64",
"bool",
]:
torch_dtype = getattr(torch, dtype)
with tempfile.TemporaryDirectory() as tmp_dir:
args = TorchDtypeTrainingArguments(output_dir=tmp_dir, torch_dtype=torch_dtype)
args_dict = args.to_dict()
self.assertIn("torch_dtype", args_dict)
self.assertEqual(args_dict["torch_dtype"], dtype)
@require_torch @require_torch
@is_staging_test @is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment