"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "364920e216c16d73c782a61a4cf6652e541fbe18"
Unverified Commit 5e8c8eb5 authored by Aaron Gokaslan's avatar Aaron Gokaslan Committed by GitHub
Browse files

Apply ruff flake8-comprehensions (#21694)

parent df06fb1f
...@@ -892,14 +892,12 @@ def main(): ...@@ -892,14 +892,12 @@ def main():
flat_params = traverse_util.flatten_dict(params) flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters # find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"] layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set( layer_norm_named_params = {
[ layer[-2:]
layer[-2:] for layer_norm_name in layer_norm_candidates
for layer_norm_name in layer_norm_candidates for layer in flat_params.keys()
for layer in flat_params.keys() if layer_norm_name in "".join(layer).lower()
if layer_norm_name in "".join(layer).lower() }
]
)
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask) return traverse_util.unflatten_dict(flat_mask)
......
...@@ -756,14 +756,12 @@ def main(): ...@@ -756,14 +756,12 @@ def main():
flat_params = traverse_util.flatten_dict(params) flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters # find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"] layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set( layer_norm_named_params = {
[ layer[-2:]
layer[-2:] for layer_norm_name in layer_norm_candidates
for layer_norm_name in layer_norm_candidates for layer in flat_params.keys()
for layer in flat_params.keys() if layer_norm_name in "".join(layer).lower()
if layer_norm_name in "".join(layer).lower() }
]
)
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask) return traverse_util.unflatten_dict(flat_mask)
......
...@@ -648,14 +648,12 @@ def main(): ...@@ -648,14 +648,12 @@ def main():
flat_params = traverse_util.flatten_dict(params) flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters # find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"] layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set( layer_norm_named_params = {
[ layer[-2:]
layer[-2:] for layer_norm_name in layer_norm_candidates
for layer_norm_name in layer_norm_candidates for layer in flat_params.keys()
for layer in flat_params.keys() if layer_norm_name in "".join(layer).lower()
if layer_norm_name in "".join(layer).lower() }
]
)
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask) return traverse_util.unflatten_dict(flat_mask)
......
...@@ -679,14 +679,12 @@ def main(): ...@@ -679,14 +679,12 @@ def main():
flat_params = traverse_util.flatten_dict(params) flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters # find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"] layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set( layer_norm_named_params = {
[ layer[-2:]
layer[-2:] for layer_norm_name in layer_norm_candidates
for layer_norm_name in layer_norm_candidates for layer in flat_params.keys()
for layer in flat_params.keys() if layer_norm_name in "".join(layer).lower()
if layer_norm_name in "".join(layer).lower() }
]
)
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask) return traverse_util.unflatten_dict(flat_mask)
......
...@@ -791,14 +791,12 @@ def main(): ...@@ -791,14 +791,12 @@ def main():
flat_params = traverse_util.flatten_dict(params) flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters # find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"] layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set( layer_norm_named_params = {
[ layer[-2:]
layer[-2:] for layer_norm_name in layer_norm_candidates
for layer_norm_name in layer_norm_candidates for layer in flat_params.keys()
for layer in flat_params.keys() if layer_norm_name in "".join(layer).lower()
if layer_norm_name in "".join(layer).lower() }
]
)
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask) return traverse_util.unflatten_dict(flat_mask)
......
...@@ -333,14 +333,12 @@ def create_train_state( ...@@ -333,14 +333,12 @@ def create_train_state(
flat_params = traverse_util.flatten_dict(params) flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters # find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"] layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set( layer_norm_named_params = {
[ layer[-2:]
layer[-2:] for layer_norm_name in layer_norm_candidates
for layer_norm_name in layer_norm_candidates for layer in flat_params.keys()
for layer in flat_params.keys() if layer_norm_name in "".join(layer).lower()
if layer_norm_name in "".join(layer).lower() }
]
)
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask) return traverse_util.unflatten_dict(flat_mask)
...@@ -642,7 +640,7 @@ def main(): ...@@ -642,7 +640,7 @@ def main():
return tokenized_examples return tokenized_examples
processed_raw_datasets = dict() processed_raw_datasets = {}
if training_args.do_train: if training_args.do_train:
if "train" not in raw_datasets: if "train" not in raw_datasets:
raise ValueError("--do_train requires a train dataset") raise ValueError("--do_train requires a train dataset")
......
...@@ -742,14 +742,12 @@ def main(): ...@@ -742,14 +742,12 @@ def main():
flat_params = traverse_util.flatten_dict(params) flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters # find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"] layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set( layer_norm_named_params = {
[ layer[-2:]
layer[-2:] for layer_norm_name in layer_norm_candidates
for layer_norm_name in layer_norm_candidates for layer in flat_params.keys()
for layer in flat_params.keys() if layer_norm_name in "".join(layer).lower()
if layer_norm_name in "".join(layer).lower() }
]
)
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask) return traverse_util.unflatten_dict(flat_mask)
......
...@@ -229,14 +229,12 @@ def create_train_state( ...@@ -229,14 +229,12 @@ def create_train_state(
flat_params = traverse_util.flatten_dict(params) flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters # find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"] layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set( layer_norm_named_params = {
[ layer[-2:]
layer[-2:] for layer_norm_name in layer_norm_candidates
for layer_norm_name in layer_norm_candidates for layer in flat_params.keys()
for layer in flat_params.keys() if layer_norm_name in "".join(layer).lower()
if layer_norm_name in "".join(layer).lower() }
]
)
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask) return traverse_util.unflatten_dict(flat_mask)
...@@ -449,7 +447,7 @@ def main(): ...@@ -449,7 +447,7 @@ def main():
): ):
# Some have all caps in their config, some don't. # Some have all caps in their config, some don't.
label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): if sorted(label_name_to_id.keys()) == sorted(label_list):
logger.info( logger.info(
f"The configuration of the model provided the following label correspondence: {label_name_to_id}. " f"The configuration of the model provided the following label correspondence: {label_name_to_id}. "
"Using it!" "Using it!"
...@@ -458,7 +456,7 @@ def main(): ...@@ -458,7 +456,7 @@ def main():
else: else:
logger.warning( logger.warning(
"Your model seems to have been trained with labels, but they don't match the dataset: ", "Your model seems to have been trained with labels, but they don't match the dataset: ",
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." f"model labels: {sorted(label_name_to_id.keys())}, dataset labels: {sorted(label_list)}."
"\nIgnoring the model labels as a result.", "\nIgnoring the model labels as a result.",
) )
elif data_args.task_name is None: elif data_args.task_name is None:
......
...@@ -290,14 +290,12 @@ def create_train_state( ...@@ -290,14 +290,12 @@ def create_train_state(
flat_params = traverse_util.flatten_dict(params) flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters # find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"] layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set( layer_norm_named_params = {
[ layer[-2:]
layer[-2:] for layer_norm_name in layer_norm_candidates
for layer_norm_name in layer_norm_candidates for layer in flat_params.keys()
for layer in flat_params.keys() if layer_norm_name in "".join(layer).lower()
if layer_norm_name in "".join(layer).lower() }
]
)
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask) return traverse_util.unflatten_dict(flat_mask)
......
...@@ -192,7 +192,7 @@ def main(): ...@@ -192,7 +192,7 @@ def main():
# Optionally, predict on dev set and write to output_dir # Optionally, predict on dev set and write to output_dir
if args.do_predict: if args.do_predict:
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True))) checkpoints = sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True))
model = model.load_from_checkpoint(checkpoints[-1]) model = model.load_from_checkpoint(checkpoints[-1])
return trainer.test(model) return trainer.test(model)
......
...@@ -211,6 +211,6 @@ if __name__ == "__main__": ...@@ -211,6 +211,6 @@ if __name__ == "__main__":
# pl use this default format to create a checkpoint: # pl use this default format to create a checkpoint:
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\ # https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
# /pytorch_lightning/callbacks/model_checkpoint.py#L322 # /pytorch_lightning/callbacks/model_checkpoint.py#L322
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True))) checkpoints = sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True))
model = model.load_from_checkpoint(checkpoints[-1]) model = model.load_from_checkpoint(checkpoints[-1])
trainer.test(model) trainer.test(model)
...@@ -810,10 +810,10 @@ def main(): ...@@ -810,10 +810,10 @@ def main():
logger.info("Loading checkpoints saved during training for evaluation") logger.info("Loading checkpoints saved during training for evaluation")
checkpoints = [args.output_dir] checkpoints = [args.output_dir]
if args.eval_all_checkpoints: if args.eval_all_checkpoints:
checkpoints = list( checkpoints = [
os.path.dirname(c) os.path.dirname(c)
for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True)) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
) ]
else: else:
logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path) logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path)
...@@ -830,7 +830,7 @@ def main(): ...@@ -830,7 +830,7 @@ def main():
# Evaluate # Evaluate
result = evaluate(args, model, tokenizer, prefix=global_step) result = evaluate(args, model, tokenizer, prefix=global_step)
result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items()) result = {k + ("_{}".format(global_step) if global_step else ""): v for k, v in result.items()}
results.update(result) results.update(result)
logger.info("Results: {}".format(results)) logger.info("Results: {}".format(results))
......
...@@ -189,7 +189,7 @@ def main(): ...@@ -189,7 +189,7 @@ def main():
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj)) return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
elif isinstance(obj, int): elif isinstance(obj, int):
return obj return obj
return list(tokenize_and_encode(o) for o in obj) return [tokenize_and_encode(o) for o in obj]
logger.info("Encoding dataset...") logger.info("Encoding dataset...")
train_dataset = load_rocstories_dataset(args.train_dataset) train_dataset = load_rocstories_dataset(args.train_dataset)
......
...@@ -696,9 +696,9 @@ def main(): ...@@ -696,9 +696,9 @@ def main():
checkpoints = [args.model_name_or_path] checkpoints = [args.model_name_or_path]
if args.eval_all_checkpoints: if args.eval_all_checkpoints:
checkpoints = list( checkpoints = [
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True)) os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
) ]
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
...@@ -712,7 +712,7 @@ def main(): ...@@ -712,7 +712,7 @@ def main():
# Evaluate # Evaluate
result = evaluate(args, model, tokenizer, prefix=global_step) result = evaluate(args, model, tokenizer, prefix=global_step)
result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items()) result = {k + ("_{}".format(global_step) if global_step else ""): v for k, v in result.items()}
results.update(result) results.update(result)
logger.info("Results: {}".format(results)) logger.info("Results: {}".format(results))
......
...@@ -111,7 +111,7 @@ def eval_data_dir( ...@@ -111,7 +111,7 @@ def eval_data_dir(
if num_return_sequences > 1: if num_return_sequences > 1:
preds = chunks(preds, num_return_sequences) # batch size chunks, each of size num_return_seq preds = chunks(preds, num_return_sequences) # batch size chunks, each of size num_return_seq
for i, pred in enumerate(preds): for i, pred in enumerate(preds):
results.append(dict(pred=pred, id=ids[i].item())) results.append({"pred": pred, "id": ids[i].item()})
save_json(results, save_path) save_json(results, save_path)
return results, sampler.num_replicas return results, sampler.num_replicas
...@@ -232,7 +232,7 @@ def combine_partial_results(partial_results) -> List: ...@@ -232,7 +232,7 @@ def combine_partial_results(partial_results) -> List:
records = [] records = []
for partial_result in partial_results: for partial_result in partial_results:
records.extend(partial_result) records.extend(partial_result)
records = list(sorted(records, key=lambda x: x["id"])) records = sorted(records, key=lambda x: x["id"])
preds = [x["pred"] for x in records] preds = [x["pred"] for x in records]
return preds return preds
......
...@@ -76,7 +76,7 @@ def generate_summaries_or_translations( ...@@ -76,7 +76,7 @@ def generate_summaries_or_translations(
fout.close() fout.close()
runtime = int(time.time() - start_time) # seconds runtime = int(time.time() - start_time) # seconds
n_obs = len(examples) n_obs = len(examples)
return dict(n_obs=n_obs, runtime=runtime, seconds_per_sample=round(runtime / n_obs, 4)) return {"n_obs": n_obs, "runtime": runtime, "seconds_per_sample": round(runtime / n_obs, 4)}
def datetime_now(): def datetime_now():
......
...@@ -36,7 +36,7 @@ def parse_search_arg(search): ...@@ -36,7 +36,7 @@ def parse_search_arg(search):
groups = search.split() groups = search.split()
entries = {k: vs for k, vs in (g.split("=") for g in groups)} entries = {k: vs for k, vs in (g.split("=") for g in groups)}
entry_names = list(entries.keys()) entry_names = list(entries.keys())
sets = [list(f"--{k} {v}" for v in vs.split(":")) for k, vs in entries.items()] sets = [[f"--{k} {v}" for v in vs.split(":")] for k, vs in entries.items()]
matrix = [list(x) for x in itertools.product(*sets)] matrix = [list(x) for x in itertools.product(*sets)]
return matrix, entry_names return matrix, entry_names
......
...@@ -456,7 +456,7 @@ def pickle_save(obj, path): ...@@ -456,7 +456,7 @@ def pickle_save(obj, path):
def flatten_list(summary_ids: List[List]): def flatten_list(summary_ids: List[List]):
return [x for x in itertools.chain.from_iterable(summary_ids)] return list(itertools.chain.from_iterable(summary_ids))
def save_git_info(folder_path: str) -> None: def save_git_info(folder_path: str) -> None:
......
...@@ -293,7 +293,7 @@ def main(): ...@@ -293,7 +293,7 @@ def main():
audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate
) )
output_batch["input_values"].append(wav) output_batch["input_values"].append(wav)
output_batch["labels"] = [label for label in batch[data_args.label_column_name]] output_batch["labels"] = list(batch[data_args.label_column_name])
return output_batch return output_batch
...@@ -303,14 +303,14 @@ def main(): ...@@ -303,14 +303,14 @@ def main():
for audio in batch[data_args.audio_column_name]: for audio in batch[data_args.audio_column_name]:
wav = audio["array"] wav = audio["array"]
output_batch["input_values"].append(wav) output_batch["input_values"].append(wav)
output_batch["labels"] = [label for label in batch[data_args.label_column_name]] output_batch["labels"] = list(batch[data_args.label_column_name])
return output_batch return output_batch
# Prepare label mappings. # Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API. # We'll include these in the model's config to get human readable labels in the Inference API.
labels = raw_datasets["train"].features[data_args.label_column_name].names labels = raw_datasets["train"].features[data_args.label_column_name].names
label2id, id2label = dict(), dict() label2id, id2label = {}, {}
for i, label in enumerate(labels): for i, label in enumerate(labels):
label2id[label] = str(i) label2id[label] = str(i)
id2label[str(i)] = label id2label[str(i)] = label
......
...@@ -83,7 +83,7 @@ def can_convert_to_float(string): ...@@ -83,7 +83,7 @@ def can_convert_to_float(string):
class Plot: class Plot:
def __init__(self, args): def __init__(self, args):
self.args = args self.args = args
self.result_dict = defaultdict(lambda: dict(bsz=[], seq_len=[], result={})) self.result_dict = defaultdict(lambda: {"bsz": [], "seq_len": [], "result": {}})
with open(self.args.csv_file, newline="") as csv_file: with open(self.args.csv_file, newline="") as csv_file:
reader = csv.DictReader(csv_file) reader = csv.DictReader(csv_file)
...@@ -116,8 +116,8 @@ class Plot: ...@@ -116,8 +116,8 @@ class Plot:
axis.set_major_formatter(ScalarFormatter()) axis.set_major_formatter(ScalarFormatter())
for model_name_idx, model_name in enumerate(self.result_dict.keys()): for model_name_idx, model_name in enumerate(self.result_dict.keys()):
batch_sizes = sorted(list(set(self.result_dict[model_name]["bsz"]))) batch_sizes = sorted(set(self.result_dict[model_name]["bsz"]))
sequence_lengths = sorted(list(set(self.result_dict[model_name]["seq_len"]))) sequence_lengths = sorted(set(self.result_dict[model_name]["seq_len"]))
results = self.result_dict[model_name]["result"] results = self.result_dict[model_name]["result"]
(x_axis_array, inner_loop_array) = ( (x_axis_array, inner_loop_array) = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment