Unverified Commit 5e8c8eb5 authored by Aaron Gokaslan's avatar Aaron Gokaslan Committed by GitHub
Browse files

Apply ruff flake8-comprehensions (#21694)

parent df06fb1f
...@@ -145,18 +145,18 @@ class TestSummarizationDistiller(TestCasePlus): ...@@ -145,18 +145,18 @@ class TestSummarizationDistiller(TestCasePlus):
assert not failures, f"The following models could not be loaded through AutoConfig: {failures}" assert not failures, f"The following models could not be loaded through AutoConfig: {failures}"
def test_distill_no_teacher(self): def test_distill_no_teacher(self):
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True) updates = {"student_encoder_layers": 2, "student_decoder_layers": 1, "no_teacher": True}
self._test_distiller_cli(updates) self._test_distiller_cli(updates)
def test_distill_checkpointing_with_teacher(self): def test_distill_checkpointing_with_teacher(self):
updates = dict( updates = {
student_encoder_layers=2, "student_encoder_layers": 2,
student_decoder_layers=1, "student_decoder_layers": 1,
max_epochs=4, "max_epochs": 4,
val_check_interval=0.25, "val_check_interval": 0.25,
alpha_hid=2.0, "alpha_hid": 2.0,
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED", "model_name_or_path": "IGNORE_THIS_IT_DOESNT_GET_USED",
) }
model = self._test_distiller_cli(updates, check_contents=False) model = self._test_distiller_cli(updates, check_contents=False)
ckpts = list(Path(model.output_dir).glob("*.ckpt")) ckpts = list(Path(model.output_dir).glob("*.ckpt"))
...@@ -193,19 +193,19 @@ class TestSummarizationDistiller(TestCasePlus): ...@@ -193,19 +193,19 @@ class TestSummarizationDistiller(TestCasePlus):
self.assertEqual(nll_loss, model_computed_loss) self.assertEqual(nll_loss, model_computed_loss)
def test_distill_mbart(self): def test_distill_mbart(self):
updates = dict( updates = {
student_encoder_layers=2, "student_encoder_layers": 2,
student_decoder_layers=1, "student_decoder_layers": 1,
num_train_epochs=4, "num_train_epochs": 4,
val_check_interval=0.25, "val_check_interval": 0.25,
alpha_hid=2.0, "alpha_hid": 2.0,
task="translation", "task": "translation",
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED", "model_name_or_path": "IGNORE_THIS_IT_DOESNT_GET_USED",
tokenizer_name=MBART_TINY, "tokenizer_name": MBART_TINY,
teacher=MBART_TINY, "teacher": MBART_TINY,
src_lang="en_XX", "src_lang": "en_XX",
tgt_lang="ro_RO", "tgt_lang": "ro_RO",
) }
model = self._test_distiller_cli(updates, check_contents=False) model = self._test_distiller_cli(updates, check_contents=False)
assert model.model.config.model_type == "mbart" assert model.model.config.model_type == "mbart"
...@@ -217,39 +217,39 @@ class TestSummarizationDistiller(TestCasePlus): ...@@ -217,39 +217,39 @@ class TestSummarizationDistiller(TestCasePlus):
self.assertEqual(len(transformer_ckpts), 2) self.assertEqual(len(transformer_ckpts), 2)
def test_distill_t5(self): def test_distill_t5(self):
updates = dict( updates = {
student_encoder_layers=1, "student_encoder_layers": 1,
student_decoder_layers=1, "student_decoder_layers": 1,
alpha_hid=2.0, "alpha_hid": 2.0,
teacher=T5_TINY, "teacher": T5_TINY,
model_name_or_path=T5_TINY, "model_name_or_path": T5_TINY,
tokenizer_name=T5_TINY, "tokenizer_name": T5_TINY,
) }
self._test_distiller_cli(updates) self._test_distiller_cli(updates)
def test_distill_different_base_models(self): def test_distill_different_base_models(self):
updates = dict( updates = {
teacher=T5_TINY, "teacher": T5_TINY,
student=T5_TINIER, "student": T5_TINIER,
model_name_or_path=T5_TINIER, "model_name_or_path": T5_TINIER,
tokenizer_name=T5_TINIER, "tokenizer_name": T5_TINIER,
) }
self._test_distiller_cli(updates) self._test_distiller_cli(updates)
def _test_distiller_cli(self, updates, check_contents=True): def _test_distiller_cli(self, updates, check_contents=True):
default_updates = dict( default_updates = {
label_smoothing=0.0, "label_smoothing": 0.0,
early_stopping_patience=-1, "early_stopping_patience": -1,
train_batch_size=1, "train_batch_size": 1,
eval_batch_size=2, "eval_batch_size": 2,
max_epochs=2, "max_epochs": 2,
alpha_mlm=0.2, "alpha_mlm": 0.2,
alpha_ce=0.8, "alpha_ce": 0.8,
do_predict=True, "do_predict": True,
model_name_or_path="sshleifer/tinier_bart", "model_name_or_path": "sshleifer/tinier_bart",
teacher=CHEAP_ARGS["model_name_or_path"], "teacher": CHEAP_ARGS["model_name_or_path"],
val_check_interval=0.5, "val_check_interval": 0.5,
) }
default_updates.update(updates) default_updates.update(updates)
args_d: dict = CHEAP_ARGS.copy() args_d: dict = CHEAP_ARGS.copy()
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()) tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
......
...@@ -98,29 +98,29 @@ class TestSummarizationDistillerMultiGPU(TestCasePlus): ...@@ -98,29 +98,29 @@ class TestSummarizationDistillerMultiGPU(TestCasePlus):
@require_torch_multi_gpu @require_torch_multi_gpu
def test_multi_gpu(self): def test_multi_gpu(self):
updates = dict( updates = {
no_teacher=True, "no_teacher": True,
freeze_encoder=True, "freeze_encoder": True,
gpus=2, "gpus": 2,
overwrite_output_dir=True, "overwrite_output_dir": True,
sortish_sampler=True, "sortish_sampler": True,
) }
self._test_distiller_cli_fork(updates, check_contents=False) self._test_distiller_cli_fork(updates, check_contents=False)
def _test_distiller_cli_fork(self, updates, check_contents=True): def _test_distiller_cli_fork(self, updates, check_contents=True):
default_updates = dict( default_updates = {
label_smoothing=0.0, "label_smoothing": 0.0,
early_stopping_patience=-1, "early_stopping_patience": -1,
train_batch_size=1, "train_batch_size": 1,
eval_batch_size=2, "eval_batch_size": 2,
max_epochs=2, "max_epochs": 2,
alpha_mlm=0.2, "alpha_mlm": 0.2,
alpha_ce=0.8, "alpha_ce": 0.8,
do_predict=True, "do_predict": True,
model_name_or_path="sshleifer/tinier_bart", "model_name_or_path": "sshleifer/tinier_bart",
teacher=CHEAP_ARGS["model_name_or_path"], "teacher": CHEAP_ARGS["model_name_or_path"],
val_check_interval=0.5, "val_check_interval": 0.5,
) }
default_updates.update(updates) default_updates.update(updates)
args_d: dict = CHEAP_ARGS.copy() args_d: dict = CHEAP_ARGS.copy()
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()) tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
......
...@@ -74,11 +74,11 @@ class SummarizationModule(BaseTransformer): ...@@ -74,11 +74,11 @@ class SummarizationModule(BaseTransformer):
self.model_type = self.config.model_type self.model_type = self.config.model_type
self.vocab_size = self.config.tgt_vocab_size if self.model_type == "fsmt" else self.config.vocab_size self.vocab_size = self.config.tgt_vocab_size if self.model_type == "fsmt" else self.config.vocab_size
self.dataset_kwargs: dict = dict( self.dataset_kwargs: dict = {
data_dir=self.hparams.data_dir, "data_dir": self.hparams.data_dir,
max_source_length=self.hparams.max_source_length, "max_source_length": self.hparams.max_source_length,
prefix=self.model.config.prefix or "", "prefix": self.model.config.prefix or "",
) }
n_observations_per_split = { n_observations_per_split = {
"train": self.hparams.n_train, "train": self.hparams.n_train,
"val": self.hparams.n_val, "val": self.hparams.n_val,
...@@ -433,7 +433,7 @@ def main(args, model=None) -> SummarizationModule: ...@@ -433,7 +433,7 @@ def main(args, model=None) -> SummarizationModule:
return model return model
model.hparams.test_checkpoint = "" model.hparams.test_checkpoint = ""
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True))) checkpoints = sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True))
if checkpoints: if checkpoints:
model.hparams.test_checkpoint = checkpoints[-1] model.hparams.test_checkpoint = checkpoints[-1]
trainer.resume_from_checkpoint = checkpoints[-1] trainer.resume_from_checkpoint = checkpoints[-1]
......
...@@ -171,11 +171,11 @@ def create_student_by_copying_alternating_layers( ...@@ -171,11 +171,11 @@ def create_student_by_copying_alternating_layers(
logger.info( logger.info(
f"Copied encoder layers {e_layers_to_copy} and decoder layers {d_layers_to_copy}. Saving them to {save_path}" f"Copied encoder layers {e_layers_to_copy} and decoder layers {d_layers_to_copy}. Saving them to {save_path}"
) )
student.config.init_metadata = dict( student.config.init_metadata = {
teacher_type=teacher.config.model_type, "teacher_type": teacher.config.model_type,
copied_encoder_layers=e_layers_to_copy, "copied_encoder_layers": e_layers_to_copy,
copied_decoder_layers=d_layers_to_copy, "copied_decoder_layers": d_layers_to_copy,
) }
student.save_pretrained(save_path) student.save_pretrained(save_path)
# Save information about copying for easier reproducibility # Save information about copying for easier reproducibility
......
...@@ -63,7 +63,7 @@ def generate_summaries_or_translations( ...@@ -63,7 +63,7 @@ def generate_summaries_or_translations(
fout.close() fout.close()
runtime = int(time.time() - start_time) # seconds runtime = int(time.time() - start_time) # seconds
n_obs = len(examples) n_obs = len(examples)
return dict(n_obs=n_obs, runtime=runtime, seconds_per_sample=round(runtime / n_obs, 4)) return {"n_obs": n_obs, "runtime": runtime, "seconds_per_sample": round(runtime / n_obs, 4)}
def datetime_now(): def datetime_now():
......
...@@ -437,7 +437,7 @@ def pickle_save(obj, path): ...@@ -437,7 +437,7 @@ def pickle_save(obj, path):
def flatten_list(summary_ids: List[List]): def flatten_list(summary_ids: List[List]):
return [x for x in itertools.chain.from_iterable(summary_ids)] return list(itertools.chain.from_iterable(summary_ids))
def save_git_info(folder_path: str) -> None: def save_git_info(folder_path: str) -> None:
......
...@@ -30,7 +30,7 @@ EMPTY_ANSWER_AGG = "none" ...@@ -30,7 +30,7 @@ EMPTY_ANSWER_AGG = "none"
def _split_thousands(delimiter, value): def _split_thousands(delimiter, value):
split = value.split(delimiter) split = value.split(delimiter)
return len(split) > 1 and any(map(lambda x: len(x) == 3, split)) return len(split) > 1 and any((len(x) == 3 for x in split))
def convert_to_float(value): def convert_to_float(value):
...@@ -123,7 +123,7 @@ _TOKENIZER = re.compile(r"\w+|[^\w\s]+", re.UNICODE | re.MULTILINE | re.DOTALL) ...@@ -123,7 +123,7 @@ _TOKENIZER = re.compile(r"\w+|[^\w\s]+", re.UNICODE | re.MULTILINE | re.DOTALL)
def _normalize_for_match(x): def _normalize_for_match(x):
return [t for t in _TOKENIZER.findall(x.lower())] return list(_TOKENIZER.findall(x.lower()))
def _compare(operator, src, tgt): def _compare(operator, src, tgt):
......
...@@ -61,7 +61,7 @@ class Extract: ...@@ -61,7 +61,7 @@ class Extract:
assert outputfile is not None and not os.path.isfile(outputfile), f"{outputfile}" assert outputfile is not None and not os.path.isfile(outputfile), f"{outputfile}"
if subset_list is not None: if subset_list is not None:
with open(os.path.realpath(subset_list)) as f: with open(os.path.realpath(subset_list)) as f:
self.subset_list = set(map(lambda x: self._vqa_file_split()[0], tryload(f))) self.subset_list = {self._vqa_file_split()[0] for x in tryload(f)}
else: else:
self.subset_list = None self.subset_list = None
......
...@@ -1095,7 +1095,7 @@ class ROIPooler(nn.Module): ...@@ -1095,7 +1095,7 @@ class ROIPooler(nn.Module):
Returns: Returns:
A tensor of shape(N*B, Channels, output_size, output_size) A tensor of shape(N*B, Channels, output_size, output_size)
""" """
x = [v for v in feature_maps.values()] x = list(feature_maps.values())
num_level_assignments = len(self.level_poolers) num_level_assignments = len(self.level_poolers)
assert len(x) == num_level_assignments and len(boxes) == x[0].size(0) assert len(x) == num_level_assignments and len(boxes) == x[0].size(0)
......
...@@ -99,7 +99,7 @@ class VQGAN_CLIP(nn.Module): ...@@ -99,7 +99,7 @@ class VQGAN_CLIP(nn.Module):
output_path = "./animation.gif" output_path = "./animation.gif"
if input_path is None: if input_path is None:
input_path = self.save_path input_path = self.save_path
paths = list(sorted(glob(input_path + "/*"))) paths = sorted(glob(input_path + "/*"))
if not len(paths): if not len(paths):
raise ValueError( raise ValueError(
"No images found in save path, aborting (did you pass save_intermediate=True to the generate" "No images found in save path, aborting (did you pass save_intermediate=True to the generate"
...@@ -178,7 +178,7 @@ class VQGAN_CLIP(nn.Module): ...@@ -178,7 +178,7 @@ class VQGAN_CLIP(nn.Module):
wandb.init(reinit=True, project="face-editor") wandb.init(reinit=True, project="face-editor")
wandb.config.update({"Positive Prompts": positive_prompts}) wandb.config.update({"Positive Prompts": positive_prompts})
wandb.config.update({"Negative Prompts": negative_prompts}) wandb.config.update({"Negative Prompts": negative_prompts})
wandb.config.update(dict(lr=self.lr, iterations=self.iterations)) wandb.config.update({"lr": self.lr, "iterations": self.iterations})
if image_path: if image_path:
image = Image.open(image_path) image = Image.open(image_path)
image = image.resize((256, 256)) image = image.resize((256, 256))
......
...@@ -47,7 +47,7 @@ def get_obj_from_str(string, reload=False): ...@@ -47,7 +47,7 @@ def get_obj_from_str(string, reload=False):
def instantiate_from_config(config): def instantiate_from_config(config):
if "target" not in config: if "target" not in config:
raise KeyError("Expected key `target` to instantiate.") raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict())) return get_obj_from_str(config["target"])(**config.get("params", {}))
def load_model_from_config(config, sd, gpu=True, eval_mode=True): def load_model_from_config(config, sd, gpu=True, eval_mode=True):
......
...@@ -51,7 +51,7 @@ from transformers.trainer_utils import set_seed # noqa ...@@ -51,7 +51,7 @@ from transformers.trainer_utils import set_seed # noqa
set_seed(42) set_seed(42)
models = dict(base="patrickvonplaten/wav2vec2_tiny_random", robust="patrickvonplaten/wav2vec2_tiny_random_robust") models = {"base": "patrickvonplaten/wav2vec2_tiny_random", "robust": "patrickvonplaten/wav2vec2_tiny_random_robust"}
ZERO2 = "zero2" ZERO2 = "zero2"
ZERO3 = "zero3" ZERO3 = "zero3"
......
...@@ -400,7 +400,7 @@ def create_vocabulary_from_data( ...@@ -400,7 +400,7 @@ def create_vocabulary_from_data(
| (set(vocabs["predict"]["vocab"][0]) if "predict" in vocabs else set()) | (set(vocabs["predict"]["vocab"][0]) if "predict" in vocabs else set())
) )
vocab_dict = {v: k for k, v in enumerate(sorted(list(vocab_set)))} vocab_dict = {v: k for k, v in enumerate(sorted(vocab_set))}
# replace white space with delimiter token # replace white space with delimiter token
if word_delimiter_token is not None: if word_delimiter_token is not None:
......
...@@ -83,7 +83,7 @@ def can_convert_to_float(string): ...@@ -83,7 +83,7 @@ def can_convert_to_float(string):
class Plot: class Plot:
def __init__(self, args): def __init__(self, args):
self.args = args self.args = args
self.result_dict = defaultdict(lambda: dict(bsz=[], seq_len=[], result={})) self.result_dict = defaultdict(lambda: {"bsz": [], "seq_len": [], "result": {}})
with open(self.args.csv_file, newline="") as csv_file: with open(self.args.csv_file, newline="") as csv_file:
reader = csv.DictReader(csv_file) reader = csv.DictReader(csv_file)
...@@ -116,8 +116,8 @@ class Plot: ...@@ -116,8 +116,8 @@ class Plot:
axis.set_major_formatter(ScalarFormatter()) axis.set_major_formatter(ScalarFormatter())
for model_name_idx, model_name in enumerate(self.result_dict.keys()): for model_name_idx, model_name in enumerate(self.result_dict.keys()):
batch_sizes = sorted(list(set(self.result_dict[model_name]["bsz"]))) batch_sizes = sorted(set(self.result_dict[model_name]["bsz"]))
sequence_lengths = sorted(list(set(self.result_dict[model_name]["seq_len"]))) sequence_lengths = sorted(set(self.result_dict[model_name]["seq_len"]))
results = self.result_dict[model_name]["result"] results = self.result_dict[model_name]["result"]
(x_axis_array, inner_loop_array) = ( (x_axis_array, inner_loop_array) = (
......
...@@ -300,7 +300,7 @@ def main(): ...@@ -300,7 +300,7 @@ def main():
# Prepare label mappings. # Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API. # We'll include these in the model's config to get human readable labels in the Inference API.
labels = dataset["train"].features["labels"].names labels = dataset["train"].features["labels"].names
label2id, id2label = dict(), dict() label2id, id2label = {}, {}
for i, label in enumerate(labels): for i, label in enumerate(labels):
label2id[label] = str(i) label2id[label] = str(i)
id2label[str(i)] = label id2label[str(i)] = label
......
...@@ -600,7 +600,7 @@ def main(): ...@@ -600,7 +600,7 @@ def main():
if training_args.output_dir is not None: if training_args.output_dir is not None:
output_eval_file = os.path.join(training_args.output_dir, "all_results.json") output_eval_file = os.path.join(training_args.output_dir, "all_results.json")
results_dict = dict() results_dict = {}
results_dict["train_loss"] = train_loss results_dict["train_loss"] = train_loss
results_dict["train_perplexity"] = train_perplexity results_dict["train_perplexity"] = train_perplexity
results_dict["eval_loss"] = validation_loss results_dict["eval_loss"] = validation_loss
......
...@@ -623,7 +623,7 @@ def main(): ...@@ -623,7 +623,7 @@ def main():
if training_args.output_dir is not None: if training_args.output_dir is not None:
output_eval_file = os.path.join(training_args.output_dir, "all_results.json") output_eval_file = os.path.join(training_args.output_dir, "all_results.json")
results_dict = dict() results_dict = {}
results_dict["train_loss"] = train_loss results_dict["train_loss"] = train_loss
results_dict["train_perplexity"] = train_perplexity results_dict["train_perplexity"] = train_perplexity
results_dict["eval_loss"] = validation_loss results_dict["eval_loss"] = validation_loss
......
...@@ -464,7 +464,7 @@ def main(): ...@@ -464,7 +464,7 @@ def main():
return tokenized_examples return tokenized_examples
processed_datasets = dict() processed_datasets = {}
if training_args.do_train: if training_args.do_train:
if "train" not in datasets: if "train" not in datasets:
raise ValueError("--do_train requires a train dataset") raise ValueError("--do_train requires a train dataset")
......
...@@ -310,12 +310,12 @@ def main(): ...@@ -310,12 +310,12 @@ def main():
if config.label2id != PretrainedConfig(num_labels=num_labels).label2id and not is_regression: if config.label2id != PretrainedConfig(num_labels=num_labels).label2id and not is_regression:
# Some have all caps in their config, some don't. # Some have all caps in their config, some don't.
label_name_to_id = {k.lower(): v for k, v in config.label2id.items()} label_name_to_id = {k.lower(): v for k, v in config.label2id.items()}
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): if sorted(label_name_to_id.keys()) == sorted(label_list):
label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)} label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)}
else: else:
logger.warning( logger.warning(
"Your model seems to have been trained with labels, but they don't match the dataset: ", "Your model seems to have been trained with labels, but they don't match the dataset: ",
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." f"model labels: {sorted(label_name_to_id.keys())}, dataset labels: {sorted(label_list)}."
"\nIgnoring the model labels as a result.", "\nIgnoring the model labels as a result.",
) )
label_to_id = {label: i for i, label in enumerate(label_list)} label_to_id = {label: i for i, label in enumerate(label_list)}
...@@ -383,7 +383,7 @@ def main(): ...@@ -383,7 +383,7 @@ def main():
dataset_options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF dataset_options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
num_replicas = training_args.strategy.num_replicas_in_sync num_replicas = training_args.strategy.num_replicas_in_sync
tf_data = dict() tf_data = {}
max_samples = { max_samples = {
"train": data_args.max_train_samples, "train": data_args.max_train_samples,
"validation": data_args.max_eval_samples, "validation": data_args.max_eval_samples,
......
...@@ -343,13 +343,13 @@ def main(): ...@@ -343,13 +343,13 @@ def main():
if "train" in datasets: if "train" in datasets:
if not is_regression and config.label2id != PretrainedConfig(num_labels=num_labels).label2id: if not is_regression and config.label2id != PretrainedConfig(num_labels=num_labels).label2id:
label_name_to_id = config.label2id label_name_to_id = config.label2id
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): if sorted(label_name_to_id.keys()) == sorted(label_list):
label_to_id = label_name_to_id # Use the model's labels label_to_id = label_name_to_id # Use the model's labels
else: else:
logger.warning( logger.warning(
"Your model seems to have been trained with labels, but they don't match the dataset: ", "Your model seems to have been trained with labels, but they don't match the dataset: ",
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels:" f"model labels: {sorted(label_name_to_id.keys())}, dataset labels:"
f" {list(sorted(label_list))}.\nIgnoring the model labels as a result.", f" {sorted(label_list)}.\nIgnoring the model labels as a result.",
) )
label_to_id = {v: i for i, v in enumerate(label_list)} label_to_id = {v: i for i, v in enumerate(label_list)}
elif not is_regression: elif not is_regression:
...@@ -411,7 +411,7 @@ def main(): ...@@ -411,7 +411,7 @@ def main():
dataset_options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF dataset_options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
num_replicas = training_args.strategy.num_replicas_in_sync num_replicas = training_args.strategy.num_replicas_in_sync
tf_data = dict() tf_data = {}
max_samples = { max_samples = {
"train": data_args.max_train_samples, "train": data_args.max_train_samples,
"validation": data_args.max_val_samples, "validation": data_args.max_val_samples,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment