"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "87927b248e9154b74c58dc9aef1438b3eb327937"
Unverified Commit b79028f0 authored by atturaioe's avatar atturaioe Committed by GitHub
Browse files

Fix TrainingArgs argument serialization (#19239)

parent 902d30b3
......@@ -17,7 +17,7 @@ import json
import math
import os
import warnings
from dataclasses import asdict, dataclass, field
from dataclasses import asdict, dataclass, field, fields
from datetime import timedelta
from enum import Enum
from pathlib import Path
......@@ -1000,10 +1000,6 @@ class TrainingArguments:
if env_local_rank != -1 and env_local_rank != self.local_rank:
self.local_rank = env_local_rank
# convert to int
self.log_level = trainer_log_levels[self.log_level]
self.log_level_replica = trainer_log_levels[self.log_level_replica]
# expand paths, if not os.makedirs("~/bar") will make directory
# in the current directory instead of the actual home
#  see https://github.com/huggingface/transformers/issues/10628
......@@ -1604,8 +1600,12 @@ class TrainingArguments:
The choice between the main and replica process settings is made according to the return value of `should_log`.
"""
log_level_main_node = logging.INFO if self.log_level == -1 else self.log_level
log_level_replica_node = logging.WARNING if self.log_level_replica == -1 else self.log_level_replica
# convert to int
log_level = trainer_log_levels[self.log_level]
log_level_replica = trainer_log_levels[self.log_level_replica]
log_level_main_node = logging.INFO if log_level == -1 else log_level
log_level_replica_node = logging.WARNING if log_level_replica == -1 else log_level_replica
return log_level_main_node if self.should_log else log_level_replica_node
@property
......@@ -1691,7 +1691,9 @@ class TrainingArguments:
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
the token values by removing their value.
"""
d = asdict(self)
# filter out fields that are defined as field(init=False)
d = dict((field.name, getattr(self, field.name)) for field in fields(self) if field.init)
for k, v in d.items():
if isinstance(v, Enum):
d[k] = v.value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment