Unverified Commit 5e8c8eb5 authored by Aaron Gokaslan's avatar Aaron Gokaslan Committed by GitHub
Browse files

Apply ruff flake8-comprehensions (#21694)

parent df06fb1f
...@@ -685,9 +685,9 @@ def main(): ...@@ -685,9 +685,9 @@ def main():
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
checkpoints = [args.output_dir] checkpoints = [args.output_dir]
if args.eval_all_checkpoints: if args.eval_all_checkpoints:
checkpoints = list( checkpoints = [
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True)) os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
) ]
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints: for checkpoint in checkpoints:
...@@ -725,7 +725,7 @@ def main(): ...@@ -725,7 +725,7 @@ def main():
for i in range(model.num_layers): for i in range(model.num_layers):
info_str += " {:.2f}".format(100 * each_layer_results[i]) info_str += " {:.2f}".format(100 * each_layer_results[i])
logger.info(info_str) logger.info(info_str)
result = dict((k + "_{}".format(global_step), v) for k, v in result.items()) result = {k + "_{}".format(global_step): v for k, v in result.items()}
results.update(result) results.update(result)
return results return results
......
...@@ -27,7 +27,7 @@ from utils import logger ...@@ -27,7 +27,7 @@ from utils import logger
def _quantize(x, bins): def _quantize(x, bins):
bins = copy.deepcopy(bins) bins = copy.deepcopy(bins)
bins = sorted(bins) bins = sorted(bins)
quantized = list(map(lambda y: bisect.bisect_right(bins, y), x)) quantized = [bisect.bisect_right(bins, y) for y in x]
return quantized return quantized
......
...@@ -850,9 +850,9 @@ def main(): ...@@ -850,9 +850,9 @@ def main():
logger.info("Loading checkpoints saved during training for evaluation") logger.info("Loading checkpoints saved during training for evaluation")
checkpoints = [args.output_dir] checkpoints = [args.output_dir]
if args.eval_all_checkpoints: if args.eval_all_checkpoints:
checkpoints = list( checkpoints = [
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True)) os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
) ]
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
...@@ -865,7 +865,7 @@ def main(): ...@@ -865,7 +865,7 @@ def main():
# Evaluate # Evaluate
result = evaluate(args, model, tokenizer, prefix=global_step) result = evaluate(args, model, tokenizer, prefix=global_step)
result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items()) result = {k + ("_{}".format(global_step) if global_step else ""): v for k, v in result.items()}
results.update(result) results.update(result)
logger.info("Results: {}".format(results)) logger.info("Results: {}".format(results))
......
...@@ -247,9 +247,12 @@ class Trainer: ...@@ -247,9 +247,12 @@ class Trainer:
lr = self.scheduler_fn(state_step - 1) lr = self.scheduler_fn(state_step - 1)
eval_loss = self.evaluate(state, val_dataset) eval_loss = self.evaluate(state, val_dataset)
logging_dict = dict( logging_dict = {
step=state_step.item(), eval_loss=eval_loss.item(), tr_loss=tr_loss, lr=lr.item() "step": state_step.item(),
) "eval_loss": eval_loss.item(),
"tr_loss": tr_loss,
"lr": lr.item(),
}
tqdm.write(str(logging_dict)) tqdm.write(str(logging_dict))
self.logger.log(logging_dict, commit=True) self.logger.log(logging_dict, commit=True)
......
...@@ -144,9 +144,9 @@ def main(): ...@@ -144,9 +144,9 @@ def main():
predictions = expand_to_aliases(example["output"]) predictions = expand_to_aliases(example["output"])
# some preprocessing to both prediction and answer # some preprocessing to both prediction and answer
answers = set(["".join(a.split()) for a in answers]) answers = {"".join(a.split()) for a in answers}
predictions = set(["".join(p.split()) for p in predictions]) predictions = {"".join(p.split()) for p in predictions}
predictions = set([s for s in predictions if s not in ["``", "''", "`", "'"]]) predictions = {s for s in predictions if s not in ["``", "''", "`", "'"]}
# if there is a common element, it's a exact match # if there is a common element, it's a exact match
example["match"] = len(list(answers & predictions)) > 0 example["match"] = len(list(answers & predictions)) > 0
......
...@@ -314,12 +314,12 @@ if __name__ == "__main__": ...@@ -314,12 +314,12 @@ if __name__ == "__main__":
data = data["train" if PROCESS_TRAIN == "true" else "validation"] data = data["train" if PROCESS_TRAIN == "true" else "validation"]
fn_kwargs = dict( fn_kwargs = {
tokenizer=tokenizer, "tokenizer": tokenizer,
doc_stride=DOC_STRIDE, "doc_stride": DOC_STRIDE,
max_length=MAX_LENGTH, "max_length": MAX_LENGTH,
assertion=False, "assertion": False,
) }
data = data.map(prepare_inputs, fn_kwargs=fn_kwargs) data = data.map(prepare_inputs, fn_kwargs=fn_kwargs)
data = data.remove_columns(["annotations", "document", "id", "question"]) data = data.remove_columns(["annotations", "document", "id", "question"])
print(data) print(data)
......
...@@ -34,7 +34,7 @@ empty_dict = object() ...@@ -34,7 +34,7 @@ empty_dict = object()
def _match(qs, ks): def _match(qs, ks):
"""Return True if regexes in qs match any window of strings in tuple ks.""" """Return True if regexes in qs match any window of strings in tuple ks."""
# compile regexes and force complete match # compile regexes and force complete match
qts = tuple(map(lambda x: re.compile(x + "$"), qs)) qts = tuple((re.compile(x + "$") for x in qs))
for i in range(len(ks) - len(qs) + 1): for i in range(len(ks) - len(qs) + 1):
matches = [x.match(y) for x, y in zip(qts, ks[i:])] matches = [x.match(y) for x, y in zip(qts, ks[i:])]
if matches and all(matches): if matches and all(matches):
......
...@@ -78,7 +78,7 @@ def query_es_index(question, es_client, index_name="english_wiki_kilt_snippets_1 ...@@ -78,7 +78,7 @@ def query_es_index(question, es_client, index_name="english_wiki_kilt_snippets_1
) )
hits = response["hits"]["hits"] hits = response["hits"]["hits"]
support_doc = "<P> " + " <P> ".join([hit["_source"]["passage_text"] for hit in hits]) support_doc = "<P> " + " <P> ".join([hit["_source"]["passage_text"] for hit in hits])
res_list = [dict([(k, hit["_source"][k]) for k in hit["_source"] if k != "passage_text"]) for hit in hits] res_list = [{k: hit["_source"][k] for k in hit["_source"] if k != "passage_text"} for hit in hits]
for r, hit in zip(res_list, hits): for r, hit in zip(res_list, hits):
r["passage_id"] = hit["_id"] r["passage_id"] = hit["_id"]
r["score"] = hit["_score"] r["score"] = hit["_score"]
...@@ -601,7 +601,7 @@ def make_qa_dense_index( ...@@ -601,7 +601,7 @@ def make_qa_dense_index(
fp = np.memmap(index_name, dtype=dtype, mode="w+", shape=(passages_dset.num_rows, 128)) fp = np.memmap(index_name, dtype=dtype, mode="w+", shape=(passages_dset.num_rows, 128))
n_batches = math.ceil(passages_dset.num_rows / batch_size) n_batches = math.ceil(passages_dset.num_rows / batch_size)
for i in range(n_batches): for i in range(n_batches):
passages = [p for p in passages_dset[i * batch_size : (i + 1) * batch_size]["passage_text"]] passages = list(passages_dset[i * batch_size : (i + 1) * batch_size]["passage_text"])
reps = embed_passages_for_retrieval(passages, tokenizer, qa_embedder, max_length, device) reps = embed_passages_for_retrieval(passages, tokenizer, qa_embedder, max_length, device)
fp[i * batch_size : (i + 1) * batch_size] = reps fp[i * batch_size : (i + 1) * batch_size] = reps
if i % 50 == 0: if i % 50 == 0:
...@@ -634,7 +634,7 @@ def query_qa_dense_index( ...@@ -634,7 +634,7 @@ def query_qa_dense_index(
D, I = wiki_index.search(q_rep, 2 * n_results) D, I = wiki_index.search(q_rep, 2 * n_results)
res_passages = [wiki_passages[int(i)] for i in I[0]] res_passages = [wiki_passages[int(i)] for i in I[0]]
support_doc = "<P> " + " <P> ".join([p["passage_text"] for p in res_passages]) support_doc = "<P> " + " <P> ".join([p["passage_text"] for p in res_passages])
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages] res_list = [{k: p[k] for k in wiki_passages.column_names} for p in res_passages]
res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results] res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results]
for r, sc in zip(res_list, D[0]): for r, sc in zip(res_list, D[0]):
r["score"] = float(sc) r["score"] = float(sc)
...@@ -650,7 +650,7 @@ def batch_query_qa_dense_index(questions, qa_embedder, tokenizer, wiki_passages, ...@@ -650,7 +650,7 @@ def batch_query_qa_dense_index(questions, qa_embedder, tokenizer, wiki_passages,
] ]
all_res_lists = [] all_res_lists = []
for res_passages, dl in zip(res_passages_lst, D): for res_passages, dl in zip(res_passages_lst, D):
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages] res_list = [{k: p[k] for k in wiki_passages.column_names} for p in res_passages]
for r, sc in zip(res_list, dl): for r, sc in zip(res_list, dl):
r["score"] = float(sc) r["score"] = float(sc)
all_res_lists += [res_list[:]] all_res_lists += [res_list[:]]
...@@ -663,7 +663,7 @@ def query_qa_dense_index_nn(passage, qa_embedder, tokenizer, wiki_passages, wiki ...@@ -663,7 +663,7 @@ def query_qa_dense_index_nn(passage, qa_embedder, tokenizer, wiki_passages, wiki
D, I = wiki_index.search(a_rep, 2 * n_results) D, I = wiki_index.search(a_rep, 2 * n_results)
res_passages = [wiki_passages[int(i)] for i in I[0]] res_passages = [wiki_passages[int(i)] for i in I[0]]
support_doc = "<P> " + " <P> ".join([p["passage_text"] for p in res_passages]) support_doc = "<P> " + " <P> ".join([p["passage_text"] for p in res_passages])
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages] res_list = [{k: p[k] for k in wiki_passages.column_names} for p in res_passages]
res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results] res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results]
for r, sc, i in zip(res_list, D[0], I[0]): for r, sc, i in zip(res_list, D[0], I[0]):
r["passage_id"] = int(i) r["passage_id"] = int(i)
...@@ -680,7 +680,7 @@ def batch_query_qa_dense_index_nn(passages, qa_embedder, tokenizer, wiki_passage ...@@ -680,7 +680,7 @@ def batch_query_qa_dense_index_nn(passages, qa_embedder, tokenizer, wiki_passage
] ]
all_res_lists = [] all_res_lists = []
for res_passages, dl, il in zip(res_passages_lst, D, I): for res_passages, dl, il in zip(res_passages_lst, D, I):
res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages] res_list = [{k: p[k] for k in wiki_passages.column_names} for p in res_passages]
for r, sc, i in zip(res_list, dl, il): for r, sc, i in zip(res_list, dl, il):
r["passage_id"] = int(i) r["passage_id"] = int(i)
r["score"] = float(sc) r["score"] = float(sc)
......
...@@ -61,7 +61,7 @@ class Extract: ...@@ -61,7 +61,7 @@ class Extract:
assert outputfile is not None and not os.path.isfile(outputfile), f"{outputfile}" assert outputfile is not None and not os.path.isfile(outputfile), f"{outputfile}"
if subset_list is not None: if subset_list is not None:
with open(os.path.realpath(subset_list)) as f: with open(os.path.realpath(subset_list)) as f:
self.subset_list = set(map(lambda x: self._vqa_file_split()[0], tryload(f))) self.subset_list = {self._vqa_file_split()[0] for x in tryload(f)}
else: else:
self.subset_list = None self.subset_list = None
......
...@@ -1095,7 +1095,7 @@ class ROIPooler(nn.Module): ...@@ -1095,7 +1095,7 @@ class ROIPooler(nn.Module):
Returns: Returns:
A tensor of shape(N*B, Channels, output_size, output_size) A tensor of shape(N*B, Channels, output_size, output_size)
""" """
x = [v for v in feature_maps.values()] x = list(feature_maps.values())
num_level_assignments = len(self.level_poolers) num_level_assignments = len(self.level_poolers)
assert len(x) == num_level_assignments and len(boxes) == x[0].size(0) assert len(x) == num_level_assignments and len(boxes) == x[0].size(0)
......
...@@ -554,9 +554,9 @@ def main(): ...@@ -554,9 +554,9 @@ def main():
if args.do_eval and args.local_rank in [-1, 0]: if args.do_eval and args.local_rank in [-1, 0]:
checkpoints = [args.output_dir] checkpoints = [args.output_dir]
if args.eval_all_checkpoints: if args.eval_all_checkpoints:
checkpoints = list( checkpoints = [
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True)) os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
) ]
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints: for checkpoint in checkpoints:
...@@ -566,7 +566,7 @@ def main(): ...@@ -566,7 +566,7 @@ def main():
model.load_state_dict(torch.load(checkpoint)) model.load_state_dict(torch.load(checkpoint))
model.to(args.device) model.to(args.device)
result = evaluate(args, model, tokenizer, criterion, prefix=prefix) result = evaluate(args, model, tokenizer, criterion, prefix=prefix)
result = dict((k + "_{}".format(global_step), v) for k, v in result.items()) result = {k + "_{}".format(global_step): v for k, v in result.items()}
results.update(result) results.update(result)
return results return results
......
...@@ -941,9 +941,9 @@ def main(): ...@@ -941,9 +941,9 @@ def main():
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
checkpoints = [args.output_dir] checkpoints = [args.output_dir]
if args.eval_all_checkpoints: if args.eval_all_checkpoints:
checkpoints = list( checkpoints = [
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True)) os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
) ]
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints: for checkpoint in checkpoints:
...@@ -953,7 +953,7 @@ def main(): ...@@ -953,7 +953,7 @@ def main():
model = model_class.from_pretrained(checkpoint) model = model_class.from_pretrained(checkpoint)
model.to(args.device) model.to(args.device)
result = evaluate(args, model, tokenizer, prefix=prefix) result = evaluate(args, model, tokenizer, prefix=prefix)
result = dict((k + "_{}".format(global_step), v) for k, v in result.items()) result = {k + "_{}".format(global_step): v for k, v in result.items()}
results.update(result) results.update(result)
return results return results
......
...@@ -1109,10 +1109,10 @@ def main(): ...@@ -1109,10 +1109,10 @@ def main():
logger.info("Loading checkpoints saved during training for evaluation") logger.info("Loading checkpoints saved during training for evaluation")
checkpoints = [args.output_dir] checkpoints = [args.output_dir]
if args.eval_all_checkpoints: if args.eval_all_checkpoints:
checkpoints = list( checkpoints = [
os.path.dirname(c) os.path.dirname(c)
for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True)) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
) ]
else: else:
logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path) logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path)
...@@ -1129,7 +1129,7 @@ def main(): ...@@ -1129,7 +1129,7 @@ def main():
# Evaluate # Evaluate
result = evaluate(args, model, tokenizer, prefix=global_step) result = evaluate(args, model, tokenizer, prefix=global_step)
result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items()) result = {k + ("_{}".format(global_step) if global_step else ""): v for k, v in result.items()}
results.update(result) results.update(result)
logger.info("Results: {}".format(results)) logger.info("Results: {}".format(results))
......
...@@ -42,8 +42,8 @@ def _graph_replace_input_with(graph_proto, name, new_name): ...@@ -42,8 +42,8 @@ def _graph_replace_input_with(graph_proto, name, new_name):
def _remove_dup_initializers_from_model(model, model_without_ext, ind_to_replace): def _remove_dup_initializers_from_model(model, model_without_ext, ind_to_replace):
inits_with_data = [i for i in model.graph.initializer] inits_with_data = list(model.graph.initializer)
inits = [i for i in model_without_ext.graph.initializer] inits = list(model_without_ext.graph.initializer)
for i, ref_i in ind_to_replace: for i, ref_i in ind_to_replace:
assert inits_with_data[i].name == inits[i].name assert inits_with_data[i].name == inits[i].name
assert inits_with_data[ref_i].name == inits[ref_i].name assert inits_with_data[ref_i].name == inits[ref_i].name
...@@ -69,7 +69,7 @@ def remove_dup_initializers(onnx_file_path): ...@@ -69,7 +69,7 @@ def remove_dup_initializers(onnx_file_path):
model = onnx.load(os.path.join(model_file_folder, model_file_name)) model = onnx.load(os.path.join(model_file_folder, model_file_name))
inits = [i for i in model.graph.initializer] inits = list(model.graph.initializer)
dup_set = set() dup_set = set()
dup_map = {} dup_map = {}
......
...@@ -127,11 +127,9 @@ def perturb_past( ...@@ -127,11 +127,9 @@ def perturb_past(
_, _, _, curr_length, _ = past[0].shape _, _, _, curr_length, _ = past[0].shape
if curr_length > window_length and window_length > 0: if curr_length > window_length and window_length > 0:
ones_key_val_shape = tuple(past[0].shape[:-2]) + tuple([window_length]) + tuple(past[0].shape[-1:]) ones_key_val_shape = tuple(past[0].shape[:-2]) + (window_length,) + tuple(past[0].shape[-1:])
zeros_key_val_shape = ( zeros_key_val_shape = tuple(past[0].shape[:-2]) + (curr_length - window_length,) + tuple(past[0].shape[-1:])
tuple(past[0].shape[:-2]) + tuple([curr_length - window_length]) + tuple(past[0].shape[-1:])
)
ones_mask = torch.ones(ones_key_val_shape) ones_mask = torch.ones(ones_key_val_shape)
ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3) ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3)
......
...@@ -164,11 +164,11 @@ class GenerativeQAModule(BaseTransformer): ...@@ -164,11 +164,11 @@ class GenerativeQAModule(BaseTransformer):
self.step_count = 0 self.step_count = 0
self.metrics = defaultdict(list) self.metrics = defaultdict(list)
self.dataset_kwargs: dict = dict( self.dataset_kwargs: dict = {
data_dir=self.hparams.data_dir, "data_dir": self.hparams.data_dir,
max_source_length=self.hparams.max_source_length, "max_source_length": self.hparams.max_source_length,
prefix=prefix or "", "prefix": prefix or "",
) }
n_observations_per_split = { n_observations_per_split = {
"train": self.hparams.n_train, "train": self.hparams.n_train,
"val": self.hparams.n_val, "val": self.hparams.n_val,
......
...@@ -137,7 +137,7 @@ logger = getLogger(__name__) ...@@ -137,7 +137,7 @@ logger = getLogger(__name__)
def flatten_list(summary_ids: List[List]): def flatten_list(summary_ids: List[List]):
return [x for x in itertools.chain.from_iterable(summary_ids)] return list(itertools.chain.from_iterable(summary_ids))
def save_git_info(folder_path: str) -> None: def save_git_info(folder_path: str) -> None:
......
...@@ -162,11 +162,11 @@ class GenerativeQAModule(BaseTransformer): ...@@ -162,11 +162,11 @@ class GenerativeQAModule(BaseTransformer):
self.step_count = 0 self.step_count = 0
self.metrics = defaultdict(list) self.metrics = defaultdict(list)
self.dataset_kwargs: dict = dict( self.dataset_kwargs: dict = {
data_dir=self.hparams.data_dir, "data_dir": self.hparams.data_dir,
max_source_length=self.hparams.max_source_length, "max_source_length": self.hparams.max_source_length,
prefix=prefix or "", "prefix": prefix or "",
) }
n_observations_per_split = { n_observations_per_split = {
"train": self.hparams.n_train, "train": self.hparams.n_train,
"val": self.hparams.n_val, "val": self.hparams.n_val,
......
...@@ -137,7 +137,7 @@ logger = getLogger(__name__) ...@@ -137,7 +137,7 @@ logger = getLogger(__name__)
def flatten_list(summary_ids: List[List]): def flatten_list(summary_ids: List[List]):
return [x for x in itertools.chain.from_iterable(summary_ids)] return list(itertools.chain.from_iterable(summary_ids))
def save_git_info(folder_path: str) -> None: def save_git_info(folder_path: str) -> None:
......
...@@ -344,7 +344,7 @@ def create_vocabulary_from_data( ...@@ -344,7 +344,7 @@ def create_vocabulary_from_data(
lambda vocab_1, vocab_2: set(vocab_1["vocab"][0]) | set(vocab_2["vocab"][0]), vocabs.values() lambda vocab_1, vocab_2: set(vocab_1["vocab"][0]) | set(vocab_2["vocab"][0]), vocabs.values()
) )
vocab_dict = {v: k for k, v in enumerate(sorted(list(vocab_set)))} vocab_dict = {v: k for k, v in enumerate(sorted(vocab_set))}
# replace white space with delimiter token # replace white space with delimiter token
if word_delimiter_token is not None: if word_delimiter_token is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment