Unverified Commit b79028f0 authored by atturaioe's avatar atturaioe Committed by GitHub
Browse files

Fix TrainingArgs argument serialization (#19239)

parent 902d30b3
...@@ -17,7 +17,7 @@ import json ...@@ -17,7 +17,7 @@ import json
import math import math
import os import os
import warnings import warnings
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field, fields
from datetime import timedelta from datetime import timedelta
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
...@@ -1000,10 +1000,6 @@ class TrainingArguments: ...@@ -1000,10 +1000,6 @@ class TrainingArguments:
if env_local_rank != -1 and env_local_rank != self.local_rank: if env_local_rank != -1 and env_local_rank != self.local_rank:
self.local_rank = env_local_rank self.local_rank = env_local_rank
# convert to int
self.log_level = trainer_log_levels[self.log_level]
self.log_level_replica = trainer_log_levels[self.log_level_replica]
# expand paths, if not os.makedirs("~/bar") will make directory # expand paths, if not os.makedirs("~/bar") will make directory
# in the current directory instead of the actual home # in the current directory instead of the actual home
#  see https://github.com/huggingface/transformers/issues/10628 #  see https://github.com/huggingface/transformers/issues/10628
...@@ -1604,8 +1600,12 @@ class TrainingArguments: ...@@ -1604,8 +1600,12 @@ class TrainingArguments:
The choice between the main and replica process settings is made according to the return value of `should_log`. The choice between the main and replica process settings is made according to the return value of `should_log`.
""" """
log_level_main_node = logging.INFO if self.log_level == -1 else self.log_level # convert to int
log_level_replica_node = logging.WARNING if self.log_level_replica == -1 else self.log_level_replica log_level = trainer_log_levels[self.log_level]
log_level_replica = trainer_log_levels[self.log_level_replica]
log_level_main_node = logging.INFO if log_level == -1 else log_level
log_level_replica_node = logging.WARNING if log_level_replica == -1 else log_level_replica
return log_level_main_node if self.should_log else log_level_replica_node return log_level_main_node if self.should_log else log_level_replica_node
@property @property
...@@ -1691,7 +1691,9 @@ class TrainingArguments: ...@@ -1691,7 +1691,9 @@ class TrainingArguments:
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
the token values by removing their value. the token values by removing their value.
""" """
d = asdict(self) # filter out fields that are defined as field(init=False)
d = dict((field.name, getattr(self, field.name)) for field in fields(self) if field.init)
for k, v in d.items(): for k, v in d.items():
if isinstance(v, Enum): if isinstance(v, Enum):
d[k] = v.value d[k] = v.value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment