Unverified Commit 5e8c8eb5 authored by Aaron Gokaslan's avatar Aaron Gokaslan Committed by GitHub
Browse files

Apply ruff flake8-comprehensions (#21694)

parent df06fb1f
......@@ -892,14 +892,12 @@ def main():
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set(
[
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
]
)
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
......
......@@ -756,14 +756,12 @@ def main():
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set(
[
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
]
)
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
......
......@@ -648,14 +648,12 @@ def main():
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set(
[
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
]
)
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
......
......@@ -679,14 +679,12 @@ def main():
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set(
[
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
]
)
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
......
......@@ -791,14 +791,12 @@ def main():
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set(
[
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
]
)
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
......
......@@ -333,14 +333,12 @@ def create_train_state(
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set(
[
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
]
)
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
......@@ -642,7 +640,7 @@ def main():
return tokenized_examples
processed_raw_datasets = dict()
processed_raw_datasets = {}
if training_args.do_train:
if "train" not in raw_datasets:
raise ValueError("--do_train requires a train dataset")
......
......@@ -742,14 +742,12 @@ def main():
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set(
[
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
]
)
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
......
......@@ -229,14 +229,12 @@ def create_train_state(
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set(
[
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
]
)
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
......@@ -449,7 +447,7 @@ def main():
):
# Some have all caps in their config, some don't.
label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
if sorted(label_name_to_id.keys()) == sorted(label_list):
logger.info(
f"The configuration of the model provided the following label correspondence: {label_name_to_id}. "
"Using it!"
......@@ -458,7 +456,7 @@ def main():
else:
logger.warning(
"Your model seems to have been trained with labels, but they don't match the dataset: ",
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
f"model labels: {sorted(label_name_to_id.keys())}, dataset labels: {sorted(label_list)}."
"\nIgnoring the model labels as a result.",
)
elif data_args.task_name is None:
......
......@@ -290,14 +290,12 @@ def create_train_state(
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set(
[
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
]
)
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
......
......@@ -192,7 +192,7 @@ def main():
# Optionally, predict on dev set and write to output_dir
if args.do_predict:
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True)))
checkpoints = sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True))
model = model.load_from_checkpoint(checkpoints[-1])
return trainer.test(model)
......
......@@ -211,6 +211,6 @@ if __name__ == "__main__":
# pl use this default format to create a checkpoint:
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
# /pytorch_lightning/callbacks/model_checkpoint.py#L322
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True)))
checkpoints = sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True))
model = model.load_from_checkpoint(checkpoints[-1])
trainer.test(model)
......@@ -810,10 +810,10 @@ def main():
logger.info("Loading checkpoints saved during training for evaluation")
checkpoints = [args.output_dir]
if args.eval_all_checkpoints:
checkpoints = list(
checkpoints = [
os.path.dirname(c)
for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
)
]
else:
logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path)
......@@ -830,7 +830,7 @@ def main():
# Evaluate
result = evaluate(args, model, tokenizer, prefix=global_step)
result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
result = {k + ("_{}".format(global_step) if global_step else ""): v for k, v in result.items()}
results.update(result)
logger.info("Results: {}".format(results))
......
......@@ -189,7 +189,7 @@ def main():
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
elif isinstance(obj, int):
return obj
return list(tokenize_and_encode(o) for o in obj)
return [tokenize_and_encode(o) for o in obj]
logger.info("Encoding dataset...")
train_dataset = load_rocstories_dataset(args.train_dataset)
......
......@@ -696,9 +696,9 @@ def main():
checkpoints = [args.model_name_or_path]
if args.eval_all_checkpoints:
checkpoints = list(
checkpoints = [
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
)
]
logger.info("Evaluate the following checkpoints: %s", checkpoints)
......@@ -712,7 +712,7 @@ def main():
# Evaluate
result = evaluate(args, model, tokenizer, prefix=global_step)
result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
result = {k + ("_{}".format(global_step) if global_step else ""): v for k, v in result.items()}
results.update(result)
logger.info("Results: {}".format(results))
......
......@@ -111,7 +111,7 @@ def eval_data_dir(
if num_return_sequences > 1:
preds = chunks(preds, num_return_sequences) # batch size chunks, each of size num_return_seq
for i, pred in enumerate(preds):
results.append(dict(pred=pred, id=ids[i].item()))
results.append({"pred": pred, "id": ids[i].item()})
save_json(results, save_path)
return results, sampler.num_replicas
......@@ -232,7 +232,7 @@ def combine_partial_results(partial_results) -> List:
records = []
for partial_result in partial_results:
records.extend(partial_result)
records = list(sorted(records, key=lambda x: x["id"]))
records = sorted(records, key=lambda x: x["id"])
preds = [x["pred"] for x in records]
return preds
......
......@@ -76,7 +76,7 @@ def generate_summaries_or_translations(
fout.close()
runtime = int(time.time() - start_time) # seconds
n_obs = len(examples)
return dict(n_obs=n_obs, runtime=runtime, seconds_per_sample=round(runtime / n_obs, 4))
return {"n_obs": n_obs, "runtime": runtime, "seconds_per_sample": round(runtime / n_obs, 4)}
def datetime_now():
......
......@@ -36,7 +36,7 @@ def parse_search_arg(search):
groups = search.split()
entries = {k: vs for k, vs in (g.split("=") for g in groups)}
entry_names = list(entries.keys())
sets = [list(f"--{k} {v}" for v in vs.split(":")) for k, vs in entries.items()]
sets = [[f"--{k} {v}" for v in vs.split(":")] for k, vs in entries.items()]
matrix = [list(x) for x in itertools.product(*sets)]
return matrix, entry_names
......
......@@ -456,7 +456,7 @@ def pickle_save(obj, path):
def flatten_list(summary_ids: List[List]):
return [x for x in itertools.chain.from_iterable(summary_ids)]
return list(itertools.chain.from_iterable(summary_ids))
def save_git_info(folder_path: str) -> None:
......
......@@ -293,7 +293,7 @@ def main():
audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate
)
output_batch["input_values"].append(wav)
output_batch["labels"] = [label for label in batch[data_args.label_column_name]]
output_batch["labels"] = list(batch[data_args.label_column_name])
return output_batch
......@@ -303,14 +303,14 @@ def main():
for audio in batch[data_args.audio_column_name]:
wav = audio["array"]
output_batch["input_values"].append(wav)
output_batch["labels"] = [label for label in batch[data_args.label_column_name]]
output_batch["labels"] = list(batch[data_args.label_column_name])
return output_batch
# Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API.
labels = raw_datasets["train"].features[data_args.label_column_name].names
label2id, id2label = dict(), dict()
label2id, id2label = {}, {}
for i, label in enumerate(labels):
label2id[label] = str(i)
id2label[str(i)] = label
......
......@@ -83,7 +83,7 @@ def can_convert_to_float(string):
class Plot:
def __init__(self, args):
self.args = args
self.result_dict = defaultdict(lambda: dict(bsz=[], seq_len=[], result={}))
self.result_dict = defaultdict(lambda: {"bsz": [], "seq_len": [], "result": {}})
with open(self.args.csv_file, newline="") as csv_file:
reader = csv.DictReader(csv_file)
......@@ -116,8 +116,8 @@ class Plot:
axis.set_major_formatter(ScalarFormatter())
for model_name_idx, model_name in enumerate(self.result_dict.keys()):
batch_sizes = sorted(list(set(self.result_dict[model_name]["bsz"])))
sequence_lengths = sorted(list(set(self.result_dict[model_name]["seq_len"])))
batch_sizes = sorted(set(self.result_dict[model_name]["bsz"]))
sequence_lengths = sorted(set(self.result_dict[model_name]["seq_len"]))
results = self.result_dict[model_name]["result"]
(x_axis_array, inner_loop_array) = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment